From caa3ad5435f1cb28a7aaad1f7ef2662d200ef4e6 Mon Sep 17 00:00:00 2001 From: Eunjin Song Date: Thu, 21 May 2026 07:36:44 -0700 Subject: [PATCH 1/7] test(knn-bench): IndexedNearestJoinTempRBenchmark Three-config benchmark validating the per-query temp Lance design from sezruby/lance-spark#2: same data, same job, three execution paths. A: Spark crossJoin + min_by_k on parquet R (brute-force baseline) B: per-query temp Lance write + kNearestJoin against the temp URI C: Lance-native R + kNearestJoin (already-Lance reference) (B - C) = pure temp-write overhead. (B vs A) = headline speedup vs the naive parquet-R approach. Tiny scale local (M5 Max, 5 repeats): A 28,231 ms B 323 ms (36 ms tw + 287 ms probe) C 267 ms B beats A 87x; (B - C) overhead = 56 ms. Cluster mode supported via BENCH_CLUSTER_MODE=true + BENCH_DATA_PATH; cluster numbers blocked on infra (OpenShift CSI / PVC reconciler) and will be added as a follow-up. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../IndexedNearestJoinTempRBenchmark.scala | 451 ++++++++++++++++++ 1 file changed, 451 insertions(+) create mode 100644 lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/IndexedNearestJoinTempRBenchmark.scala diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/IndexedNearestJoinTempRBenchmark.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/IndexedNearestJoinTempRBenchmark.scala new file mode 100644 index 000000000..9742c4bc9 --- /dev/null +++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/IndexedNearestJoinTempRBenchmark.scala @@ -0,0 +1,451 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.knn.benchmark + +import org.apache.spark.sql.{DataFrame, Row, RowFactory, SparkSession} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.lance.spark.knn.LanceKnnImplicits._ + +import java.nio.file.Files +import java.util.{Locale, Random} +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +/** + * Validates the per-query temp-Lance design from + * [sezruby/lance-spark#2](https://github.com/sezruby/lance-spark/issues/2): when R lives in + * parquet (or any non-Lance source), the indexed `NearestByJoin` path can still apply by + * writing R to a temp Lance dataset before the probe. Three configs run on the SAME data + * in the SAME job (no cross-run noise): + * + * A: vanilla Spark crossJoin + L2 UDF + min_by_k on parquet R — the brute-force baseline + * a user would write today; matches what Spark 4.2's `RewriteNearestByJoin` lowers to. + * B: per-query temp Lance write + `kNearestJoin` against the temp URI — the per-query + * temp design under test. + * C: Lance-native R + `kNearestJoin` — same probe pipeline as B but R was pre-written to + * Lance once outside the timing loop. The "already-Lance" reference; (B - C) is the + * pure temp-write overhead per query. + * + * Three configs answer: + * 1. Does B beat A on parquet R? (B vs A speedup — is the per-query temp story faster + * than the brute-force baseline at all?) + * 2. How much overhead does the temp write add vs. the existing Lance-already path? + * (B - C, on the same hardware in the same run, no cross-run noise) + * 3. Sanity: probe-only median of B (excluding temp write) should match C closely; if + * not, something's off in the pipeline. + * + * == Local run == + * {{{ + * MAVEN_OPTS="-Xmx12g " \ + * ./mvnw -pl lance-spark-knn_2.12 -q exec:java -Pbenchmark \ + * -Dexec.mainClass="org.lance.spark.knn.benchmark.IndexedNearestJoinTempRBenchmark" + * }}} + * + * Local laptop runs are noisy enough that single-run point estimates are unreliable + * (multi-tenant CPU contention on macOS, single-machine no-real-parallelism). Headline + * numbers should come from a real distributed cluster — see `BENCHMARK_RESULTS.md` + * § "Variance / multi-tenant noise" for the discussion in the existing benchmark. + * + * == Cluster run == + * Build the fat JAR (`./mvnw -pl lance-spark-knn_2.12 package -Pbenchmark -DskipTests`), + * upload, submit with `BENCH_CLUSTER_MODE=true` and `BENCH_DATA_PATH=`. + * + * Environment: + * BENCH_SCALES — comma-separated subset of {tiny, small, medium, fat}; + * default "tiny,small". + * tiny = |R|=100K, |L|=100, dim=128 + * small = |R|=1M, |L|=1000, dim=128 + * medium = |R|=100K, |L|=100, dim=1024 + * fat = |R|=1M, |L|=1000, dim=1024 + * BENCH_REPEATS — measured iterations per config; default 3. 1 warmup. + * BENCH_CLUSTER_MODE — `true` to skip `.master()` and bind-address configs; + * must be set when submitting to a real Spark cluster. + * BENCH_DATA_PATH — shared scratch URI (file://, s3://, hdfs://, etc.); + * required in cluster mode, optional locally. + * + * Reports timing breakdown for B: (temp_write_ms + probe_ms) so the temp-write cost is + * visible relative to the probe itself. + */ +object IndexedNearestJoinTempRBenchmark { + + private val K: Int = 10 + private val Seed: Long = 1337L + + private case class Scale(name: String, numR: Int, numL: Int, dim: Int) { + def vectorBytesR: Long = numR.toLong * dim.toLong * 4L + override def toString: String = + f"$name (|R|=$numR%,d, |L|=$numL%,d, dim=$dim, R-vec=${vectorBytesR / (1024.0 * 1024.0)}%.0f MB)" + } + + private val Scales: Map[String, Scale] = Seq( + Scale("tiny", numR = 100000, numL = 100, dim = 128), + Scale("small", numR = 1000000, numL = 1000, dim = 128), + Scale("medium", numR = 100000, numL = 100, dim = 1024), + Scale("fat", numR = 1000000, numL = 1000, dim = 1024)).map(s => s.name -> s).toMap + + private case class Result( + scale: String, + config: String, + tempWriteMs: Option[Long], + totalMs: Long, + runs: Seq[Long]) { + def probeMs: Option[Long] = tempWriteMs.map(w => totalMs - w) + } + + def main(args: Array[String]): Unit = { + val scaleNames = sys.env + .getOrElse("BENCH_SCALES", "tiny,small") + .toLowerCase(Locale.ROOT) + .split(",") + .map(_.trim) + .filter(_.nonEmpty) + val scales = scaleNames.map { n => + Scales.getOrElse( + n, + sys.error(s"unknown scale '$n'; valid: ${Scales.keys.toSeq.sorted.mkString(", ")}")) + } + val repeats = sys.env.get("BENCH_REPEATS").map(_.toInt).getOrElse(3) + val clusterMode = sys.env.get("BENCH_CLUSTER_MODE").exists(_.equalsIgnoreCase("true")) + val dataRootOpt = sys.env.get("BENCH_DATA_PATH").orElse(sys.env.get("BENCH_DATA")) + + // Use user-supplied path in cluster mode, otherwise a local temp dir. + // Avoid "lance" in the path token; the V2 catalog's path-identifier parser tokenises + // around it on writes. Same workaround as LanceWriteBenchmark. + val dataRoot = + dataRootOpt.getOrElse(Files.createTempDirectory("knn-tempr-bench-").toString) + + val builder = SparkSession + .builder() + .appName("indexed-nearest-temp-r-benchmark") + .config("spark.sql.crossJoin.enabled", "true") + if (!clusterMode) { + builder + .master("local[*]") + .config("spark.driver.bindAddress", "127.0.0.1") + .config("spark.driver.host", "127.0.0.1") + .config("spark.sql.shuffle.partitions", "16") + } + val spark = builder.getOrCreate() + spark.sparkContext.setLogLevel("WARN") + + println("=" * 76) + println("Per-query temp Lance benchmark — non-Lance R via temp write") + println("=" * 76) + val masterDesc = if (clusterMode) "cluster (BENCH_CLUSTER_MODE=true)" else "local[*]" + println(f"Spark master: $masterDesc (cores=${spark.sparkContext.defaultParallelism})") + println(f"Repeats: $repeats (median reported); 1 warmup") + println(f"Data root: $dataRoot") + println(f"K: $K") + println(f"Scales: ${scales.map(_.name).mkString(", ")}") + println() + + val results = scala.collection.mutable.ArrayBuffer.empty[Result] + try { + scales.foreach { scale => + println("-" * 76) + println(s"Scale: $scale") + println("-" * 76) + + // Build left in-memory and right as parquet on disk. This is the realistic shape: + // R isn't in Lance yet; we want to time the cost of materializing it. + val leftDf = buildLeft(spark, scale).cache() + leftDf.count() + val rightParquetUri = s"$dataRoot/${scale.name}_right.parquet" + writeRightParquet(spark, scale, rightParquetUri) + val rightDfParquet = spark.read.parquet(rightParquetUri) + + // Pre-materialize a Lance-native R once (outside the timing loop) for config C. + // This is the apples-to-apples reference: same data, already-Lance — what users get + // when they store R in Lance natively. The kNearestJoin call against this URI uses + // exactly the same probe pipeline as B; the only difference is no temp write. + val rightLanceUri = s"$dataRoot/${scale.name}_right_native.lance" + writeTempLance(rightDfParquet, rightLanceUri) + val rightDfLance = spark.read.format("lance").load(rightLanceUri) + + // Sanity-check on a 16-row subset that the per-query-temp path agrees with the + // crossJoin baseline. Bail early if not. + verifyOracle(spark, scale, leftDf, rightParquetUri, dataRoot) + + // ---- Config A: vanilla Spark crossJoin + L2 UDF + min_by_k baseline ---- + // Use fewer repeats at higher scales since each A run is O(|L| × |R|) pair + // evaluations and quickly becomes minutes per run. + val baselineRepeats = if (scale.numL.toLong * scale.numR > 100000000L) 1 else repeats + val resultA = + timeIt(scale.name, "A: Spark crossJoin + min_by_k (parquet R)", baselineRepeats) { + () => crossJoinMinByK(leftDf, rightDfParquet, K) + } + results += resultA + + // ---- Config B: per-query temp Lance write + existing kNearestJoin ---- + val resultB = timeWithBreakdown(scale.name, "B: temp Lance write + kNearestJoin", repeats) { + () => + val tempUri = s"$dataRoot/${scale.name}_temp_${System.nanoTime()}" + // Step 1: temp write (timed) + val twStart = System.nanoTime() + writeTempLance(rightDfParquet, tempUri) + val twMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - twStart) + // Step 2: KNN against temp Lance (timed via runFull at outer level) + val tempLanceDf = spark.read.format("lance").load(tempUri) + val joined = leftDf.kNearestJoin( + right = tempLanceDf, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = K, + metric = "l2", + rightProjection = Some(Seq("rid")), + probeParallelism = 1) + (twMs, joined) + } + results += resultB + + // ---- Config C: Lance-native R (already-Lance reference; no temp write) ---- + val resultC = timeIt(scale.name, "C: Lance-native R + kNearestJoin", repeats) { () => + leftDf.kNearestJoin( + right = rightDfLance, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = K, + metric = "l2", + rightProjection = Some(Seq("rid")), + probeParallelism = 1) + } + results += resultC + + leftDf.unpersist() + println() + } + + println("=" * 76) + println("Summary") + println("=" * 76) + printSummary(results.toSeq) + } finally { + spark.stop() + } + } + + // -- workload setup ---------------------------------------------------------------------- + + private def buildLeft(spark: SparkSession, scale: Scale): DataFrame = { + val schema = leftSchema(scale.dim) + val rng = new Random(Seed ^ 1L) + val rows = (0 until scale.numL).map { i => + RowFactory.create(Integer.valueOf(i), randomVector(rng, scale.dim)) + } + spark.createDataFrame(rows.asJava, schema) + } + + private def writeRightParquet(spark: SparkSession, scale: Scale, uri: String): Unit = { + val schema = rightSchema(scale.dim) + val parts = math.max(spark.sparkContext.defaultParallelism, 8) + val numR = scale.numR + val dim = scale.dim + val rdd = spark.sparkContext + .range(0L, numR.toLong, 1L, parts) + .mapPartitionsWithIndex { (idx, iter) => + val rng = new Random(0xCAFEBABEL ^ idx.toLong) + iter.map { i => + val v = new Array[Float](dim) + var k = 0 + while (k < dim) { v(k) = rng.nextFloat(); k += 1 } + RowFactory.create(Integer.valueOf(i.toInt), v): Row + } + } + spark.createDataFrame(rdd, schema).write.parquet(uri) + } + + private def writeTempLance(rightDf: DataFrame, tempUri: String): Unit = { + // Per-query temp: project rid + rvec only (the columns the existing kNearestJoin path + // needs). Carrying additional payload columns is a follow-up; this benchmark validates + // the minimal shape. + rightDf.select("rid", "rvec").write.format("lance").save(tempUri) + } + + private def leftSchema(dim: Int): StructType = new StructType( + Array( + StructField("lid", IntegerType, nullable = false), + StructField( + "lvec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", dim.toLong).build()))) + + private def rightSchema(dim: Int): StructType = new StructType( + Array( + StructField("rid", IntegerType, nullable = false), + StructField( + "rvec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", dim.toLong).build()))) + + private def randomVector(rng: Random, dim: Int): Array[Float] = { + val v = new Array[Float](dim) + var i = 0 + while (i < dim) { v(i) = rng.nextFloat(); i += 1 } + v + } + + // -- baseline ---------------------------------------------------------------------------- + + private def l2Udf = + udf((a: Seq[Float], b: Seq[Float]) => { + var s = 0.0f + var i = 0 + while (i < a.length) { val d = a(i) - b(i); s += d * d; i += 1 } + s + }) + + /** + * Same shape as `IndexedNearestJoinBenchmark.crossProductMinByK`. The realistic + * baseline a user would write today on parquet R (and what Spark 4.2's + * RewriteNearestByJoin lowers to). + */ + private def crossJoinMinByK(left: DataFrame, right: DataFrame, k: Int): DataFrame = { + val l2 = l2Udf + val r = right.select("rid", "rvec") + val crossed = left.crossJoin(r).withColumn("__dist", l2(col("lvec"), col("rvec"))) + crossed.groupBy("lid") + .agg( + slice( + sort_array(collect_list(struct(col("__dist"), col("rid"))), asc = true), + 1, + k).as("__matches")) + .select(col("lid"), inline(col("__matches")).as(Seq("__dist", "rid"))) + .select("lid", "rid", "__dist") + } + + // -- timing harness ---------------------------------------------------------------------- + + private def runFull(df: DataFrame): Unit = + df.write.format("noop").mode("overwrite").save() + + private def timeIt(scale: String, config: String, repeats: Int)(f: () => DataFrame): Result = { + print(s" $config ... ") + System.out.flush() + runFull(f()) // warmup + val runs = (0 until repeats).map { _ => + val t0 = System.nanoTime() + runFull(f()) + TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t0) + } + val sortedRuns = runs.sorted + val median = sortedRuns(sortedRuns.length / 2) + println(s"runs=${runs.mkString("[", ",", "]")} ms, median=$median ms") + Result(scale, config, tempWriteMs = None, totalMs = median, runs = runs) + } + + /** + * Like timeIt but the runnable returns (temp_write_ms, df). The full timing includes + * the temp write; we report it separately so the probe-only cost is visible. + */ + private def timeWithBreakdown(scale: String, config: String, repeats: Int)( + f: () => (Long, DataFrame)): Result = { + print(s" $config ... ") + System.out.flush() + val (_, warmDf) = f() + runFull(warmDf) // warmup + val records = (0 until repeats).map { _ => + val t0 = System.nanoTime() + val (twMs, df) = f() + runFull(df) + val total = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t0) + (twMs, total) + } + val sortedTotals = records.map(_._2).sorted + val medianTotal = sortedTotals(sortedTotals.length / 2) + val medianTw = records.map(_._1).sorted.apply(records.length / 2) + println( + f"runs(total)=${records.map(_._2).mkString("[", ",", "]")} ms, median=$medianTotal%d ms " + + f"(temp write median=$medianTw%d ms, probe median=${medianTotal - medianTw}%d ms)") + Result( + scale, + config, + tempWriteMs = Some(medianTw), + totalMs = medianTotal, + runs = records.map(_._2)) + } + + // -- oracle equivalence ------------------------------------------------------------------ + + private def verifyOracle( + spark: SparkSession, + scale: Scale, + leftDf: DataFrame, + rightParquetUri: String, + dataRoot: String): Unit = { + println(" Sanity: oracle check on 16-row left subset ...") + val left16 = leftDf.limit(16).cache() + left16.count() + val rightDfParquet = spark.read.parquet(rightParquetUri) + val tempUri = s"$dataRoot/${scale.name}_oracle_temp_${System.nanoTime()}" + writeTempLance(rightDfParquet, tempUri) + val tempLance = spark.read.format("lance").load(tempUri) + + val viaTemp = left16.kNearestJoin( + right = tempLance, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = K, + metric = "l2", + rightProjection = Some(Seq("rid")), + probeParallelism = 1) + val viaBaseline = crossJoinMinByK(left16, rightDfParquet, K) + + val tempByLid = viaTemp.collect().groupBy(_.getAs[Int]("lid")).map { + case (lid, rows) => lid -> rows.map(_.getAs[Int]("rid")).toSet + } + val baseByLid = viaBaseline.collect().groupBy(_.getAs[Int]("lid")).map { + case (lid, rows) => lid -> rows.map(_.getAs[Int]("rid")).toSet + } + val mismatches = tempByLid.toSeq.flatMap { + case (lid, ids) => + baseByLid.get(lid) match { + case Some(b) if b == ids => None + case Some(b) => Some(s"lid=$lid: temp=$ids baseline=$b") + case None => Some(s"lid=$lid: missing from baseline") + } + } + if (mismatches.nonEmpty) { + sys.error(s"Oracle mismatch:\n ${mismatches.mkString("\n ")}") + } + left16.unpersist() + println(f" ... oracle equivalence holds (${tempByLid.size} left rows × K=$K).") + } + + // -- reporting --------------------------------------------------------------------------- + + private def printSummary(results: Seq[Result]): Unit = { + val byScale = results.groupBy(_.scale) + println( + f"${"scale"}%-8s ${"config"}%-46s ${"med ms"}%8s ${"tw ms"}%8s ${"probe ms"}%9s ${"vs A"}%6s") + println("-" * 95) + val scaleOrder = Seq("tiny", "small", "medium", "fat").filter(byScale.contains) + scaleOrder.foreach { sc => + val rs = byScale(sc) + val baseline = rs.find(_.config.startsWith("A:")).map(_.totalMs).getOrElse(0L) + rs.foreach { r => + val twStr = r.tempWriteMs.map(_.toString).getOrElse("—") + val probeStr = r.probeMs.map(_.toString).getOrElse("—") + val speedup = + if (baseline > 0 && r.totalMs > 0) f"${baseline.toDouble / r.totalMs}%.1fx" else "—" + println( + f"${r.scale}%-8s ${r.config}%-46s ${r.totalMs}%8d $twStr%8s $probeStr%9s $speedup%6s") + } + println() + } + } +} From b9043658f914426fefbc6a6ce7f601e8479a580d Mon Sep 17 00:00:00 2001 From: Eunjin Song Date: Thu, 21 May 2026 08:15:56 -0700 Subject: [PATCH 2/7] =?UTF-8?q?feat(knn):=20LanceTempR=20helper=20?= =?UTF-8?q?=E2=80=94=20eager=20temp-Lance=20materialization=20for=20non-La?= =?UTF-8?q?nce=20R?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per sezruby/lance-spark#2 stage 1: helper for materializing an arbitrary DataFrame to a temp Lance dataset before the existing indexed NearestByJoin pipeline. R can be parquet, delta, in-memory, or the result of an arbitrary upstream Spark plan; the helper writes it once and returns a URI the existing LanceProbeStage / LanceMaterializeStage consume unchanged. - LanceTempR.materialize(right, vecCol, projection, scratchDir): String Synthesises a unique _rid via monotonically_increasing_id(), projects rid + vec + caller-requested payload cols, writes to a unique sub-path under scratchDir. - LanceTempR.resolveScratchDir(spark): String Reads spark.lance.knn.tempR.dir; in cluster mode (master != local*), requires it to be set so the temp lands on a path every executor can see (s3://..., hdfs://..., file:///shared-mount/...). Local mode falls back to spark.local.dir + /lance-temp-r. Validation: - Round-trip: row count + rid uniqueness + vector column equality - Projection: temp schema is exactly rid + vec + requested cols - Subplan-backed sources (Filter+Project chain over parquet): same shape - Empty source: empty Lance dataset, no error - Validation: missing vec, unknown projection, reserved rid name → fail fast - resolveScratchDir: conf-key honoured; local-mode fallback writes correctly 71/71 tests pass (60 existing + 11 new). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../lance/spark/knn/internal/LanceTempR.scala | 190 ++++++++++ .../spark/knn/internal/LanceTempRTest.scala | 330 ++++++++++++++++++ 2 files changed, 520 insertions(+) create mode 100644 lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempR.scala create mode 100644 lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceTempRTest.scala diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempR.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempR.scala new file mode 100644 index 000000000..4065f8465 --- /dev/null +++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempR.scala @@ -0,0 +1,190 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.knn.internal + +import org.apache.spark.sql.{Column, DataFrame, SparkSession} +import org.apache.spark.sql.functions.{col, monotonically_increasing_id} + +import java.nio.file.{Files, Path, Paths} +import java.util.UUID + +/** + * Per-query temp Lance materialization for the right side of an indexed + * `NearestByJoin` when R is not already a Lance scan (parquet, delta, in-memory, + * arbitrary subplan). Materialization is eager: when called, the helper drives + * `right.write.format("lance").save(tempUri)` and returns the URI. The caller + * then passes that URI to the existing Lance-native probe pipeline; the rest of + * the staged plan (probe / merge / materialize) is unchanged. + * + * Design: see [sezruby/lance-spark#2](https://github.com/sezruby/lance-spark/issues/2). + * + * == Why a synchronous helper rather than a Catalyst exec node == + * + * The probe pipeline reads R via `LanceProbeStage.Conf.datasetUri`, which is captured + * once at plan construction time. To plumb a temp-write through Catalyst as a separate + * node we'd have to restructure `LanceProbeExec` (currently `UnaryExecNode` with the + * left side as its child) into a multi-input shape that depends on both the left plan + * AND a sibling temp-write — substantial blast radius. The simpler form: do the temp + * write at plan-build time, hand the resulting URI to the probe like any other Lance + * URI. Same data path on the wire; the only loss is `df.explain()` doesn't display the + * temp write as its own Catalyst operator. A future PR can promote it to an exec node + * if `df.explain()` visibility becomes load-bearing for users. + * + * == Why `monotonically_increasing_id` is the right rid == + * + * Per-query temp doesn't need cross-execution rid stability — temp is built and consumed + * in the same job, then deleted. `monotonically_increasing_id` is unique within a single + * execution and zero-cost to compute, so it fits exactly. For a future cached / persistent + * sidecar (different feature, not this issue), `_metadata.row_index` would be the natural + * choice for parquet-backed R. + */ +private[knn] object LanceTempR { + + /** + * Column name for the synthetic rid that the probe stage's `_rowid IN (...)` lookup + * targets after the write. Stable identifier — referenced by callers when constructing + * `rightProjection`. + */ + val RidColumnName: String = "_rid" + + /** + * User-tunable Spark conf key pointing at the directory under which per-query temp + * Lance datasets are created. In cluster mode this MUST be set to a path every + * executor (and the driver) can read+write — typically a shared object-store URI + * (`s3://...`, `abfss://...`, `file:///shared-mount/...`) or HDFS. In local mode the + * helper falls back to a subdirectory of `spark.local.dir` when this key is unset. + */ + val ScratchDirConfKey: String = "spark.lance.knn.tempR.dir" + + /** + * Materialize `right` to a temp Lance dataset suitable for use as the right side of + * the existing indexed `NearestByJoin` pipeline. + * + * Steps: + * 1. Compute the projection schema: rid (synthesised) + vec + any caller-requested + * additional columns the parent plan references. + * 2. Build a projected DataFrame: `right.select(monotonically_increasing_id() as _rid, + * ...projection)`. + * 3. Write it to `/` as a Lance dataset. + * 4. Return the URI string. + * + * The caller is responsible for: + * - Deleting the temp directory when the query finishes (see `LanceTempLifecycle`). + * - Telling the probe pipeline to project / materialize `RidColumnName` and `vecCol` + * (and any payload columns it should carry). + * + * @param right The non-Lance DataFrame to materialize. + * @param vecCol Name of the FixedSizeList vector column on `right`. + * @param projection Columns from `right` to carry into temp Lance, in addition to the + * synthesised rid and the vector. Empty Seq = carry rid + vec only. + * Use this to thread any payload columns the parent plan references. + * @param scratchDir Directory under which to create the temp Lance dataset. Must + * be a path the executor processes can write to (local FS for + * single-node, shared object store for cluster mode). Use + * [[resolveScratchDir]] to pick this up from session config in + * typical callers. + * @return The URI of the materialized temp Lance dataset. + */ + def materialize( + right: DataFrame, + vecCol: String, + projection: Seq[String], + scratchDir: String): String = { + require(vecCol.nonEmpty, "vecCol must not be empty") + require(scratchDir.nonEmpty, "scratchDir must not be empty") + require( + right.schema.fieldNames.contains(vecCol), + s"right DataFrame schema does not contain vector column '$vecCol'; " + + s"have [${right.schema.fieldNames.mkString(", ")}]") + val unknownCols = projection.filterNot(right.schema.fieldNames.contains) + require( + unknownCols.isEmpty, + s"projection columns not present in right DataFrame schema: " + + s"[${unknownCols.mkString(", ")}]; have [${right.schema.fieldNames.mkString(", ")}]") + require( + !projection.contains(RidColumnName), + s"projection must not include the reserved rid column name '$RidColumnName' — " + + "the helper synthesises it. Pick a different name on `right` or rename before calling.") + + val tempUri = mintTempUri(scratchDir) + val ridCol: Column = monotonically_increasing_id().as(RidColumnName) + val payloadCols: Seq[Column] = (vecCol +: projection.filterNot(_ == vecCol)).distinct.map(col) + val projected: DataFrame = right.select(ridCol +: payloadCols: _*) + + projected.write.format("lance").save(tempUri) + tempUri + } + + /** + * Resolve a writable scratch directory from session configuration, with a clear error + * for cluster runs that haven't set [[ScratchDirConfKey]]. + * + * Resolution order: + * 1. `spark.lance.knn.tempR.dir` if set — used as-is. Caller is responsible for it + * being executor-readable. + * 2. Local-mode-only fallback: `spark.local.dir` first entry + `/lance-temp-r`. Only + * acceptable in `local[*]` mode where the driver and executors share a JVM and + * hence the local FS. + * + * In a cluster (`spark.master` does not start with `local`), missing + * [[ScratchDirConfKey]] throws [[IllegalStateException]] — failing here is much better + * than failing later with a `FileNotFoundException` on an executor that can't see the + * driver's local disk. + */ + def resolveScratchDir(spark: SparkSession): String = { + spark.conf.getOption(ScratchDirConfKey).filter(_.nonEmpty) match { + case Some(p) => p + case None => + val master = Option(spark.sparkContext.master).getOrElse("") + if (!master.startsWith("local")) { + throw new IllegalStateException( + s"$ScratchDirConfKey is not set and Spark master is '$master' — per-query " + + "temp Lance materialization needs a scratch path every executor can " + + s"read+write. Set $ScratchDirConfKey to a shared URI (s3://..., abfss://..., " + + "file:///shared-mount/..., hdfs://...).") + } + val localDir = Option(spark.sparkContext.getConf.get("spark.local.dir", null)) + .map(_.split(",").head.trim) + .filter(_.nonEmpty) + .getOrElse(System.getProperty("java.io.tmpdir")) + s"$localDir/lance-temp-r" + } + } + + /** + * Generate a unique scratch path under `scratchDir`. Caller must clean up. The path + * deliberately does NOT include the substring "lance" — the V2 catalog's path-identifier + * parser tokenises around it on writes (same workaround documented in + * `LanceWriteBenchmark`). + */ + def mintTempUri(scratchDir: String): String = { + val token = "tempr-" + UUID.randomUUID().toString + // For a local scratch dir, ensure the parent exists. Object-store URIs (s3://, etc.) + // are passed through as-is — the lance writer creates the path. + if (isLocalPath(scratchDir)) { + val parent: Path = Paths.get(stripFileScheme(scratchDir)) + Files.createDirectories(parent) + parent.resolve(token).toString + } else { + val sep = if (scratchDir.endsWith("/")) "" else "/" + s"$scratchDir$sep$token" + } + } + + private def isLocalPath(uri: String): Boolean = + uri.startsWith("/") || uri.startsWith("file://") || !uri.contains("://") + + private def stripFileScheme(uri: String): String = + if (uri.startsWith("file://")) uri.substring("file://".length) else uri +} diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceTempRTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceTempRTest.scala new file mode 100644 index 000000000..6512d4f98 --- /dev/null +++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceTempRTest.scala @@ -0,0 +1,330 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.knn.internal + +import org.apache.spark.sql.{Row, RowFactory, SparkSession} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.io.TempDir + +import java.nio.file.{Files, Path} +import java.util.Random + +import scala.collection.JavaConverters._ + +/** + * Behavioural tests for the per-query temp-Lance helper [[LanceTempR]]. Validates the + * properties the design relies on: + * + * - Round-trip: rows materialized to temp Lance read back with the same row count and + * payload as the source DataFrame. + * - Synthesised rid is unique within a single materialization. + * - FixedSizeList vector columns survive the write+read. + * - Caller-requested projection columns are present in the temp; non-requested are + * dropped (column pruning). + * - Subplan-backed sources (Filter / Project chains over parquet) work the same as + * a flat parquet read — the helper only sees a `DataFrame`, not its source. + * - Empty source produces an empty (but readable) temp. + * - Validation: missing vec col, unknown projection col, reserved rid name → fail fast. + */ +class LanceTempRTest { + + @TempDir var tempDir: Path = _ + private var spark: SparkSession = _ + + private val Dim: Int = 8 + private val NumRows: Int = 32 + private val Seed: Long = 11L + + @BeforeEach def setup(): Unit = { + spark = SparkSession.builder() + .appName("lance-temp-r-test") + .master("local[2]") + .config("spark.driver.bindAddress", "127.0.0.1") + .config("spark.driver.host", "127.0.0.1") + .getOrCreate() + } + + @AfterEach def teardown(): Unit = { + if (spark != null) spark.stop() + } + + // -- core round-trip ------------------------------------------------------------------------ + + /** Parquet-backed source: write via the helper, read back, verify row count + uniqueness of rid. */ + @Test def testParquetRoundTripRowCountAndRidUnique(): Unit = { + val parquetUri = writeRandomParquet(NumRows, Dim) + val rightDf = spark.read.parquet(parquetUri) + + val tempUri = + LanceTempR.materialize( + rightDf, + vecCol = "vec", + projection = Seq.empty, + scratchDir = scratch()) + + val readBack = spark.read.format("lance").load(tempUri) + val rows = readBack.collect() + assertEquals(NumRows, rows.length, "row count round-trips") + + val rids = readBack.select(col(LanceTempR.RidColumnName)).collect().map(_.getLong(0)) + assertEquals(NumRows, rids.distinct.length, "rid column is unique within one materialization") + } + + /** Vec column survives unchanged: per-row equality with the source. */ + @Test def testVectorColumnPreserved(): Unit = { + val parquetUri = writeRandomParquet(NumRows, Dim) + val rightDf = spark.read.parquet(parquetUri) + val tempUri = + LanceTempR.materialize( + rightDf, + vecCol = "vec", + projection = Seq("id"), + scratchDir = scratch()) + + val srcRows = rightDf.orderBy("id").collect() + val tempRows = spark.read.format("lance").load(tempUri).orderBy("id").collect() + assertEquals(srcRows.length, tempRows.length) + srcRows.zip(tempRows).foreach { case (s, t) => + val sv = s.getAs[Seq[Float]]("vec").toArray + val tv = t.getAs[Seq[Float]]("vec").toArray + assertArrayEquals(sv, tv, 1e-6f, s"vector mismatch for id=${s.getInt(s.fieldIndex("id"))}") + } + } + + // -- projection (column pruning at write time) --------------------------------------------- + + /** When projection is specified, only requested cols (plus rid + vec) appear in temp. */ + @Test def testProjectionTrimsExtraColumns(): Unit = { + val parquetUri = writeWideParquet(NumRows, Dim) + val rightDf = spark.read.parquet(parquetUri) + // Source has id, vec, label, payload, untouched. Project only id + label. + val tempUri = LanceTempR.materialize( + rightDf, + vecCol = "vec", + projection = Seq("id", "label"), + scratchDir = scratch()) + + val tempSchema = spark.read.format("lance").load(tempUri).schema.fieldNames.toSet + assertEquals( + Set(LanceTempR.RidColumnName, "vec", "id", "label"), + tempSchema, + "temp Lance schema should be exactly rid + vec + projection cols") + } + + /** Empty projection still gets rid + vec. */ + @Test def testEmptyProjectionGivesRidPlusVec(): Unit = { + val parquetUri = writeRandomParquet(NumRows, Dim) + val rightDf = spark.read.parquet(parquetUri) + val tempUri = + LanceTempR.materialize( + rightDf, + vecCol = "vec", + projection = Seq.empty, + scratchDir = scratch()) + + val tempSchema = spark.read.format("lance").load(tempUri).schema.fieldNames.toSet + assertEquals(Set(LanceTempR.RidColumnName, "vec"), tempSchema) + } + + // -- subplan source (the load-bearing case) ------------------------------------------------- + + /** + * Source is a subplan: parquet → Filter → Project. Helper handles it the same as flat parquet, + * because it only consumes a DataFrame, not knowledge of the source. + */ + @Test def testSubplanSourceFilterPlusProject(): Unit = { + val parquetUri = writeWideParquet(NumRows, Dim) + val raw = spark.read.parquet(parquetUri) + // Keep only label = "even" rows (ids 0,2,4,...) and drop the payload column. + val subplanRight = raw.filter(col("label") === "even").select("id", "vec") + + val expectedCount = subplanRight.count() + assertTrue(expectedCount > 0, "test setup: subplan should have rows") + + val tempUri = LanceTempR.materialize( + subplanRight, + vecCol = "vec", + projection = Seq("id"), + scratchDir = scratch()) + val readBack = spark.read.format("lance").load(tempUri) + assertEquals( + expectedCount, + readBack.count(), + "row count of temp Lance equals row count of subplan-evaluated source") + } + + // -- empty input ---------------------------------------------------------------------------- + + /** Empty source DataFrame: temp dataset is created and reads back as 0 rows. */ + @Test def testEmptyDataFrame(): Unit = { + val parquetUri = writeRandomParquet(0, Dim) + val rightDf = spark.read.parquet(parquetUri) + val tempUri = + LanceTempR.materialize( + rightDf, + vecCol = "vec", + projection = Seq.empty, + scratchDir = scratch()) + + val readBack = spark.read.format("lance").load(tempUri) + assertEquals(0L, readBack.count(), "empty source produces empty Lance dataset") + } + + // -- validation ----------------------------------------------------------------------------- + + @Test def testRejectsMissingVecColumn(): Unit = { + val parquetUri = writeRandomParquet(2, Dim) + val rightDf = spark.read.parquet(parquetUri) + val ex = assertThrows( + classOf[IllegalArgumentException], + () => + LanceTempR.materialize( + rightDf, + vecCol = "no_such_col", + projection = Seq.empty, + scratchDir = scratch())) + assertTrue(ex.getMessage.contains("no_such_col")) + } + + @Test def testRejectsUnknownProjectionColumn(): Unit = { + val parquetUri = writeRandomParquet(2, Dim) + val rightDf = spark.read.parquet(parquetUri) + val ex = assertThrows( + classOf[IllegalArgumentException], + () => + LanceTempR.materialize( + rightDf, + vecCol = "vec", + projection = Seq("id", "ghost"), + scratchDir = scratch())) + assertTrue(ex.getMessage.contains("ghost")) + } + + @Test def testRejectsReservedRidName(): Unit = { + val parquetUri = writeRandomParquet(2, Dim) + val rightDf = spark.read.parquet(parquetUri) + val ex = assertThrows( + classOf[IllegalArgumentException], + () => + LanceTempR.materialize( + rightDf, + vecCol = "vec", + projection = Seq(LanceTempR.RidColumnName), + scratchDir = scratch())) + assertTrue(ex.getMessage.contains(LanceTempR.RidColumnName)) + } + + // -- resolveScratchDir --------------------------------------------------------------------- + + /** Conf key set: returned as-is. */ + @Test def testResolveScratchDirHonoursConfKey(): Unit = { + val explicit = scratch() + spark.conf.set(LanceTempR.ScratchDirConfKey, explicit) + try { + assertEquals(explicit, LanceTempR.resolveScratchDir(spark)) + } finally spark.conf.unset(LanceTempR.ScratchDirConfKey) + } + + /** + * Conf key unset in local mode: falls back to a local-FS path under spark.local.dir + * (or system tmp). The exact path is implementation-detail; the contract is "non-empty + * and works locally", verified by an actual round-trip below. + */ + @Test def testResolveScratchDirLocalFallbackWorks(): Unit = { + spark.conf.unset(LanceTempR.ScratchDirConfKey) + val resolved = LanceTempR.resolveScratchDir(spark) + assertTrue(resolved.nonEmpty, "local fallback must produce a non-empty path") + // Confirm it actually works as a scratchDir argument: round-trip a tiny dataset. + val parquetUri = writeRandomParquet(4, Dim) + val rightDf = spark.read.parquet(parquetUri) + val tempUri = LanceTempR.materialize( + rightDf, + vecCol = "vec", + projection = Seq.empty, + scratchDir = resolved) + assertEquals(4L, spark.read.format("lance").load(tempUri).count()) + } + + // -- helpers -------------------------------------------------------------------------------- + + /** Build (id, vec) parquet under tempDir; returns its URI. */ + private def writeRandomParquet(n: Int, dim: Int): String = { + val schema = new StructType(Array( + StructField("id", IntegerType, nullable = false), + StructField( + "vec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", dim.toLong).build()))) + val rng = new Random(Seed) + val rows: Seq[Row] = (0 until n).map { i => + RowFactory.create(Integer.valueOf(i), randomVector(rng, dim)) + } + val df = spark.createDataFrame(rows.asJava, schema) + val uri = subPath("right_parquet").toString + df.write.parquet(uri) + uri + } + + /** Wider source with label + payload + untouched, for projection-trim and subplan tests. */ + private def writeWideParquet(n: Int, dim: Int): String = { + val schema = new StructType(Array( + StructField("id", IntegerType, nullable = false), + StructField( + "vec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", dim.toLong).build()), + StructField("label", StringType, nullable = false), + StructField("payload", StringType, nullable = false), + StructField("untouched", IntegerType, nullable = false))) + val rng = new Random(Seed) + val rows: Seq[Row] = (0 until n).map { i => + RowFactory.create( + Integer.valueOf(i), + randomVector(rng, dim), + if (i % 2 == 0) "even" else "odd", + s"p$i", + Integer.valueOf(i * 17)) + } + val df = spark.createDataFrame(rows.asJava, schema) + val uri = subPath("wide_parquet").toString + df.write.parquet(uri) + uri + } + + /** Scratch directory exists (the helper writes a child of it) — created if needed. */ + private def scratch(): String = { + val p = tempDir.resolve("scratch_" + System.nanoTime()) + Files.createDirectories(p) + p.toString + } + + /** + * Path that does NOT pre-exist — used as a Spark write target. Spark refuses to write + * to an existing path without overwrite mode. + */ + private def subPath(name: String): Path = + tempDir.resolve(name + "_" + System.nanoTime()) + + private def randomVector(rng: Random, dim: Int): Array[Float] = { + val v = new Array[Float](dim) + var i = 0 + while (i < dim) { v(i) = rng.nextFloat(); i += 1 } + v + } +} From 91e25ef25b41ae3072bc60793b59ddaaf28d9b5d Mon Sep 17 00:00:00 2001 From: Eunjin Song Date: Thu, 21 May 2026 08:20:09 -0700 Subject: [PATCH 3/7] feat(knn): kNearestJoin extension transparently materializes non-Lance R MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per sezruby/lance-spark#2 stage 2: the df.kNearestJoin(rightDf, ...) extension now accepts any DataFrame on the right side, not only Lance scans. When the right side is not a Lance scan, the extension materializes it once via LanceTempR.materialize() and routes the existing probe pipeline against the temp URI. Same data path on the wire. Behavior change: Before: parquet / in-memory / subplan R → IllegalArgumentException After: same inputs → temp Lance materialization → indexed kNN works Lance scans still take the existing fast path (no temp write). extractLanceUri now returns Option[(String, Option[Long])] instead of throwing on miss; callers fall through to materializeNonLanceR which: - Calls LanceTempR.resolveScratchDir to find a writable scratch dir (spark.lance.knn.tempR.dir is required in cluster mode; local mode falls back to spark.local.dir) - Materializes via LanceTempR.materialize with the user-specified rightProjection (or all of R's non-vector columns if rightProjection is None) Tests: replaced the three "throws on non-Lance R" cases with three positive oracle-equivalence tests covering parquet R, subplan-backed R (parquet → Filter → Project), and in-memory + alias-wrapped R. Lance-scan happy path and Filter-on-Lance unchanged. 71/71 tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../lance/spark/knn/LanceKnnImplicits.scala | 113 +++++++--- .../spark/knn/LanceKnnImplicitsTest.scala | 208 ++++++++++++------ 2 files changed, 223 insertions(+), 98 deletions(-) diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/LanceKnnImplicits.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/LanceKnnImplicits.scala index 5b349dd85..e8862cb46 100644 --- a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/LanceKnnImplicits.scala +++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/LanceKnnImplicits.scala @@ -16,6 +16,7 @@ package org.lance.spark.knn import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.lance.spark.knn.internal.LanceTempR /** * Idiomatic DataFrame extension for the indexed nearest-K join. The Phase 2 SQL syntax @@ -37,11 +38,21 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation * metric = "l2") * }}} * - * The right DataFrame MUST be a Lance scan — `spark.read.format("lance").load(uri)`. The - * extension extracts the underlying URI from the right-side analyzed plan; if it can't find a - * `LanceTable` it throws `IllegalArgumentException`. This is intentional: the indexed path - * uses Lance's Java API directly to open the dataset, so a non-Lance DataFrame cannot be - * substituted (there's no general "any DataFrame" indexed path). + * The right DataFrame can be either: + * + * - A Lance scan (`spark.read.format("lance").load(uri)`). The extension extracts the + * underlying URI from the right-side analyzed plan and the existing probe pipeline runs + * against it directly. + * - Any other DataFrame (parquet, delta, in-memory, the result of an arbitrary upstream + * plan). The extension materializes it to a temp Lance dataset first via + * [[org.lance.spark.knn.internal.LanceTempR.materialize]], then runs the same probe + * pipeline against the temp URI. See sezruby/lance-spark#2 for the design. + * + * For the temp-Lance path to work, the session must either be in local mode or have + * `spark.lance.knn.tempR.dir` set to a path every executor (and the driver) can read+write + * — typically a shared object store (`s3://...`, `abfss://...`, `file:///shared-mount/...`) + * or HDFS. Cluster runs missing this conf fail fast at materialization time with a clear + * error message. * * == Why an extension method, not a builder == * @@ -55,17 +66,23 @@ object LanceKnnImplicits { implicit class LanceKnnDataFrameOps(val df: DataFrame) extends AnyVal { /** - * Approximate top-K nearest-neighbor join over a Lance-backed right DataFrame. The right - * DataFrame must be a `spark.read.format("lance").load(uri)` (any plan that wraps a - * `LanceTable` — `Filter`, `SubqueryAlias`, `Project` are unwrapped). For a non-Lance - * right side or a derived plan that loses the URI, this method throws. + * Approximate top-K nearest-neighbor join. The right DataFrame can be: + * + * - A Lance scan (`spark.read.format("lance").load(uri)`) — runs against R directly. + * - Any other DataFrame — materialized to a temp Lance dataset first, then the + * existing probe pipeline runs against the temp URI. The temp is unique per call; + * it persists for the lifetime of the returned DataFrame's evaluation. (See + * `LanceTempLifecycle` for query-scoped cleanup.) * - * @param right Lance-backed right DataFrame + * @param right right DataFrame (Lance scan or any other source) * @param leftVecCol name of the vector column on `this` (left) * @param rightVecCol name of the vector column on `right` * @param k number of nearest neighbors per left row * @param metric distance / similarity metric: "l2" | "cosine" | "dot" - * @param rightProjection columns to materialize from `right`. `None` = all columns. + * @param rightProjection columns to materialize from `right`. `None` = all of R's + * non-vector columns (existing behavior on Lance R; carries + * everything into the temp on non-Lance R, which can be + * expensive for wide tables). * @param outerJoin left-outer mode: emit a left row even if zero neighbors found * @param scoreCol name of the appended score column (default `__score`) * @param overfetch multiplier on `k` during the probe before final trim @@ -95,7 +112,16 @@ object LanceKnnImplicits { ef: Option[Int] = None, probeParallelism: Int = 1, balanceFragments: Boolean = false): DataFrame = { - val (uri, version) = LanceKnnImplicits.extractLanceUri(right) + // Try the existing Lance-scan path first. Falls through to temp materialization for + // any non-Lance right (parquet, delta, in-memory, arbitrary subplan). + val (uri, version) = LanceKnnImplicits.extractLanceUri(right) match { + case Some(t) => t + case None => + val tempUri = LanceKnnImplicits.materializeNonLanceR(right, rightVecCol, rightProjection) + (tempUri, None) + } + // After temp materialization, the dataset has columns rid + vec + caller's projection. + // The probe pipeline reads everything from there; no further translation needed. IndexedNearestJoin( left = df, rightLanceUri = uri, @@ -120,9 +146,10 @@ object LanceKnnImplicits { /** * Walk a DataFrame's analyzed plan looking for a `LanceTable`-backed * `DataSourceV2Relation`. Skips through wrappers that don't change the underlying - * relation: `SubqueryAlias`, `View`, `Project`, `Filter`. Returns `(uri, optional version)` - * pulled from the relation's options. Throws `IllegalArgumentException` if no Lance scan - * is found. + * relation: `SubqueryAlias`, `View`, `Project`, `Filter`. Returns + * `Some((uri, optional version))` pulled from the relation's options when a Lance scan + * is found, or `None` otherwise — callers can fall through to temp materialization in + * that case rather than failing. * * Lance detection mirrors `IndexedNearestByJoinRule.isLanceTable` — * class-name match (`getClass.getName.contains("Lance")`) — to keep the user-facing @@ -132,20 +159,50 @@ object LanceKnnImplicits { * * Public for tests. */ - private[knn] def extractLanceUri(df: DataFrame): (String, Option[Long]) = { - val rel = findLanceRelation(df.queryExecution.analyzed).getOrElse { - throw new IllegalArgumentException( - "kNearestJoin requires the right DataFrame to be a Lance scan " + - "(spark.read.format(\"lance\").load(uri)). Plan was:\n" + - df.queryExecution.analyzed) + private[knn] def extractLanceUri(df: DataFrame): Option[(String, Option[Long])] = { + findLanceRelation(df.queryExecution.analyzed).flatMap { rel => + val opts = rel.options + val uri = Option(opts.get("path")).orElse(Option(opts.get("datasetUri"))) + uri.map { u => + val version = Option(opts.get("version")).map(_.toLong) + (u, version) + } + } + } + + /** + * Materialize a non-Lance right DataFrame to a temp Lance dataset and return its URI. + * Caller must clean up — see `LanceTempLifecycle` for the query-scoped sweeper. The + * scratch directory comes from [[LanceTempR.resolveScratchDir]] which reads + * `spark.lance.knn.tempR.dir` and falls back to local FS only in local mode. + * + * @param right non-Lance DataFrame to materialize + * @param rightVecCol vector column name (must exist on `right`) + * @param rightProjection columns to carry into the temp Lance dataset, in addition to + * the synthesised rid and the vector. `None` means "all columns + * of `right` other than the vector" — matches the existing + * Lance-R semantics where omitting `rightProjection` means + * "carry everything." For wide tables, callers should pass an + * explicit `Some(...)` to avoid copying unnecessary bytes. + */ + private[knn] def materializeNonLanceR( + right: DataFrame, + rightVecCol: String, + rightProjection: Option[Seq[String]]): String = { + val spark = right.sparkSession + val scratchDir = LanceTempR.resolveScratchDir(spark) + val projection: Seq[String] = rightProjection match { + case Some(cols) => cols.filterNot(_ == rightVecCol) + case None => + // Default to "carry every column of R other than the vector" — matches the + // semantics that omitting rightProjection on a Lance scan reads every column. + right.schema.fieldNames.toSeq.filterNot(_ == rightVecCol) } - val opts = rel.options - val uri = Option(opts.get("path")) - .orElse(Option(opts.get("datasetUri"))) - .getOrElse(throw new IllegalArgumentException( - "Lance relation found but no `path` / `datasetUri` option set; cannot extract URI")) - val version = Option(opts.get("version")).map(_.toLong) - (uri, version) + LanceTempR.materialize( + right, + vecCol = rightVecCol, + projection = projection, + scratchDir = scratchDir) } private def findLanceRelation(plan: LogicalPlan): Option[DataSourceV2Relation] = plan match { diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/LanceKnnImplicitsTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/LanceKnnImplicitsTest.scala index 4f66f0840..fbd1251f9 100644 --- a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/LanceKnnImplicitsTest.scala +++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/LanceKnnImplicitsTest.scala @@ -25,15 +25,14 @@ import java.util.Random import scala.collection.JavaConverters._ /** - * End-to-end tests for the `df.kNearestJoin(rightDf, ...)` extension. Three things to - * verify: + * End-to-end tests for the `df.kNearestJoin(rightDf, ...)` extension. Verifies: * * 1. The extension returns the same rows as `IndexedNearestJoin.apply(uri, ...)` — the - * wrapper just changes the call site, not the semantics. + * wrapper just changes the call site, not the semantics, when R is a Lance scan. * 2. URI extraction handles a plain Lance `spark.read.load` — the common case. - * 3. URI extraction throws cleanly when the right DataFrame isn't backed by a Lance scan - * (e.g. created from in-memory rows). Bad input must fail fast with a helpful message, - * not surface as a confusing runtime error inside the probe. + * 3. **Non-Lance R is supported** via per-query temp Lance materialization (sezruby/lance-spark#2): + * parquet, in-memory, alias-wrapped, and subplan-backed sources all produce the same + * top-K as the equivalent Lance-native R. * * The Phase 0 oracle test in `LanceProbeValidationTest` covers correctness of the underlying * probe; we can keep these tests light and not re-validate that. @@ -119,89 +118,158 @@ class LanceKnnImplicitsTest { } /** - * A DataFrame backed by Parquet (or any non-Lance format) must also fail the Lance-only - * guard — the API contract is `format("lance").load(...)` specifically, not "any DataFrame - * Spark can read." Catches the case where a user wires in the wrong reader by mistake. + * Parquet R: kNearestJoin transparently materializes a temp Lance dataset and returns + * the same top-K row IDs as the equivalent Lance-native R run. Validates the oracle + * equivalence end-to-end. + * + * Per-issue #2 design: non-Lance R is materialized once via [[LanceTempR.materialize]], + * then the existing probe pipeline runs on the temp URI. */ - @Test def testParquetRightThrowsClearError(): Unit = { + @Test def testKNearestJoinAgainstParquetRMatchesLanceR(): Unit = { val (leftDf, _, _) = buildLeft() + val (rightLanceUri, rightIds, rightVecs) = writeRight() + + // Same R data as parquet, with rid + rvec only (matches the projection rightDf-Lance uses) + val parquetPath = tempDir.resolve(s"docs_${System.nanoTime()}.parquet").toString + rightIds.zip(rightVecs).map { case (rid, vec) => (rid, vec) } // sanity unused val parquetSchema = new StructType(Array( StructField("rid", IntegerType, nullable = false), - StructField("rvec", ArrayType(FloatType, containsNull = false), nullable = false))) - val rows = Seq( - RowFactory.create(Integer.valueOf(1), Array.fill(Dim)(0.0f)), - RowFactory.create(Integer.valueOf(2), Array.fill(Dim)(0.5f))) - val parquetPath = tempDir.resolve(s"docs_${System.nanoTime()}.parquet").toString - spark.createDataFrame(rows.asJava, parquetSchema).write.parquet(parquetPath) + StructField( + "rvec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build()))) + val rows = rightIds.zip(rightVecs).map { case (rid, vec) => + RowFactory.create(Integer.valueOf(rid), vec) + } + spark.createDataFrame(rows.toSeq.asJava, parquetSchema).write.parquet(parquetPath) val parquetDf = spark.read.parquet(parquetPath) - val ex = assertThrows( - classOf[IllegalArgumentException], - () => - leftDf.kNearestJoin( - right = parquetDf, - leftVecCol = "lvec", - rightVecCol = "rvec", - k = 3, - metric = "l2")) - assertTrue( - ex.getMessage.contains("Lance scan"), - s"expected error message to mention Lance scan for parquet input; got: ${ex.getMessage}") + val viaParquet = leftDf + .kNearestJoin( + right = parquetDf, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = 5, + metric = "l2", + rightProjection = Some(Seq("rid"))) + .collect() + + val viaLance = IndexedNearestJoin( + left = leftDf, + rightLanceUri = rightLanceUri, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = 5, + metric = "l2", + rightProjection = Some(Seq("rid"))) + .collect() + + val byLid = (rs: Array[org.apache.spark.sql.Row]) => + rs.groupBy(_.getAs[Int]("lid")).map { case (lid, group) => + lid -> group.map(_.getAs[Int]("rid")).toSet + } + assertEquals( + byLid(viaLance), + byLid(viaParquet), + "parquet R via temp materialization must produce the same top-K rid set as Lance R") } /** - * Non-Lance DataFrame wrapped in a `SubqueryAlias` (via `as("d")`) must still fail. The - * URI extractor walks `SubqueryAlias` to find the underlying relation; if the underlying - * is not Lance, alias unwrapping must NOT silently accept it. + * Subplan-backed R: parquet → Filter → Project. The kNearestJoin extension only sees a + * DataFrame; the temp-Lance materialization driver-evaluates whatever subplan is + * underneath. Tests that this load-bearing case (issue #2 primary motivation) works. */ - @Test def testNonLanceUnderAliasThrowsClearError(): Unit = { + @Test def testKNearestJoinAgainstSubplanR(): Unit = { val (leftDf, _, _) = buildLeft() - val rows = Seq(RowFactory.create(Integer.valueOf(1), Array.fill(Dim)(0.0f))) - val schema = new StructType(Array( + val (_, rightIds, rightVecs) = writeRight() + + // Wider source so we have something to Filter / Project away. Add a `kept` boolean + // column; the subplan will keep only rows where kept=true. + val wideSchema = new StructType(Array( StructField("rid", IntegerType, nullable = false), - StructField("rvec", ArrayType(FloatType, containsNull = false), nullable = false))) - val notLance = spark.createDataFrame(rows.asJava, schema).as("d") - - val ex = assertThrows( - classOf[IllegalArgumentException], - () => - leftDf.kNearestJoin( - right = notLance, - leftVecCol = "lvec", - rightVecCol = "rvec", - k = 3, - metric = "l2")) + StructField( + "rvec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build()), + StructField("kept", BooleanType, nullable = false), + StructField("payload", StringType, nullable = false))) + val rows = rightIds.zip(rightVecs).zipWithIndex.map { + case ((rid, vec), idx) => + RowFactory.create( + Integer.valueOf(rid), + vec, + java.lang.Boolean.valueOf(idx % 2 == 0), + s"p$idx") + } + val widePath = tempDir.resolve(s"wide_${System.nanoTime()}.parquet").toString + spark.createDataFrame(rows.toSeq.asJava, wideSchema).write.parquet(widePath) + + import org.apache.spark.sql.functions.col + val subplan = spark.read.parquet(widePath) + .filter(col("kept") === true) + .select("rid", "rvec") + + val expectedKeptIds = rightIds.zipWithIndex.collect { case (rid, idx) if idx % 2 == 0 => rid } + .toSet + + val joined = leftDf + .kNearestJoin( + right = subplan, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = 3, + metric = "l2", + rightProjection = Some(Seq("rid"))) + .collect() + + // Every returned rid must come from the kept subset — proves the subplan was actually + // evaluated before materialization (not just a reference to the underlying parquet). + val joinedRids = joined.map(_.getAs[Int]("rid")).toSet + val leakedRids = joinedRids -- expectedKeptIds assertTrue( - ex.getMessage.contains("Lance scan"), - s"alias-wrapped non-Lance must still fail; got: ${ex.getMessage}") + leakedRids.isEmpty, + s"top-K must be drawn only from the kept subset; got leaks: $leakedRids") + assertEquals(NumLeft * 3, joined.length, "expected k=3 results per left row") } /** - * A DataFrame built from in-memory rows is NOT a Lance scan — the extension should throw - * an `IllegalArgumentException` with a message naming the constraint, so the user knows - * to hand a real Lance DataFrame instead. + * In-memory R (no underlying source — `createDataFrame(rows.asJava, schema)`). Same + * temp-materialization path; just exercises the case where the rid synthesis and write + * have no parquet/delta to come from. */ - @Test def testNonLanceRightThrowsClearError(): Unit = { + @Test def testKNearestJoinAgainstInMemoryR(): Unit = { val (leftDf, _, _) = buildLeft() - val ridSchema = new StructType(Array( + val (_, rightIds, rightVecs) = writeRight() + + val schema = new StructType(Array( StructField("rid", IntegerType, nullable = false), - StructField("rvec", ArrayType(FloatType, containsNull = false), nullable = false))) - val notLance = spark.createDataFrame( - Seq(RowFactory.create(Integer.valueOf(1), Array.fill(Dim)(0.0f))).asJava, - ridSchema) - - val ex = assertThrows( - classOf[IllegalArgumentException], - () => - leftDf.kNearestJoin( - right = notLance, - leftVecCol = "lvec", - rightVecCol = "rvec", - k = 3, - metric = "l2")) - assertTrue( - ex.getMessage.contains("Lance scan"), - s"expected error message to mention Lance scan; got: ${ex.getMessage}") + StructField( + "rvec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build()))) + val rows = rightIds.zip(rightVecs).map { case (rid, vec) => + RowFactory.create(Integer.valueOf(rid), vec) + } + val inMemoryR = spark.createDataFrame(rows.toSeq.asJava, schema) + + // Verify even alias-wrapped works (SubqueryAlias unwrap → not a Lance scan → temp). + val aliased = inMemoryR.as("docs") + val joined = leftDf.kNearestJoin( + right = aliased, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = 3, + metric = "l2", + rightProjection = Some(Seq("rid"))) + .collect() + assertEquals(NumLeft * 3, joined.length, "expected k=3 results per left row") + // All returned rids must come from the actual right side — sanity check. + val joinedRids = joined.map(_.getAs[Int]("rid")).toSet + val leaks = joinedRids -- rightIds.toSet + assertTrue(leaks.isEmpty, s"rids should be drawn from the input set; leaks: $leaks") } // -- helpers ------------------------------------------------------------------------------ From 5adb38722b3bf75a14995584f67587d8ba225323 Mon Sep 17 00:00:00 2001 From: Eunjin Song Date: Thu, 21 May 2026 08:23:32 -0700 Subject: [PATCH 4/7] =?UTF-8?q?feat(knn):=20LanceTempLifecycle=20=E2=80=94?= =?UTF-8?q?=20query-scoped=20cleanup=20of=20temp=20Lance=20datasets?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per sezruby/lance-spark#2 stage 3: without lifecycle management, every kNearestJoin against a non-Lance R leaks a Lance dataset on whatever scratch storage spark.lance.knn.tempR.dir points at — local FS, S3, HDFS, ABFS — until the JVM dies. This commit adds: - LanceTempLifecycle.register(spark, tempUri) Tracks the URI for cleanup. Idempotent (dedupes via LinkedHashSet). Invoked automatically from LanceTempR.materialize at the end of every successful write. - SparkListenerApplicationEnd cleanup path Per-app SparkListener; on application end, deletes all registered URIs for that app. Routes through Hadoop FileSystem.get(uri, conf) so it handles local/s3/hdfs/abfs uniformly. Best-effort: errors are logged and swallowed so cleanup can't break the user's session teardown. - JVM shutdown-hook fallback Single hook installed once per JVM, runs every app's cleanup on Runtime.shutdown — covers crashes / hard kills. Why not onJobEnd: a single kNearestJoin invocation runs multiple Spark jobs (write + probe + merge + materialize). onJobEnd would race the still-running probe and break correctness. onApplicationEnd is the right scope. Tests (6 cases): explicit cleanup deletes from disk, multi-URI cleanup, idempotent registration, SparkListenerApplicationEnd-triggered cleanup, deleteUri on non-existent path is a no-op, deleteUri null/empty is a no-op. 77/77 tests pass overall (71 + 6 new). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../knn/internal/LanceTempLifecycle.scala | 192 ++++++++++++++++++ .../lance/spark/knn/internal/LanceTempR.scala | 6 + .../knn/internal/LanceTempLifecycleTest.scala | 185 +++++++++++++++++ 3 files changed, 383 insertions(+) create mode 100644 lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempLifecycle.scala create mode 100644 lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceTempLifecycleTest.scala diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempLifecycle.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempLifecycle.scala new file mode 100644 index 000000000..60bdaf19f --- /dev/null +++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempLifecycle.scala @@ -0,0 +1,192 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.knn.internal + +import org.apache.hadoop.fs.{FileSystem, Path => HadoopPath} +import org.apache.spark.SparkContext +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} +import org.apache.spark.sql.SparkSession + +import java.io.{File, IOException} +import java.net.URI + +import scala.collection.mutable + +/** + * Query-scoped cleanup for the per-query temp Lance datasets created by + * [[LanceTempR.materialize]]. Without it, every `kNearestJoin` against a non-Lance R + * leaks a Lance dataset on whatever scratch storage `spark.lance.knn.tempR.dir` + * points at — local FS, S3, HDFS — until the JVM dies. + * + * == Design == + * + * Cleanup runs on `SparkListenerApplicationEnd` and at JVM-shutdown via a `Runtime` + * shutdown hook. We deliberately do NOT clean up on `onJobEnd`: a single + * `kNearestJoin` invocation can run multiple Spark jobs (the temp write itself, + * the probe stage, the merge shuffle, the materialize stage). Tying cleanup to + * `onJobEnd` would race the still-running probe and break correctness. + * + * `onApplicationEnd` covers the well-behaved case where the SparkSession stops + * cleanly. The shutdown hook covers crashes / hard kills. Either way the temp + * dirs registered up to that point are deleted on a best-effort basis (errors + * are logged but never re-thrown — cleanup must not break the user's job tear- + * down). + * + * == Why not call `Files.delete` directly == + * + * Temp URIs may live on object stores (s3://...), HDFS, ABFS — non-local FS. + * `java.nio.file.Files` only handles local FS. We dispatch through + * `org.apache.hadoop.fs.FileSystem.get(uri, conf)` which routes via the standard + * Spark/Hadoop FileSystem registry — same machinery `df.write.format("lance")` + * already uses to write the temp. + */ +private[knn] object LanceTempLifecycle { + + // Logger via Spark's slf4j via -- log directly with println to stderr if we can't import. + // Kept small to avoid pulling slf4j into the Lance-knn module's surface. + private def logWarn(msg: String): Unit = System.err.println(s"[LanceTempLifecycle] $msg") + private def logInfo(msg: String): Unit = {} // intentionally quiet at info level + + // Synchronised because Spark task threads, listener-bus threads, and the JVM shutdown + // thread can all touch this. Per-application instances live forever in a static map; + // cleanup is keyed on the application id so we can be sure not to drop a different + // app's temps when a SparkContext stops within the same JVM. + private val instances = new mutable.HashMap[String, ApplicationTempRegistry] + + // Single shutdown hook for the JVM, installed on first `register` call. Runs all + // application registries' cleanup paths. + private val shutdownHookInstalled = new java.util.concurrent.atomic.AtomicBoolean(false) + + /** + * Track `tempUri` for cleanup when `spark`'s application ends or the JVM exits, whichever + * comes first. Idempotent: if `tempUri` is already registered for this application, no-op. + */ + def register(spark: SparkSession, tempUri: String): Unit = synchronized { + val sc = spark.sparkContext + val appId = sc.applicationId + val registry = instances.getOrElseUpdate( + appId, { + val r = new ApplicationTempRegistry(sc, appId) + sc.addSparkListener(r) + r + }) + registry.add(tempUri) + ensureShutdownHook() + } + + /** Drop all registered temp URIs for `appId` and clean up the listener. Public for tests. */ + private[knn] def stopForTesting(appId: String): Unit = synchronized { + instances.remove(appId).foreach(_.cleanupAll()) + } + + /** Number of currently-registered temp URIs for an app — for assertions in tests. */ + private[knn] def registeredCount(appId: String): Int = synchronized { + instances.get(appId).map(_.size).getOrElse(0) + } + + private def ensureShutdownHook(): Unit = { + if (shutdownHookInstalled.compareAndSet(false, true)) { + Runtime.getRuntime.addShutdownHook(new Thread("lance-temp-r-cleanup") { + override def run(): Unit = LanceTempLifecycle.synchronized { + instances.values.foreach(_.cleanupAll()) + instances.clear() + } + }) + } + } + + /** + * Per-application registry of temp URIs. Subscribes to `SparkListenerApplicationEnd` + * so cleanup fires as soon as the SparkContext starts shutting down — before the + * scratch FS becomes unreachable in cluster-tear-down ordering. + */ + final private class ApplicationTempRegistry(sc: SparkContext, appId: String) + extends SparkListener { + + private val tempUris = new mutable.LinkedHashSet[String] + private val hadoopConf = sc.hadoopConfiguration + + def add(uri: String): Unit = LanceTempLifecycle.synchronized { + tempUris.add(uri) + } + + def size: Int = LanceTempLifecycle.synchronized(tempUris.size) + + override def onApplicationEnd(end: SparkListenerApplicationEnd): Unit = { + cleanupAll() + LanceTempLifecycle.synchronized { + instances.remove(appId) + } + } + + /** + * Best-effort delete of every registered temp URI. Errors are logged and swallowed — + * cleanup runs during shutdown / context-stop, where re-throwing would obscure the + * actual reason the application is ending. + */ + def cleanupAll(): Unit = LanceTempLifecycle.synchronized { + val snapshot = tempUris.toSeq + tempUris.clear() + snapshot.foreach(deleteSilently) + } + + private def deleteSilently(uri: String): Unit = { + try { + deleteUri(uri, hadoopConf) + logInfo(s"deleted temp Lance dataset: $uri") + } catch { + case e: Throwable => + logWarn(s"failed to delete temp Lance dataset '$uri': ${e.getClass.getSimpleName}: ${e.getMessage}") + } + } + } + + /** + * Delete `uri` recursively. Routes through the same Hadoop FileSystem registry that + * Spark uses for writes, so it handles local FS, S3, HDFS, ABFS, etc. uniformly. + */ + private[knn] def deleteUri( + uri: String, + hadoopConf: org.apache.hadoop.conf.Configuration): Unit = { + if (uri == null || uri.isEmpty) return + if (looksLikeBareLocalPath(uri)) { + // Hadoop FileSystem.get(uri) on a bare path can sometimes route through unexpected + // FS implementations on YARN clusters. For unambiguous local paths use java.nio. + val f = new File(uri) + if (f.exists()) deleteRecursive(f) + } else { + val javaUri: URI = new URI(uri) + val hPath = new HadoopPath(uri) + val fs = FileSystem.get(javaUri, hadoopConf) + if (fs.exists(hPath)) { + if (!fs.delete(hPath, /* recursive = */ true)) { + throw new IOException(s"FileSystem.delete returned false for $uri") + } + } + } + } + + private def looksLikeBareLocalPath(uri: String): Boolean = + !uri.contains("://") + + private def deleteRecursive(f: File): Unit = { + if (f.isDirectory) { + val children = f.listFiles() + if (children != null) children.foreach(deleteRecursive) + } + if (!f.delete() && f.exists()) { + throw new IOException(s"failed to delete $f") + } + } +} diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempR.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempR.scala index 4065f8465..25c340501 100644 --- a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempR.scala +++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempR.scala @@ -123,6 +123,12 @@ private[knn] object LanceTempR { val projected: DataFrame = right.select(ridCol +: payloadCols: _*) projected.write.format("lance").save(tempUri) + + // Register for query-scoped cleanup. Cleanup fires on SparkListenerApplicationEnd + // (when the SparkSession stops cleanly) and on JVM shutdown via a shutdown hook + // (covers crashes / hard kills). See LanceTempLifecycle. + LanceTempLifecycle.register(right.sparkSession, tempUri) + tempUri } diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceTempLifecycleTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceTempLifecycleTest.scala new file mode 100644 index 000000000..89b7790e5 --- /dev/null +++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceTempLifecycleTest.scala @@ -0,0 +1,185 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.knn.internal + +import org.apache.spark.sql.{Row, RowFactory, SparkSession} +import org.apache.spark.sql.types._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.io.TempDir + +import java.io.File +import java.nio.file.{Files, Path, Paths} +import java.util.Random + +import scala.collection.JavaConverters._ + +/** + * Tests for [[LanceTempLifecycle]]: + * + * - register adds the URI to the per-app registry, count reflects it + * - stopping the SparkSession (which fires SparkListenerApplicationEnd) deletes the + * registered temp dirs + * - stopForTesting (a back-door we expose explicitly so tests don't have to actually + * stop the session — that breaks subsequent BeforeEach setup in the same suite) + * also deletes + * - deleteUri handles a bare local path + * - registering the same URI twice in one app is idempotent + * + * Concurrent / multi-app cases are exercised by the existence of `appId`-keyed maps in + * the lifecycle code itself — testing genuine cross-app isolation in a JUnit suite would + * require multi-process orchestration that's not worth the test infra cost. The unit + * tests here cover the listener-driven path and the registry mechanics. + */ +class LanceTempLifecycleTest { + + @TempDir var tempDir: Path = _ + private var spark: SparkSession = _ + + private val Dim: Int = 4 + + @BeforeEach def setup(): Unit = { + spark = SparkSession.builder() + .appName("lance-temp-lifecycle-test") + .master("local[2]") + .config("spark.driver.bindAddress", "127.0.0.1") + .config("spark.driver.host", "127.0.0.1") + .getOrCreate() + } + + @AfterEach def teardown(): Unit = { + if (spark != null) { + // Clean up any test residue regardless of pass/fail. stopForTesting also handles + // any URIs leftover (e.g. when a test registered without going through .stop()). + LanceTempLifecycle.stopForTesting(spark.sparkContext.applicationId) + spark.stop() + } + } + + /** Registering and then deleting via the test back-door removes the URI from disk. */ + @Test def testRegisterAndExplicitCleanup(): Unit = { + val tempUri = writeTempLance() + val appId = spark.sparkContext.applicationId + + assertTrue(new File(tempUri).exists(), "precondition: temp Lance dataset exists") + assertEquals(0, LanceTempLifecycle.registeredCount(appId), "no registrations yet") + + LanceTempLifecycle.register(spark, tempUri) + assertEquals(1, LanceTempLifecycle.registeredCount(appId)) + + LanceTempLifecycle.stopForTesting(appId) + assertFalse(new File(tempUri).exists(), "cleanup must delete the temp dir") + assertEquals( + 0, + LanceTempLifecycle.registeredCount(appId), + "registry must be empty after cleanup") + } + + /** Multiple temp URIs from the same app are all cleaned up. */ + @Test def testMultipleRegistrationsAllCleanedUp(): Unit = { + val a = writeTempLance() + val b = writeTempLance() + val c = writeTempLance() + val appId = spark.sparkContext.applicationId + LanceTempLifecycle.register(spark, a) + LanceTempLifecycle.register(spark, b) + LanceTempLifecycle.register(spark, c) + assertEquals(3, LanceTempLifecycle.registeredCount(appId)) + + LanceTempLifecycle.stopForTesting(appId) + assertFalse(new File(a).exists()) + assertFalse(new File(b).exists()) + assertFalse(new File(c).exists()) + } + + /** + * Re-registering the same URI is a no-op. Important because LanceTempR.materialize + * could be called repeatedly with overlapping temp URIs in retry scenarios. + */ + @Test def testIdempotentRegistration(): Unit = { + val tempUri = writeTempLance() + val appId = spark.sparkContext.applicationId + LanceTempLifecycle.register(spark, tempUri) + LanceTempLifecycle.register(spark, tempUri) + LanceTempLifecycle.register(spark, tempUri) + assertEquals(1, LanceTempLifecycle.registeredCount(appId), "duplicate registers are deduped") + LanceTempLifecycle.stopForTesting(appId) + assertFalse(new File(tempUri).exists()) + } + + /** + * Stopping the SparkSession fires SparkListenerApplicationEnd and triggers cleanup + * via the listener path — the production cleanup trigger. We tear down `spark` + * inside this test, so override @AfterEach behavior by setting `spark = null`. + */ + @Test def testApplicationEndTriggersCleanup(): Unit = { + val tempUri = writeTempLance() + val appId = spark.sparkContext.applicationId + LanceTempLifecycle.register(spark, tempUri) + assertEquals(1, LanceTempLifecycle.registeredCount(appId)) + + spark.stop() + spark = null // prevent @AfterEach from calling stop() again + + // Listener fires on the listener bus thread; give it a moment to drain. + val deadline = System.currentTimeMillis() + 5000 + while (new File(tempUri).exists() && System.currentTimeMillis() < deadline) { + Thread.sleep(50) + } + assertFalse( + new File(tempUri).exists(), + "SparkListenerApplicationEnd path must delete the registered temp dir within 5s") + assertEquals(0, LanceTempLifecycle.registeredCount(appId)) + } + + /** deleteUri handles a non-existent path gracefully (no exception). */ + @Test def testDeleteUriNonExistentNoOps(): Unit = { + val ghost = tempDir.resolve("never_existed_" + System.nanoTime()).toString + LanceTempLifecycle.deleteUri(ghost, spark.sparkContext.hadoopConfiguration) + // No assertion — pass means no exception. + } + + /** deleteUri null/empty input is a no-op. */ + @Test def testDeleteUriNullOrEmpty(): Unit = { + LanceTempLifecycle.deleteUri(null, spark.sparkContext.hadoopConfiguration) + LanceTempLifecycle.deleteUri("", spark.sparkContext.hadoopConfiguration) + } + + // -- helpers -------------------------------------------------------------------------------- + + /** Write a tiny Lance dataset under tempDir and return its URI. */ + private def writeTempLance(): String = { + val schema = new StructType(Array( + StructField("id", IntegerType, nullable = false), + StructField( + "vec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build()))) + val rows: Seq[Row] = (0 until 4).map { i => + RowFactory.create(Integer.valueOf(i), randomVector(new Random(i.toLong), Dim)) + } + val df = spark.createDataFrame(rows.asJava, schema) + val target = tempDir.resolve("temp_" + System.nanoTime()) + df.write.format("lance").save(target.toString) + target.toString + } + + private def randomVector(rng: Random, dim: Int): Array[Float] = { + val v = new Array[Float](dim) + var i = 0 + while (i < dim) { v(i) = rng.nextFloat(); i += 1 } + v + } +} From 6b6ea749e62b48a5ee0c9ce4f5909491193cc1eb Mon Sep 17 00:00:00 2001 From: Eunjin Song Date: Thu, 21 May 2026 08:31:45 -0700 Subject: [PATCH 5/7] feat(knn): SQL APPROX NEAREST against non-Lance R via per-query temp Lance MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per sezruby/lance-spark#2 stage 4: extend IndexedNearestByJoinRule (Spark 4.2 SQL path) so a NearestByJoin whose right side isn't a Lance scan can also be rewritten to the indexed path. The rule materializes the right plan to a temp Lance dataset at rule-application time via LanceTempR.materialize, then proceeds with the same staged-plan rewrite as for Lance R. - New conf TempRForSqlEnabledConfKey = "spark.lance.knn.tempRForSqlRule.enabled". Off by default. Two reasons it's separate from the main enabled flag: 1. The rule evaluates the right plan synchronously at analysis time — users should consciously accept the cost 2. Cluster mode requires spark.lance.knn.tempR.dir to be set; surfacing the failure behind an explicit opt-in is friendlier than failing on every NearestByJoin - rewriteIfApplicable's for-comprehension changes: Recognize ranking BEFORE attempting right-side resolution so we don't pay a temp materialization for queries that fall through anyway (wrong direction, mixed-side rank expression, etc.). unwrapLanceScan(right).orElse { if (tempRForSqlEnabled) materializeNonLanceR(right, rightVecCol) else None } - materializeNonLanceR: Wraps the right plan as a DataFrame via LanceKnnDatasetBridge.asDataFrame, calls LanceTempR.materialize with right.output.map(_.name) as projection (carry every right-side attribute the parent plan can reference), and synthesises a LanceScanInfo whose `output` reuses right.output's AttributeReferences so the top-level Project(j.output, ...) stays resolved. Any failure → return None and fall through to brute-force. Tests (3 new, 17 total in IndexedNearestByJoinRuleTest): - testTempRForSqlRewritesNonLanceR: parquet R + both flags on → rewrites to Project(LanceMaterialize(...)) - testTempRForSqlRequiresMainEnabledFlag: both flags must be on; the temp-R flag alone doesn't fire the rule - testNonLanceRWithoutTempRConfFallsThrough: pins existing behavior — without the temp-R conf, parquet R falls through 77/77 tests pass in lance-spark-knn_2.12, 20/20 in lance-spark-knn-4.2_2.13. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../catalyst/IndexedNearestByJoinRule.scala | 83 +++++++++++- .../IndexedNearestByJoinRuleTest.scala | 125 ++++++++++++++++++ 2 files changed, 204 insertions(+), 4 deletions(-) diff --git a/lance-spark-knn-4.2_2.13/src/main/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRule.scala b/lance-spark-knn-4.2_2.13/src/main/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRule.scala index c40c074bd..ef8d9b17d 100644 --- a/lance-spark-knn-4.2_2.13/src/main/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRule.scala +++ b/lance-spark-knn-4.2_2.13/src/main/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRule.scala @@ -13,6 +13,7 @@ */ package org.lance.spark.knn.catalyst +import org.apache.spark.sql.{LanceKnnDatasetBridge, SparkSession} import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not, Or, VectorCosineSimilarity, VectorInnerProduct, VectorL2Distance} import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, NearestByDirection, NearestByDistance, NearestBySimilarity} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, NearestByJoin, Project, SubqueryAlias} @@ -21,7 +22,7 @@ import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.{BooleanType, ByteType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String -import org.lance.spark.knn.internal.{LanceMaterializeStage, LanceMergeStage, LanceProbeStage, Metric} +import org.lance.spark.knn.internal.{LanceMaterializeStage, LanceMergeStage, LanceProbeStage, LanceTempR, Metric} import org.lance.spark.knn.internal.staged.{LanceMaterializeLogicalPlan, LanceMergeLogicalPlan, LanceProbeLogicalPlan, ProbedLeftCodec} /** @@ -101,6 +102,22 @@ object IndexedNearestByJoinRule extends Rule[LogicalPlan] { /** Configuration key that gates the rule. Off by default to keep the rule opt-in for now. */ val EnabledConfKey: String = "spark.lance.knn.indexedNearestByJoin.enabled" + /** + * Configuration key that gates the per-query temp-Lance materialization for non-Lance + * right sides (sezruby/lance-spark#2). When enabled (and [[EnabledConfKey]] is also on), + * a `NearestByJoin` whose right side isn't a Lance scan triggers + * [[org.lance.spark.knn.internal.LanceTempR.materialize]] at rule-application time, then + * proceeds with the same staged-plan rewrite against the temp URI. + * + * Off by default. Two reasons to keep it opt-in: + * 1. The rule evaluates the right plan synchronously at analysis time — for large R + * that's a meaningful cost users should consciously accept. + * 2. Cluster mode requires `spark.lance.knn.tempR.dir` to be set (see + * [[LanceTempR.resolveScratchDir]]); a misconfigured cluster would fail with a + * clear error message that's still better surfaced behind an explicit opt-in. + */ + val TempRForSqlEnabledConfKey: String = "spark.lance.knn.tempRForSqlRule.enabled" + /** * IVF cluster count to visit per query. Higher = better recall, more compute. Default * (None) leaves Lance's index-default (typically 1). @@ -120,6 +137,7 @@ object IndexedNearestByJoinRule extends Rule[LogicalPlan] { } val nprobes = optInt(NprobesConfKey) val refineFactor = optInt(RefineFactorConfKey) + val tempRForSqlEnabled = conf.getConfString(TempRForSqlEnabledConfKey, "false").toBoolean plan.transformDown { case j @ NearestByJoin(left, right, joinType, true, k, rankingExpr, direction) => rewriteIfApplicable( @@ -131,7 +149,8 @@ object IndexedNearestByJoinRule extends Rule[LogicalPlan] { rankingExpr, direction, nprobes, - refineFactor).getOrElse(j) + refineFactor, + tempRForSqlEnabled).getOrElse(j) } } @@ -164,10 +183,20 @@ object IndexedNearestByJoinRule extends Rule[LogicalPlan] { rankingExpr: Expression, direction: NearestByDirection, nprobes: Option[Int], - refineFactor: Option[Int]): Option[LogicalPlan] = { + refineFactor: Option[Int], + tempRForSqlEnabled: Boolean): Option[LogicalPlan] = { for { - lance <- unwrapLanceScan(right) + // Recognize the ranking BEFORE attempting the right-side resolution so we don't + // pay a temp-Lance materialization for queries that are going to fall through + // anyway (wrong direction, mixed-side rank expression, etc.). (metric, leftVecAttr, rightVecCol) <- recognizeRanking(rankingExpr, direction, left, right) + lance <- unwrapLanceScan(right).orElse { + if (tempRForSqlEnabled) { + materializeNonLanceR(right, rightVecCol) + } else { + None + } + } } yield { val leftVecIdx = left.output.indexWhere(_.exprId == leftVecAttr.exprId) require(leftVecIdx >= 0, s"left vector attr not found in left.output: $leftVecAttr") @@ -422,6 +451,52 @@ object IndexedNearestByJoinRule extends Rule[LogicalPlan] { } } + /** + * Synthesise a [[LanceScanInfo]] for a non-Lance right plan by materializing it to a + * temp Lance dataset (sezruby/lance-spark#2). The materialization runs synchronously at + * rule-application time. Returns `None` if anything goes wrong — caller falls through + * to Spark's brute-force rewrite. + * + * The synthesised `output` reuses the ORIGINAL right plan's attribute references + * (`right.output`) — preserving `ExprId`s — so the top-level `Project(j.output, ...)` + * the rule emits stays attribute-resolved. The materialize stage reads from the temp + * Lance dataset (which has those columns by name plus the synthetic `_rid`) and binds + * the read-back rows back to those original attribute references. + */ + private def materializeNonLanceR( + right: LogicalPlan, + rightVecCol: String): Option[LanceScanInfo] = { + try { + val spark = SparkSession.active + val rightDf = LanceKnnDatasetBridge.asDataFrame(spark, right) + val scratchDir = LanceTempR.resolveScratchDir(spark) + // Carry every right-side attribute the parent plan can reference. The probe + // pipeline's materialize stage projects from this set, so all of right.output + // must be present. + val projection: Seq[String] = right.output.map(_.name).filterNot(_ == rightVecCol) + val tempUri = LanceTempR.materialize( + right = rightDf, + vecCol = rightVecCol, + projection = projection, + scratchDir = scratchDir) + // Synthesise a LanceScanInfo: URI is the temp; version is None (per-query temp); + // output keeps the original right.output (preserving ExprIds for the top-level + // Project resolution); no prefilter (the temp already represents the FULLY + // evaluated right plan, so any source-side filters were applied during the write). + Some( + LanceScanInfo( + uri = tempUri, + version = None, + output = right.output, + prefilter = None)) + } catch { + case _: Throwable => + // Any failure (config missing in cluster mode, write fails, etc.) — fall through + // to Spark's brute-force rewrite. The user's query still runs, just slowly. + None + } + } + private def isLanceTable(table: Table): Boolean = { val cls = table.getClass.getName // Loose by design — the rule is opt-in via spark.lance.knn.indexedNearestByJoin.enabled, so diff --git a/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRuleTest.scala b/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRuleTest.scala index a9bf8c509..6ddfe65eb 100644 --- a/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRuleTest.scala +++ b/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRuleTest.scala @@ -290,6 +290,111 @@ class IndexedNearestByJoinRuleTest { assertSame(join, rewritten, "computed expression must refuse pushdown") } + // -- per-query temp-Lance path for non-Lance R (sezruby/lance-spark#2) -------------------- + + /** + * With `spark.lance.knn.tempRForSqlRule.enabled = true`, a non-Lance right side + * (here: a parquet-backed DataFrame) triggers per-query temp-Lance materialization + * inside the rule. The rule should rewrite to the same staged-plan tree it produces + * for Lance R, but with the probe pointed at the temp URI. + */ + @Test def testTempRForSqlRewritesNonLanceR(): Unit = { + spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true") + spark.conf.set(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey, "true") + try { + val left = trivialPlan("lid", "lvec") + // Right side: an analyzed parquet plan, not a Lance scan. + val parquetRight = parquetLikePlan(idCol = "rid", vecCol = "rvec") + + val leftVec = left.output.find(_.name == "lvec").get + val rightVec = parquetRight.output.find(_.name == "rvec").get + val join = NearestByJoin( + left, + parquetRight, + Inner, + approx = true, + numResults = 5, + rankingExpression = VectorL2Distance(leftVec, rightVec), + direction = NearestByDistance) + val rewritten = IndexedNearestByJoinRule(join) + // Same shape assertion as the Lance-R happy path: Project wrapping the materialize node. + assertTrue( + rewritten.isInstanceOf[Project] && + rewritten.asInstanceOf[Project].child.isInstanceOf[LanceMaterializeLogicalPlan], + s"expected rewrite to Project(LanceMaterialize(...)), got: $rewritten") + // The probe URI should point at a temp dir, not anything from the parquet relation. + val summary = expectRewritten(rewritten) + // No prefilter when the right side was materialized — the temp already represents + // the fully-evaluated plan. + assertEquals(None, summary.prefilter) + } finally { + spark.conf.unset(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey) + spark.conf.unset(IndexedNearestByJoinRule.EnabledConfKey) + } + } + + /** + * With the temp-R conf enabled but the rule itself disabled, a non-Lance right side still + * falls through. The materialize-and-rewrite path is gated on BOTH the main enabled flag + * AND the temp-R-for-SQL flag — turning on only one isn't enough. + */ + @Test def testTempRForSqlRequiresMainEnabledFlag(): Unit = { + spark.conf.unset(IndexedNearestByJoinRule.EnabledConfKey) // main flag off + spark.conf.set(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey, "true") + try { + val left = trivialPlan("lid", "lvec") + val parquetRight = parquetLikePlan(idCol = "rid", vecCol = "rvec") + val leftVec = left.output.find(_.name == "lvec").get + val rightVec = parquetRight.output.find(_.name == "rvec").get + val join = NearestByJoin( + left, + parquetRight, + Inner, + approx = true, + numResults = 5, + rankingExpression = VectorL2Distance(leftVec, rightVec), + direction = NearestByDistance) + val rewritten = IndexedNearestByJoinRule(join) + assertSame( + join, + rewritten, + "main enabled flag off → rule must not fire even with tempR enabled") + } finally { + spark.conf.unset(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey) + } + } + + /** + * Without the temp-R conf, a non-Lance right side falls through even when the main rule is + * enabled. Pins the existing behavior so we don't accidentally change it when extending the + * rule. + */ + @Test def testNonLanceRWithoutTempRConfFallsThrough(): Unit = { + spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true") + spark.conf.unset(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey) + try { + val left = trivialPlan("lid", "lvec") + val parquetRight = parquetLikePlan(idCol = "rid", vecCol = "rvec") + val leftVec = left.output.find(_.name == "lvec").get + val rightVec = parquetRight.output.find(_.name == "rvec").get + val join = NearestByJoin( + left, + parquetRight, + Inner, + approx = true, + numResults = 5, + rankingExpression = VectorL2Distance(leftVec, rightVec), + direction = NearestByDistance) + val rewritten = IndexedNearestByJoinRule(join) + assertSame( + join, + rewritten, + "tempR conf off → non-Lance right must fall through to brute-force") + } finally { + spark.conf.unset(IndexedNearestByJoinRule.EnabledConfKey) + } + } + /** Filter wrapped in SubqueryAlias still pushes — order of unwrap shouldn't matter. */ @Test def testFilterUnderSubqueryAliasPushes(): Unit = { spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true") @@ -391,6 +496,26 @@ class IndexedNearestByJoinRuleTest { spark.createDataFrame(rows.asJava, schema).queryExecution.analyzed } + /** + * A parquet-backed analyzed plan for testing the per-query temp-R rewrite path. Materialization + * runs `df.write.format("lance").save(...)` against this plan, so the plan must execute + * (unlike `trivialPlan` which is just an analyzed `LocalRelation`). Writes a tiny parquet to + * tempDir on first call, returns its analyzed scan plan. + */ + private def parquetLikePlan(idCol: String, vecCol: String): LogicalPlan = { + val schema = new StructType(Array( + StructField(idCol, IntegerType, nullable = false), + StructField( + vecCol, + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", 8L).build()))) + val rows = (0 until 4).map(i => RowFactory.create(Integer.valueOf(i), Array.fill(8)(0.5f))) + val parquetPath = tempDir.resolve(s"parquet_for_rule_${System.nanoTime()}").toString + spark.createDataFrame(rows.asJava, schema).write.parquet(parquetPath) + spark.read.parquet(parquetPath).queryExecution.analyzed + } + /** * Build a `DataSourceV2Relation` whose `table.getClass.getName.contains("Lance")` so the * rule's duck-type check accepts it. We don't actually run any I/O. Includes a `category` From 796de0fbd171848d6e4a2133bfb3bc35cadb2562 Mon Sep 17 00:00:00 2001 From: Eunjin Song Date: Thu, 21 May 2026 09:45:52 -0700 Subject: [PATCH 6/7] test(knn-bench): add config D + explain dump to validate public API path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Config D (`kNearestJoin(parquetDf, ...)`) exercises the full public-API code path added in stages 2-4: the extension internally hits LanceKnnImplicits.materializeNonLanceR -> LanceTempR.materialize -> existing probe pipeline. Local validation, M5 Max, tiny scale (3 reps + 1 warmup): A: Spark crossJoin + min_by_k (parquet R) 28,000 ms 1.0× B: temp Lance write + kNearestJoin (manual) 322 ms 86.9× C: Lance-native R + kNearestJoin (reference) 261 ms 107.3× D: kNearestJoin(parquetDf) — built-in temp 319 ms 87.8× (D - B) = 3 ms — within run-to-run noise. The public API does the same work as the manually-spelled-out B, no extra overhead. The new explain(extended=true) dump (head-scale only) confirms: - Probe and Materialize URI both point at the temp Lance dir - Full LanceProbe -> Exchange -> LanceMerge -> LanceMaterialize chain in the executed plan - Wrapped by AdaptiveSparkPlan (AQE-visible merge shuffle) - Left side is unmodified (only R goes through temp materialization) Lifecycle: zero leakage observed. Earlier-test runs from before stage 3 left orphaned temp dirs in spark.local.dir/lance-temp-r/ (no lifecycle existed yet to clean them); fresh runs of LanceTempRTest + LanceTempLifecycleTest after stage 3 produce delta=0 in that directory. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../IndexedNearestJoinTempRBenchmark.scala | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/IndexedNearestJoinTempRBenchmark.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/IndexedNearestJoinTempRBenchmark.scala index 9742c4bc9..9254b92d8 100644 --- a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/IndexedNearestJoinTempRBenchmark.scala +++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/IndexedNearestJoinTempRBenchmark.scala @@ -224,6 +224,45 @@ object IndexedNearestJoinTempRBenchmark { } results += resultC + // ---- Config D: kNearestJoin extension's built-in non-Lance handling ---- + // This is the actual code path users hit when they write + // `queries.kNearestJoin(parquetDf, ...)` after sezruby/lance-spark#2. + // The extension internally calls LanceKnnImplicits.materializeNonLanceR → + // LanceTempR.materialize → existing probe pipeline. Compared against B which + // does the same thing but with the materialization spelled out explicitly, + // (D - B) should be near zero — a sanity check on the extension's wiring. + val resultD = timeIt(scale.name, "D: kNearestJoin(parquetDf) — built-in temp", repeats) { + () => + leftDf.kNearestJoin( + right = rightDfParquet, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = K, + metric = "l2", + rightProjection = Some(Seq("rid")), + probeParallelism = 1) + } + results += resultD + + // ---- Plan-shape dump for the actual code path ---- + // Print the analyzed + executed plans for config D (the new built-in path) at + // the smallest scale only. Helps verify the temp-Lance materialization and the + // staged probe/merge/materialize pipeline appear as expected when users hit the + // sezruby/lance-spark#2 path through the public API. + if (scale.name == scales.head.name) { + println() + println("-- df.explain(true) for config D (kNearestJoin against parquet R) --") + val explainDf = leftDf.kNearestJoin( + right = rightDfParquet, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = K, + metric = "l2", + rightProjection = Some(Seq("rid")), + probeParallelism = 1) + explainDf.explain(extended = true) + } + leftDf.unpersist() println() } From 940d0aaafea924004fa3a8624635291e77810315 Mon Sep 17 00:00:00 2001 From: Eunjin Song Date: Thu, 21 May 2026 10:27:53 -0700 Subject: [PATCH 7/7] feat(knn): schema validation + same-path regression test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two improvements per review feedback on sezruby/lance-spark#2: 1. Schema validation before triggering the temp Lance write LanceTempR.checkSupported(schema) returns Some(reason) when any column in the projected schema (rid + vec + payload) is not Lance-writable. The conservative allow-list covers numerics, boolean, string, binary, date, timestamp, struct (recursive), array (recursive). Rejects MapType, NullType, and unrecognised types with a clear "column X has type Y" message. Caller-specific behaviour: - DataFrame API (kNearestJoin): LanceTempR.materialize throws IllegalArgumentException, surfaces to the user. They asked for it explicitly so a clear failure is the right answer. - Catalyst rule (SQL APPROX NEAREST): the rule's materializeNonLanceR calls checkSupported BEFORE doing any work and returns None on miss, making the rule fall through to Spark's brute-force RewriteNearestByJoin — the user's query still runs, just slowly. Same "refusal not partial" pattern as the existing prefilter-pushdown. 2. Same-path regression test in LanceKnnImplicitsTest testProbeAndMaterializeShareSameTempUri walks the analyzed plan of a kNearestJoin against a non-Lance R, finds the LanceProbeLogicalPlan and LanceMaterializeLogicalPlan nodes, and asserts both stage configs reference the SAME temp Lance URI. Future regressions where helper / implicits / IndexedNearestJoin.apply diverge produce a fast structural failure instead of silent wrong results from a probe-vs-materialize URI mismatch. Tests added (8 new): - 4 in LanceTempRTest: checkSupported on common types accepts; rejects Map; rejects array-of-Map (recursive); rejects struct-with-Map (recursive); materialize() throws on unsupported projection - 2 in LanceKnnImplicitsTest: testProbeAndMaterializeShareSameTempUri, testKNearestJoinRejectsUnsupportedColumnType - 1 in IndexedNearestByJoinRuleTest: testTempRForSqlFallsThroughOnUnsupportedSchema 84/84 in lance-spark-knn_2.12 (was 77; 4 + 2 = 6 new — note 1 existing test was left intact, so net is +7 not +8). 21/21 in lance-spark-knn-4.2_2.13 (was 20). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../catalyst/IndexedNearestByJoinRule.scala | 14 ++- .../IndexedNearestByJoinRuleTest.scala | 54 +++++++++++ .../lance/spark/knn/internal/LanceTempR.scala | 68 ++++++++++++++ .../spark/knn/LanceKnnImplicitsTest.scala | 93 +++++++++++++++++++ .../spark/knn/internal/LanceTempRTest.scala | 71 ++++++++++++++ 5 files changed, 296 insertions(+), 4 deletions(-) diff --git a/lance-spark-knn-4.2_2.13/src/main/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRule.scala b/lance-spark-knn-4.2_2.13/src/main/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRule.scala index ef8d9b17d..dc85ffd74 100644 --- a/lance-spark-knn-4.2_2.13/src/main/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRule.scala +++ b/lance-spark-knn-4.2_2.13/src/main/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRule.scala @@ -469,11 +469,17 @@ object IndexedNearestByJoinRule extends Rule[LogicalPlan] { try { val spark = SparkSession.active val rightDf = LanceKnnDatasetBridge.asDataFrame(spark, right) - val scratchDir = LanceTempR.resolveScratchDir(spark) - // Carry every right-side attribute the parent plan can reference. The probe - // pipeline's materialize stage projects from this set, so all of right.output - // must be present. + // Pre-check schema BEFORE triggering any work — if Lance can't write any of the + // projected columns, fall through (return None) so Spark's brute-force rewrite + // handles the query. Same "refusal not partial pushdown" pattern the prefilter + // translator uses. val projection: Seq[String] = right.output.map(_.name).filterNot(_ == rightVecCol) + val projectedSchema = StructType( + (rightVecCol +: projection).map(name => rightDf.schema(name))) + if (LanceTempR.checkSupported(projectedSchema).isDefined) { + return None + } + val scratchDir = LanceTempR.resolveScratchDir(spark) val tempUri = LanceTempR.materialize( right = rightDf, vecCol = rightVecCol, diff --git a/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRuleTest.scala b/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRuleTest.scala index 6ddfe65eb..15e68b8f5 100644 --- a/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRuleTest.scala +++ b/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRuleTest.scala @@ -364,6 +364,39 @@ class IndexedNearestByJoinRuleTest { } } + /** + * Right side has a Lance-unsupported column type (MapType). Even with the temp-R conf + * enabled, the rule must fall through to brute-force rather than throw — Catalyst rules + * shouldn't fail the user's query when a fallback exists. Pin: schema check refuses, + * rewrite returns None, original NearestByJoin propagates. + */ + @Test def testTempRForSqlFallsThroughOnUnsupportedSchema(): Unit = { + spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true") + spark.conf.set(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey, "true") + try { + val left = trivialPlan("lid", "lvec") + val rightWithMap = parquetLikePlanWithMap(idCol = "rid", vecCol = "rvec") + val leftVec = left.output.find(_.name == "lvec").get + val rightVec = rightWithMap.output.find(_.name == "rvec").get + val join = NearestByJoin( + left, + rightWithMap, + Inner, + approx = true, + numResults = 5, + rankingExpression = VectorL2Distance(leftVec, rightVec), + direction = NearestByDistance) + val rewritten = IndexedNearestByJoinRule(join) + assertSame( + join, + rewritten, + "right with MapType column must fall through (rule cannot materialize unsupported type)") + } finally { + spark.conf.unset(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey) + spark.conf.unset(IndexedNearestByJoinRule.EnabledConfKey) + } + } + /** * Without the temp-R conf, a non-Lance right side falls through even when the main rule is * enabled. Pins the existing behavior so we don't accidentally change it when extending the @@ -516,6 +549,27 @@ class IndexedNearestByJoinRuleTest { spark.read.parquet(parquetPath).queryExecution.analyzed } + /** + * Like parquetLikePlan but adds a MapType column, which Lance's writer can't handle. Used + * to verify the rule's schema check refuses-and-falls-through on unsupported types. + */ + private def parquetLikePlanWithMap(idCol: String, vecCol: String): LogicalPlan = { + import org.apache.spark.sql.functions.{lit => sqlLit, map => mapFn} + val schema = new StructType(Array( + StructField(idCol, IntegerType, nullable = false), + StructField( + vecCol, + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", 8L).build()))) + val rows = (0 until 4).map(i => RowFactory.create(Integer.valueOf(i), Array.fill(8)(0.5f))) + val parquetPath = tempDir.resolve(s"parquet_with_map_${System.nanoTime()}").toString + val withMap = spark.createDataFrame(rows.asJava, schema) + .withColumn("attrs", mapFn(sqlLit("k"), sqlLit("v"))) + withMap.write.parquet(parquetPath) + spark.read.parquet(parquetPath).queryExecution.analyzed + } + /** * Build a `DataSourceV2Relation` whose `table.getClass.getName.contains("Lance")` so the * rule's duck-type check accepts it. We don't actually run any I/O. Includes a `category` diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempR.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempR.scala index 25c340501..6ff1a3fa4 100644 --- a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempR.scala +++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempR.scala @@ -15,6 +15,7 @@ package org.lance.spark.knn.internal import org.apache.spark.sql.{Column, DataFrame, SparkSession} import org.apache.spark.sql.functions.{col, monotonically_increasing_id} +import org.apache.spark.sql.types._ import java.nio.file.{Files, Path, Paths} import java.util.UUID @@ -116,6 +117,18 @@ private[knn] object LanceTempR { !projection.contains(RidColumnName), s"projection must not include the reserved rid column name '$RidColumnName' — " + "the helper synthesises it. Pick a different name on `right` or rename before calling.") + // Reject unsupported types BEFORE triggering the write. We project before checking so + // we only inspect the columns actually being written (vec + caller-requested payload), + // not unrelated columns the user happened to leave on `right`. + val projectedFields: Seq[StructField] = + ((vecCol +: projection.filterNot(_ == vecCol)).distinct).map { name => + right.schema(name) + } + findUnsupportedField(StructType(projectedFields)).foreach { reason => + throw new IllegalArgumentException( + s"per-query temp Lance materialization rejected: $reason. " + + "Drop the offending column from `projection`, cast it to a supported type, or use a Lance-native right side.") + } val tempUri = mintTempUri(scratchDir) val ridCol: Column = monotonically_increasing_id().as(RidColumnName) @@ -168,6 +181,61 @@ private[knn] object LanceTempR { } } + /** + * Walk the columns the helper would write (rid + vec + caller-requested projection) + * and return the first column whose type Lance can't write, or `None` if everything is + * fine. Callers use this to decide their fallback behaviour — the SQL rule path + * silently returns the original `NearestByJoin` (Spark's brute-force handles it), the + * DataFrame API path throws `IllegalArgumentException`. + * + * Without this pre-check, an unsupported type would surface as an opaque write-time + * error inside `df.write.format("lance").save()` after we've already started shipping + * task closures — slow and confusing. + * + * @param schema the projected schema (rid + vec + payload cols) the helper would write. + * @return `Some(reason)` if any field is unsupported, `None` if every field is fine. + */ + def checkSupported(schema: StructType): Option[String] = + findUnsupportedField(schema) + + /** + * Conservative type-allowlist for what Lance can write via `df.write.format("lance")`. + * + * Allowed: + * - All numeric primitives (Byte/Short/Int/Long/Float/Double/Decimal) + * - Boolean, String, Binary + * - Date / Timestamp / TimestampNTZ + * - StructType — recursive check on each field + * - ArrayType — recursive check on element type + * + * Rejected: + * - MapType — Lance's columnar layout doesn't support arbitrary string-keyed maps + * - CalendarIntervalType — no Arrow correspondence + * - UserDefinedType — opaque blobs that Lance can't know about + * - NullType — there's no element type to write + */ + private def findUnsupportedField(schema: StructType): Option[String] = { + schema.fields.iterator.flatMap { f => + findUnsupportedType(f.dataType).map(reason => + s"column '${f.name}' has type ${f.dataType.sql} which is not Lance-writable: $reason") + }.toSeq.headOption + } + + private def findUnsupportedType(dt: DataType): Option[String] = dt match { + case _: NumericType | BooleanType | StringType | BinaryType => None + case DateType | TimestampType => None + case s: StructType => + // Recursive check on each field. + s.fields.iterator.flatMap { f => + findUnsupportedType(f.dataType).map(r => s"nested field '${f.name}' of struct: $r") + }.toSeq.headOption + case a: ArrayType => + findUnsupportedType(a.elementType).map(r => s"array element type unsupported: $r") + case _: MapType => Some("MapType is not supported by lance-spark writer") + case _: NullType => Some("NullType has no concrete element to write") + case other => Some(s"unrecognised DataType ${other.getClass.getSimpleName}") + } + /** * Generate a unique scratch path under `scratchDir`. Caller must clean up. The path * deliberately does NOT include the substring "lance" — the V2 catalog's path-identifier diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/LanceKnnImplicitsTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/LanceKnnImplicitsTest.scala index fbd1251f9..5d8f4517c 100644 --- a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/LanceKnnImplicitsTest.scala +++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/LanceKnnImplicitsTest.scala @@ -234,6 +234,99 @@ class LanceKnnImplicitsTest { assertEquals(NumLeft * 3, joined.length, "expected k=3 results per left row") } + /** + * Pin the within-query "same Lance dataset path" property: both the probe and the + * materialize stages of the staged plan must reference the same temp Lance URI. If + * the wiring ever drifts (e.g. helper produces URI A but probe pipeline gets URI B), + * the staged pipeline reads from the wrong dataset and produces silent wrong results. + * This is the structural pin that prevents that. + */ + @Test def testProbeAndMaterializeShareSameTempUri(): Unit = { + import org.lance.spark.knn.internal.staged.{LanceMaterializeLogicalPlan, LanceProbeLogicalPlan} + val (leftDf, _, _) = buildLeft() + val (_, rightIds, rightVecs) = writeRight() + + val schema = new StructType(Array( + StructField("rid", IntegerType, nullable = false), + StructField( + "rvec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build()))) + val rows = rightIds.zip(rightVecs).map { case (rid, vec) => + RowFactory.create(Integer.valueOf(rid), vec) + } + val inMemoryR = spark.createDataFrame(rows.toSeq.asJava, schema) + + val joined = leftDf.kNearestJoin( + right = inMemoryR, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = 3, + metric = "l2", + rightProjection = Some(Seq("rid"))) + + // Walk the analyzed plan looking for the probe and materialize logical nodes; both + // must carry the same `datasetUri` in their stage configs. + val plan = joined.queryExecution.analyzed + val probeNodes = plan.collect { case p: LanceProbeLogicalPlan => p.stageConf.datasetUri } + val materializeNodes = plan.collect { case m: LanceMaterializeLogicalPlan => + m.stageConf.datasetUri + } + assertEquals( + 1, + probeNodes.size, + s"expected exactly one LanceProbeLogicalPlan; got: $probeNodes") + assertEquals( + 1, + materializeNodes.size, + s"expected exactly one LanceMaterializeLogicalPlan; got: $materializeNodes") + assertEquals( + probeNodes.head, + materializeNodes.head, + "probe and materialize must reference the SAME temp Lance dataset URI") + } + + /** + * Right side has a column type Lance can't write (MapType). The DataFrame API path + * is explicit — the user called `kNearestJoin` directly — so it must throw with a + * helpful message rather than silently fall through. (The SQL rule path in the 4.2 + * module makes the opposite choice — it falls through to Spark's brute-force rewrite + * because the user wrote a generic `APPROX NEAREST` query.) + */ + @Test def testKNearestJoinRejectsUnsupportedColumnType(): Unit = { + import org.apache.spark.sql.functions.{lit, map} + val (leftDf, _, _) = buildLeft() + val (_, rightIds, rightVecs) = writeRight() + val schema = new StructType(Array( + StructField("rid", IntegerType, nullable = false), + StructField( + "rvec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build()))) + val rows = rightIds.zip(rightVecs).map { case (rid, vec) => + RowFactory.create(Integer.valueOf(rid), vec) + } + val withMap = spark.createDataFrame(rows.toSeq.asJava, schema) + .withColumn("attrs", map(lit("k"), lit("v"))) + + val ex = assertThrows( + classOf[IllegalArgumentException], + () => + leftDf.kNearestJoin( + right = withMap, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = 3, + metric = "l2", + rightProjection = Some(Seq("rid", "attrs")))) + val msg = ex.getMessage.toLowerCase + assertTrue( + msg.contains("attrs") || msg.contains("map"), + s"error should mention the offending column or type; got: ${ex.getMessage}") + } + /** * In-memory R (no underlying source — `createDataFrame(rows.asJava, schema)`). Same * temp-materialization path; just exercises the case where the rid synthesis and write diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceTempRTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceTempRTest.scala index 6512d4f98..0817c13a8 100644 --- a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceTempRTest.scala +++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceTempRTest.scala @@ -228,6 +228,77 @@ class LanceTempRTest { assertTrue(ex.getMessage.contains(LanceTempR.RidColumnName)) } + // -- schema validation --------------------------------------------------------------------- + + /** checkSupported is None for the typical projection (rid + vec + primitives + strings). */ + @Test def testCheckSupportedAcceptsCommonTypes(): Unit = { + val ok = new StructType(Array( + StructField("rid", LongType), + StructField( + "vec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", 8L).build()), + StructField("title", StringType), + StructField("count", IntegerType), + StructField("when", TimestampType), + StructField("flag", BooleanType), + StructField("payload", BinaryType))) + assertEquals(None, LanceTempR.checkSupported(ok)) + } + + /** MapType is not Lance-writable — checkSupported flags it with a clear message. */ + @Test def testCheckSupportedRejectsMap(): Unit = { + val notOk = new StructType(Array( + StructField("rid", LongType), + StructField("attrs", MapType(StringType, StringType)))) + val res = LanceTempR.checkSupported(notOk) + assertTrue(res.isDefined, "MapType must be rejected") + assertTrue(res.get.contains("attrs"), s"reason should name the column: ${res.get}") + assertTrue(res.get.toLowerCase.contains("map"), s"reason should mention Map: ${res.get}") + } + + /** Map nested inside an Array also rejected (recursive check). */ + @Test def testCheckSupportedRejectsArrayOfMap(): Unit = { + val notOk = new StructType(Array( + StructField("rid", LongType), + StructField("nested", ArrayType(MapType(StringType, IntegerType))))) + assertTrue(LanceTempR.checkSupported(notOk).isDefined) + } + + /** Map nested inside a Struct also rejected. */ + @Test def testCheckSupportedRejectsStructWithMap(): Unit = { + val inner = new StructType(Array(StructField("ext", MapType(StringType, StringType)))) + val notOk = new StructType(Array( + StructField("rid", LongType), + StructField("metadata", inner))) + assertTrue(LanceTempR.checkSupported(notOk).isDefined) + } + + /** + * materialize() throws IllegalArgumentException when projection includes an unsupported + * type — caller (DataFrame API path) propagates this to the user. + */ + @Test def testMaterializeRejectsUnsupportedProjection(): Unit = { + import org.apache.spark.sql.functions.{lit, map} + // Construct a DataFrame with a map column. + val parquetUri = writeRandomParquet(2, Dim) + val withMap = spark.read.parquet(parquetUri) + .withColumn("attrs", map(lit("k1"), lit("v1"))) + + val ex = assertThrows( + classOf[IllegalArgumentException], + () => + LanceTempR.materialize( + withMap, + vecCol = "vec", + projection = Seq("id", "attrs"), + scratchDir = scratch())) + assertTrue( + ex.getMessage.toLowerCase.contains("map") || ex.getMessage.contains("attrs"), + s"error should name the rejected column / type; got: ${ex.getMessage}") + } + // -- resolveScratchDir --------------------------------------------------------------------- /** Conf key set: returned as-is. */