diff --git a/lance-spark-knn-4.2_2.13/src/main/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRule.scala b/lance-spark-knn-4.2_2.13/src/main/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRule.scala index c40c074bd..dc85ffd74 100644 --- a/lance-spark-knn-4.2_2.13/src/main/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRule.scala +++ b/lance-spark-knn-4.2_2.13/src/main/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRule.scala @@ -13,6 +13,7 @@ */ package org.lance.spark.knn.catalyst +import org.apache.spark.sql.{LanceKnnDatasetBridge, SparkSession} import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not, Or, VectorCosineSimilarity, VectorInnerProduct, VectorL2Distance} import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, NearestByDirection, NearestByDistance, NearestBySimilarity} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, NearestByJoin, Project, SubqueryAlias} @@ -21,7 +22,7 @@ import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.{BooleanType, ByteType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String -import org.lance.spark.knn.internal.{LanceMaterializeStage, LanceMergeStage, LanceProbeStage, Metric} +import org.lance.spark.knn.internal.{LanceMaterializeStage, LanceMergeStage, LanceProbeStage, LanceTempR, Metric} import org.lance.spark.knn.internal.staged.{LanceMaterializeLogicalPlan, LanceMergeLogicalPlan, LanceProbeLogicalPlan, ProbedLeftCodec} /** @@ -101,6 +102,22 @@ object IndexedNearestByJoinRule extends Rule[LogicalPlan] { /** Configuration key that gates the rule. Off by default to keep the rule opt-in for now. */ val EnabledConfKey: String = "spark.lance.knn.indexedNearestByJoin.enabled" + /** + * Configuration key that gates the per-query temp-Lance materialization for non-Lance + * right sides (sezruby/lance-spark#2). When enabled (and [[EnabledConfKey]] is also on), + * a `NearestByJoin` whose right side isn't a Lance scan triggers + * [[org.lance.spark.knn.internal.LanceTempR.materialize]] at rule-application time, then + * proceeds with the same staged-plan rewrite against the temp URI. + * + * Off by default. Two reasons to keep it opt-in: + * 1. The rule evaluates the right plan synchronously at analysis time — for large R + * that's a meaningful cost users should consciously accept. + * 2. Cluster mode requires `spark.lance.knn.tempR.dir` to be set (see + * [[LanceTempR.resolveScratchDir]]); a misconfigured cluster would fail with a + * clear error message that's still better surfaced behind an explicit opt-in. + */ + val TempRForSqlEnabledConfKey: String = "spark.lance.knn.tempRForSqlRule.enabled" + /** * IVF cluster count to visit per query. Higher = better recall, more compute. Default * (None) leaves Lance's index-default (typically 1). @@ -120,6 +137,7 @@ object IndexedNearestByJoinRule extends Rule[LogicalPlan] { } val nprobes = optInt(NprobesConfKey) val refineFactor = optInt(RefineFactorConfKey) + val tempRForSqlEnabled = conf.getConfString(TempRForSqlEnabledConfKey, "false").toBoolean plan.transformDown { case j @ NearestByJoin(left, right, joinType, true, k, rankingExpr, direction) => rewriteIfApplicable( @@ -131,7 +149,8 @@ object IndexedNearestByJoinRule extends Rule[LogicalPlan] { rankingExpr, direction, nprobes, - refineFactor).getOrElse(j) + refineFactor, + tempRForSqlEnabled).getOrElse(j) } } @@ -164,10 +183,20 @@ object IndexedNearestByJoinRule extends Rule[LogicalPlan] { rankingExpr: Expression, direction: NearestByDirection, nprobes: Option[Int], - refineFactor: Option[Int]): Option[LogicalPlan] = { + refineFactor: Option[Int], + tempRForSqlEnabled: Boolean): Option[LogicalPlan] = { for { - lance <- unwrapLanceScan(right) + // Recognize the ranking BEFORE attempting the right-side resolution so we don't + // pay a temp-Lance materialization for queries that are going to fall through + // anyway (wrong direction, mixed-side rank expression, etc.). (metric, leftVecAttr, rightVecCol) <- recognizeRanking(rankingExpr, direction, left, right) + lance <- unwrapLanceScan(right).orElse { + if (tempRForSqlEnabled) { + materializeNonLanceR(right, rightVecCol) + } else { + None + } + } } yield { val leftVecIdx = left.output.indexWhere(_.exprId == leftVecAttr.exprId) require(leftVecIdx >= 0, s"left vector attr not found in left.output: $leftVecAttr") @@ -422,6 +451,58 @@ object IndexedNearestByJoinRule extends Rule[LogicalPlan] { } } + /** + * Synthesise a [[LanceScanInfo]] for a non-Lance right plan by materializing it to a + * temp Lance dataset (sezruby/lance-spark#2). The materialization runs synchronously at + * rule-application time. Returns `None` if anything goes wrong — caller falls through + * to Spark's brute-force rewrite. + * + * The synthesised `output` reuses the ORIGINAL right plan's attribute references + * (`right.output`) — preserving `ExprId`s — so the top-level `Project(j.output, ...)` + * the rule emits stays attribute-resolved. The materialize stage reads from the temp + * Lance dataset (which has those columns by name plus the synthetic `_rid`) and binds + * the read-back rows back to those original attribute references. + */ + private def materializeNonLanceR( + right: LogicalPlan, + rightVecCol: String): Option[LanceScanInfo] = { + try { + val spark = SparkSession.active + val rightDf = LanceKnnDatasetBridge.asDataFrame(spark, right) + // Pre-check schema BEFORE triggering any work — if Lance can't write any of the + // projected columns, fall through (return None) so Spark's brute-force rewrite + // handles the query. Same "refusal not partial pushdown" pattern the prefilter + // translator uses. + val projection: Seq[String] = right.output.map(_.name).filterNot(_ == rightVecCol) + val projectedSchema = StructType( + (rightVecCol +: projection).map(name => rightDf.schema(name))) + if (LanceTempR.checkSupported(projectedSchema).isDefined) { + return None + } + val scratchDir = LanceTempR.resolveScratchDir(spark) + val tempUri = LanceTempR.materialize( + right = rightDf, + vecCol = rightVecCol, + projection = projection, + scratchDir = scratchDir) + // Synthesise a LanceScanInfo: URI is the temp; version is None (per-query temp); + // output keeps the original right.output (preserving ExprIds for the top-level + // Project resolution); no prefilter (the temp already represents the FULLY + // evaluated right plan, so any source-side filters were applied during the write). + Some( + LanceScanInfo( + uri = tempUri, + version = None, + output = right.output, + prefilter = None)) + } catch { + case _: Throwable => + // Any failure (config missing in cluster mode, write fails, etc.) — fall through + // to Spark's brute-force rewrite. The user's query still runs, just slowly. + None + } + } + private def isLanceTable(table: Table): Boolean = { val cls = table.getClass.getName // Loose by design — the rule is opt-in via spark.lance.knn.indexedNearestByJoin.enabled, so diff --git a/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRuleTest.scala b/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRuleTest.scala index a9bf8c509..15e68b8f5 100644 --- a/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRuleTest.scala +++ b/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRuleTest.scala @@ -290,6 +290,144 @@ class IndexedNearestByJoinRuleTest { assertSame(join, rewritten, "computed expression must refuse pushdown") } + // -- per-query temp-Lance path for non-Lance R (sezruby/lance-spark#2) -------------------- + + /** + * With `spark.lance.knn.tempRForSqlRule.enabled = true`, a non-Lance right side + * (here: a parquet-backed DataFrame) triggers per-query temp-Lance materialization + * inside the rule. The rule should rewrite to the same staged-plan tree it produces + * for Lance R, but with the probe pointed at the temp URI. + */ + @Test def testTempRForSqlRewritesNonLanceR(): Unit = { + spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true") + spark.conf.set(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey, "true") + try { + val left = trivialPlan("lid", "lvec") + // Right side: an analyzed parquet plan, not a Lance scan. + val parquetRight = parquetLikePlan(idCol = "rid", vecCol = "rvec") + + val leftVec = left.output.find(_.name == "lvec").get + val rightVec = parquetRight.output.find(_.name == "rvec").get + val join = NearestByJoin( + left, + parquetRight, + Inner, + approx = true, + numResults = 5, + rankingExpression = VectorL2Distance(leftVec, rightVec), + direction = NearestByDistance) + val rewritten = IndexedNearestByJoinRule(join) + // Same shape assertion as the Lance-R happy path: Project wrapping the materialize node. + assertTrue( + rewritten.isInstanceOf[Project] && + rewritten.asInstanceOf[Project].child.isInstanceOf[LanceMaterializeLogicalPlan], + s"expected rewrite to Project(LanceMaterialize(...)), got: $rewritten") + // The probe URI should point at a temp dir, not anything from the parquet relation. + val summary = expectRewritten(rewritten) + // No prefilter when the right side was materialized — the temp already represents + // the fully-evaluated plan. + assertEquals(None, summary.prefilter) + } finally { + spark.conf.unset(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey) + spark.conf.unset(IndexedNearestByJoinRule.EnabledConfKey) + } + } + + /** + * With the temp-R conf enabled but the rule itself disabled, a non-Lance right side still + * falls through. The materialize-and-rewrite path is gated on BOTH the main enabled flag + * AND the temp-R-for-SQL flag — turning on only one isn't enough. + */ + @Test def testTempRForSqlRequiresMainEnabledFlag(): Unit = { + spark.conf.unset(IndexedNearestByJoinRule.EnabledConfKey) // main flag off + spark.conf.set(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey, "true") + try { + val left = trivialPlan("lid", "lvec") + val parquetRight = parquetLikePlan(idCol = "rid", vecCol = "rvec") + val leftVec = left.output.find(_.name == "lvec").get + val rightVec = parquetRight.output.find(_.name == "rvec").get + val join = NearestByJoin( + left, + parquetRight, + Inner, + approx = true, + numResults = 5, + rankingExpression = VectorL2Distance(leftVec, rightVec), + direction = NearestByDistance) + val rewritten = IndexedNearestByJoinRule(join) + assertSame( + join, + rewritten, + "main enabled flag off → rule must not fire even with tempR enabled") + } finally { + spark.conf.unset(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey) + } + } + + /** + * Right side has a Lance-unsupported column type (MapType). Even with the temp-R conf + * enabled, the rule must fall through to brute-force rather than throw — Catalyst rules + * shouldn't fail the user's query when a fallback exists. Pin: schema check refuses, + * rewrite returns None, original NearestByJoin propagates. + */ + @Test def testTempRForSqlFallsThroughOnUnsupportedSchema(): Unit = { + spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true") + spark.conf.set(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey, "true") + try { + val left = trivialPlan("lid", "lvec") + val rightWithMap = parquetLikePlanWithMap(idCol = "rid", vecCol = "rvec") + val leftVec = left.output.find(_.name == "lvec").get + val rightVec = rightWithMap.output.find(_.name == "rvec").get + val join = NearestByJoin( + left, + rightWithMap, + Inner, + approx = true, + numResults = 5, + rankingExpression = VectorL2Distance(leftVec, rightVec), + direction = NearestByDistance) + val rewritten = IndexedNearestByJoinRule(join) + assertSame( + join, + rewritten, + "right with MapType column must fall through (rule cannot materialize unsupported type)") + } finally { + spark.conf.unset(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey) + spark.conf.unset(IndexedNearestByJoinRule.EnabledConfKey) + } + } + + /** + * Without the temp-R conf, a non-Lance right side falls through even when the main rule is + * enabled. Pins the existing behavior so we don't accidentally change it when extending the + * rule. + */ + @Test def testNonLanceRWithoutTempRConfFallsThrough(): Unit = { + spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true") + spark.conf.unset(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey) + try { + val left = trivialPlan("lid", "lvec") + val parquetRight = parquetLikePlan(idCol = "rid", vecCol = "rvec") + val leftVec = left.output.find(_.name == "lvec").get + val rightVec = parquetRight.output.find(_.name == "rvec").get + val join = NearestByJoin( + left, + parquetRight, + Inner, + approx = true, + numResults = 5, + rankingExpression = VectorL2Distance(leftVec, rightVec), + direction = NearestByDistance) + val rewritten = IndexedNearestByJoinRule(join) + assertSame( + join, + rewritten, + "tempR conf off → non-Lance right must fall through to brute-force") + } finally { + spark.conf.unset(IndexedNearestByJoinRule.EnabledConfKey) + } + } + /** Filter wrapped in SubqueryAlias still pushes — order of unwrap shouldn't matter. */ @Test def testFilterUnderSubqueryAliasPushes(): Unit = { spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true") @@ -391,6 +529,47 @@ class IndexedNearestByJoinRuleTest { spark.createDataFrame(rows.asJava, schema).queryExecution.analyzed } + /** + * A parquet-backed analyzed plan for testing the per-query temp-R rewrite path. Materialization + * runs `df.write.format("lance").save(...)` against this plan, so the plan must execute + * (unlike `trivialPlan` which is just an analyzed `LocalRelation`). Writes a tiny parquet to + * tempDir on first call, returns its analyzed scan plan. + */ + private def parquetLikePlan(idCol: String, vecCol: String): LogicalPlan = { + val schema = new StructType(Array( + StructField(idCol, IntegerType, nullable = false), + StructField( + vecCol, + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", 8L).build()))) + val rows = (0 until 4).map(i => RowFactory.create(Integer.valueOf(i), Array.fill(8)(0.5f))) + val parquetPath = tempDir.resolve(s"parquet_for_rule_${System.nanoTime()}").toString + spark.createDataFrame(rows.asJava, schema).write.parquet(parquetPath) + spark.read.parquet(parquetPath).queryExecution.analyzed + } + + /** + * Like parquetLikePlan but adds a MapType column, which Lance's writer can't handle. Used + * to verify the rule's schema check refuses-and-falls-through on unsupported types. + */ + private def parquetLikePlanWithMap(idCol: String, vecCol: String): LogicalPlan = { + import org.apache.spark.sql.functions.{lit => sqlLit, map => mapFn} + val schema = new StructType(Array( + StructField(idCol, IntegerType, nullable = false), + StructField( + vecCol, + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", 8L).build()))) + val rows = (0 until 4).map(i => RowFactory.create(Integer.valueOf(i), Array.fill(8)(0.5f))) + val parquetPath = tempDir.resolve(s"parquet_with_map_${System.nanoTime()}").toString + val withMap = spark.createDataFrame(rows.asJava, schema) + .withColumn("attrs", mapFn(sqlLit("k"), sqlLit("v"))) + withMap.write.parquet(parquetPath) + spark.read.parquet(parquetPath).queryExecution.analyzed + } + /** * Build a `DataSourceV2Relation` whose `table.getClass.getName.contains("Lance")` so the * rule's duck-type check accepts it. We don't actually run any I/O. Includes a `category` diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/LanceKnnImplicits.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/LanceKnnImplicits.scala index 5b349dd85..e8862cb46 100644 --- a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/LanceKnnImplicits.scala +++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/LanceKnnImplicits.scala @@ -16,6 +16,7 @@ package org.lance.spark.knn import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.lance.spark.knn.internal.LanceTempR /** * Idiomatic DataFrame extension for the indexed nearest-K join. The Phase 2 SQL syntax @@ -37,11 +38,21 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation * metric = "l2") * }}} * - * The right DataFrame MUST be a Lance scan — `spark.read.format("lance").load(uri)`. The - * extension extracts the underlying URI from the right-side analyzed plan; if it can't find a - * `LanceTable` it throws `IllegalArgumentException`. This is intentional: the indexed path - * uses Lance's Java API directly to open the dataset, so a non-Lance DataFrame cannot be - * substituted (there's no general "any DataFrame" indexed path). + * The right DataFrame can be either: + * + * - A Lance scan (`spark.read.format("lance").load(uri)`). The extension extracts the + * underlying URI from the right-side analyzed plan and the existing probe pipeline runs + * against it directly. + * - Any other DataFrame (parquet, delta, in-memory, the result of an arbitrary upstream + * plan). The extension materializes it to a temp Lance dataset first via + * [[org.lance.spark.knn.internal.LanceTempR.materialize]], then runs the same probe + * pipeline against the temp URI. See sezruby/lance-spark#2 for the design. + * + * For the temp-Lance path to work, the session must either be in local mode or have + * `spark.lance.knn.tempR.dir` set to a path every executor (and the driver) can read+write + * — typically a shared object store (`s3://...`, `abfss://...`, `file:///shared-mount/...`) + * or HDFS. Cluster runs missing this conf fail fast at materialization time with a clear + * error message. * * == Why an extension method, not a builder == * @@ -55,17 +66,23 @@ object LanceKnnImplicits { implicit class LanceKnnDataFrameOps(val df: DataFrame) extends AnyVal { /** - * Approximate top-K nearest-neighbor join over a Lance-backed right DataFrame. The right - * DataFrame must be a `spark.read.format("lance").load(uri)` (any plan that wraps a - * `LanceTable` — `Filter`, `SubqueryAlias`, `Project` are unwrapped). For a non-Lance - * right side or a derived plan that loses the URI, this method throws. + * Approximate top-K nearest-neighbor join. The right DataFrame can be: + * + * - A Lance scan (`spark.read.format("lance").load(uri)`) — runs against R directly. + * - Any other DataFrame — materialized to a temp Lance dataset first, then the + * existing probe pipeline runs against the temp URI. The temp is unique per call; + * it persists for the lifetime of the returned DataFrame's evaluation. (See + * `LanceTempLifecycle` for query-scoped cleanup.) * - * @param right Lance-backed right DataFrame + * @param right right DataFrame (Lance scan or any other source) * @param leftVecCol name of the vector column on `this` (left) * @param rightVecCol name of the vector column on `right` * @param k number of nearest neighbors per left row * @param metric distance / similarity metric: "l2" | "cosine" | "dot" - * @param rightProjection columns to materialize from `right`. `None` = all columns. + * @param rightProjection columns to materialize from `right`. `None` = all of R's + * non-vector columns (existing behavior on Lance R; carries + * everything into the temp on non-Lance R, which can be + * expensive for wide tables). * @param outerJoin left-outer mode: emit a left row even if zero neighbors found * @param scoreCol name of the appended score column (default `__score`) * @param overfetch multiplier on `k` during the probe before final trim @@ -95,7 +112,16 @@ object LanceKnnImplicits { ef: Option[Int] = None, probeParallelism: Int = 1, balanceFragments: Boolean = false): DataFrame = { - val (uri, version) = LanceKnnImplicits.extractLanceUri(right) + // Try the existing Lance-scan path first. Falls through to temp materialization for + // any non-Lance right (parquet, delta, in-memory, arbitrary subplan). + val (uri, version) = LanceKnnImplicits.extractLanceUri(right) match { + case Some(t) => t + case None => + val tempUri = LanceKnnImplicits.materializeNonLanceR(right, rightVecCol, rightProjection) + (tempUri, None) + } + // After temp materialization, the dataset has columns rid + vec + caller's projection. + // The probe pipeline reads everything from there; no further translation needed. IndexedNearestJoin( left = df, rightLanceUri = uri, @@ -120,9 +146,10 @@ object LanceKnnImplicits { /** * Walk a DataFrame's analyzed plan looking for a `LanceTable`-backed * `DataSourceV2Relation`. Skips through wrappers that don't change the underlying - * relation: `SubqueryAlias`, `View`, `Project`, `Filter`. Returns `(uri, optional version)` - * pulled from the relation's options. Throws `IllegalArgumentException` if no Lance scan - * is found. + * relation: `SubqueryAlias`, `View`, `Project`, `Filter`. Returns + * `Some((uri, optional version))` pulled from the relation's options when a Lance scan + * is found, or `None` otherwise — callers can fall through to temp materialization in + * that case rather than failing. * * Lance detection mirrors `IndexedNearestByJoinRule.isLanceTable` — * class-name match (`getClass.getName.contains("Lance")`) — to keep the user-facing @@ -132,20 +159,50 @@ object LanceKnnImplicits { * * Public for tests. */ - private[knn] def extractLanceUri(df: DataFrame): (String, Option[Long]) = { - val rel = findLanceRelation(df.queryExecution.analyzed).getOrElse { - throw new IllegalArgumentException( - "kNearestJoin requires the right DataFrame to be a Lance scan " + - "(spark.read.format(\"lance\").load(uri)). Plan was:\n" + - df.queryExecution.analyzed) + private[knn] def extractLanceUri(df: DataFrame): Option[(String, Option[Long])] = { + findLanceRelation(df.queryExecution.analyzed).flatMap { rel => + val opts = rel.options + val uri = Option(opts.get("path")).orElse(Option(opts.get("datasetUri"))) + uri.map { u => + val version = Option(opts.get("version")).map(_.toLong) + (u, version) + } + } + } + + /** + * Materialize a non-Lance right DataFrame to a temp Lance dataset and return its URI. + * Caller must clean up — see `LanceTempLifecycle` for the query-scoped sweeper. The + * scratch directory comes from [[LanceTempR.resolveScratchDir]] which reads + * `spark.lance.knn.tempR.dir` and falls back to local FS only in local mode. + * + * @param right non-Lance DataFrame to materialize + * @param rightVecCol vector column name (must exist on `right`) + * @param rightProjection columns to carry into the temp Lance dataset, in addition to + * the synthesised rid and the vector. `None` means "all columns + * of `right` other than the vector" — matches the existing + * Lance-R semantics where omitting `rightProjection` means + * "carry everything." For wide tables, callers should pass an + * explicit `Some(...)` to avoid copying unnecessary bytes. + */ + private[knn] def materializeNonLanceR( + right: DataFrame, + rightVecCol: String, + rightProjection: Option[Seq[String]]): String = { + val spark = right.sparkSession + val scratchDir = LanceTempR.resolveScratchDir(spark) + val projection: Seq[String] = rightProjection match { + case Some(cols) => cols.filterNot(_ == rightVecCol) + case None => + // Default to "carry every column of R other than the vector" — matches the + // semantics that omitting rightProjection on a Lance scan reads every column. + right.schema.fieldNames.toSeq.filterNot(_ == rightVecCol) } - val opts = rel.options - val uri = Option(opts.get("path")) - .orElse(Option(opts.get("datasetUri"))) - .getOrElse(throw new IllegalArgumentException( - "Lance relation found but no `path` / `datasetUri` option set; cannot extract URI")) - val version = Option(opts.get("version")).map(_.toLong) - (uri, version) + LanceTempR.materialize( + right, + vecCol = rightVecCol, + projection = projection, + scratchDir = scratchDir) } private def findLanceRelation(plan: LogicalPlan): Option[DataSourceV2Relation] = plan match { diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/IndexedNearestJoinTempRBenchmark.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/IndexedNearestJoinTempRBenchmark.scala new file mode 100644 index 000000000..9254b92d8 --- /dev/null +++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/IndexedNearestJoinTempRBenchmark.scala @@ -0,0 +1,490 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.knn.benchmark + +import org.apache.spark.sql.{DataFrame, Row, RowFactory, SparkSession} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.lance.spark.knn.LanceKnnImplicits._ + +import java.nio.file.Files +import java.util.{Locale, Random} +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +/** + * Validates the per-query temp-Lance design from + * [sezruby/lance-spark#2](https://github.com/sezruby/lance-spark/issues/2): when R lives in + * parquet (or any non-Lance source), the indexed `NearestByJoin` path can still apply by + * writing R to a temp Lance dataset before the probe. Three configs run on the SAME data + * in the SAME job (no cross-run noise): + * + * A: vanilla Spark crossJoin + L2 UDF + min_by_k on parquet R — the brute-force baseline + * a user would write today; matches what Spark 4.2's `RewriteNearestByJoin` lowers to. + * B: per-query temp Lance write + `kNearestJoin` against the temp URI — the per-query + * temp design under test. + * C: Lance-native R + `kNearestJoin` — same probe pipeline as B but R was pre-written to + * Lance once outside the timing loop. The "already-Lance" reference; (B - C) is the + * pure temp-write overhead per query. + * + * Three configs answer: + * 1. Does B beat A on parquet R? (B vs A speedup — is the per-query temp story faster + * than the brute-force baseline at all?) + * 2. How much overhead does the temp write add vs. the existing Lance-already path? + * (B - C, on the same hardware in the same run, no cross-run noise) + * 3. Sanity: probe-only median of B (excluding temp write) should match C closely; if + * not, something's off in the pipeline. + * + * == Local run == + * {{{ + * MAVEN_OPTS="-Xmx12g " \ + * ./mvnw -pl lance-spark-knn_2.12 -q exec:java -Pbenchmark \ + * -Dexec.mainClass="org.lance.spark.knn.benchmark.IndexedNearestJoinTempRBenchmark" + * }}} + * + * Local laptop runs are noisy enough that single-run point estimates are unreliable + * (multi-tenant CPU contention on macOS, single-machine no-real-parallelism). Headline + * numbers should come from a real distributed cluster — see `BENCHMARK_RESULTS.md` + * § "Variance / multi-tenant noise" for the discussion in the existing benchmark. + * + * == Cluster run == + * Build the fat JAR (`./mvnw -pl lance-spark-knn_2.12 package -Pbenchmark -DskipTests`), + * upload, submit with `BENCH_CLUSTER_MODE=true` and `BENCH_DATA_PATH=`. + * + * Environment: + * BENCH_SCALES — comma-separated subset of {tiny, small, medium, fat}; + * default "tiny,small". + * tiny = |R|=100K, |L|=100, dim=128 + * small = |R|=1M, |L|=1000, dim=128 + * medium = |R|=100K, |L|=100, dim=1024 + * fat = |R|=1M, |L|=1000, dim=1024 + * BENCH_REPEATS — measured iterations per config; default 3. 1 warmup. + * BENCH_CLUSTER_MODE — `true` to skip `.master()` and bind-address configs; + * must be set when submitting to a real Spark cluster. + * BENCH_DATA_PATH — shared scratch URI (file://, s3://, hdfs://, etc.); + * required in cluster mode, optional locally. + * + * Reports timing breakdown for B: (temp_write_ms + probe_ms) so the temp-write cost is + * visible relative to the probe itself. + */ +object IndexedNearestJoinTempRBenchmark { + + private val K: Int = 10 + private val Seed: Long = 1337L + + private case class Scale(name: String, numR: Int, numL: Int, dim: Int) { + def vectorBytesR: Long = numR.toLong * dim.toLong * 4L + override def toString: String = + f"$name (|R|=$numR%,d, |L|=$numL%,d, dim=$dim, R-vec=${vectorBytesR / (1024.0 * 1024.0)}%.0f MB)" + } + + private val Scales: Map[String, Scale] = Seq( + Scale("tiny", numR = 100000, numL = 100, dim = 128), + Scale("small", numR = 1000000, numL = 1000, dim = 128), + Scale("medium", numR = 100000, numL = 100, dim = 1024), + Scale("fat", numR = 1000000, numL = 1000, dim = 1024)).map(s => s.name -> s).toMap + + private case class Result( + scale: String, + config: String, + tempWriteMs: Option[Long], + totalMs: Long, + runs: Seq[Long]) { + def probeMs: Option[Long] = tempWriteMs.map(w => totalMs - w) + } + + def main(args: Array[String]): Unit = { + val scaleNames = sys.env + .getOrElse("BENCH_SCALES", "tiny,small") + .toLowerCase(Locale.ROOT) + .split(",") + .map(_.trim) + .filter(_.nonEmpty) + val scales = scaleNames.map { n => + Scales.getOrElse( + n, + sys.error(s"unknown scale '$n'; valid: ${Scales.keys.toSeq.sorted.mkString(", ")}")) + } + val repeats = sys.env.get("BENCH_REPEATS").map(_.toInt).getOrElse(3) + val clusterMode = sys.env.get("BENCH_CLUSTER_MODE").exists(_.equalsIgnoreCase("true")) + val dataRootOpt = sys.env.get("BENCH_DATA_PATH").orElse(sys.env.get("BENCH_DATA")) + + // Use user-supplied path in cluster mode, otherwise a local temp dir. + // Avoid "lance" in the path token; the V2 catalog's path-identifier parser tokenises + // around it on writes. Same workaround as LanceWriteBenchmark. + val dataRoot = + dataRootOpt.getOrElse(Files.createTempDirectory("knn-tempr-bench-").toString) + + val builder = SparkSession + .builder() + .appName("indexed-nearest-temp-r-benchmark") + .config("spark.sql.crossJoin.enabled", "true") + if (!clusterMode) { + builder + .master("local[*]") + .config("spark.driver.bindAddress", "127.0.0.1") + .config("spark.driver.host", "127.0.0.1") + .config("spark.sql.shuffle.partitions", "16") + } + val spark = builder.getOrCreate() + spark.sparkContext.setLogLevel("WARN") + + println("=" * 76) + println("Per-query temp Lance benchmark — non-Lance R via temp write") + println("=" * 76) + val masterDesc = if (clusterMode) "cluster (BENCH_CLUSTER_MODE=true)" else "local[*]" + println(f"Spark master: $masterDesc (cores=${spark.sparkContext.defaultParallelism})") + println(f"Repeats: $repeats (median reported); 1 warmup") + println(f"Data root: $dataRoot") + println(f"K: $K") + println(f"Scales: ${scales.map(_.name).mkString(", ")}") + println() + + val results = scala.collection.mutable.ArrayBuffer.empty[Result] + try { + scales.foreach { scale => + println("-" * 76) + println(s"Scale: $scale") + println("-" * 76) + + // Build left in-memory and right as parquet on disk. This is the realistic shape: + // R isn't in Lance yet; we want to time the cost of materializing it. + val leftDf = buildLeft(spark, scale).cache() + leftDf.count() + val rightParquetUri = s"$dataRoot/${scale.name}_right.parquet" + writeRightParquet(spark, scale, rightParquetUri) + val rightDfParquet = spark.read.parquet(rightParquetUri) + + // Pre-materialize a Lance-native R once (outside the timing loop) for config C. + // This is the apples-to-apples reference: same data, already-Lance — what users get + // when they store R in Lance natively. The kNearestJoin call against this URI uses + // exactly the same probe pipeline as B; the only difference is no temp write. + val rightLanceUri = s"$dataRoot/${scale.name}_right_native.lance" + writeTempLance(rightDfParquet, rightLanceUri) + val rightDfLance = spark.read.format("lance").load(rightLanceUri) + + // Sanity-check on a 16-row subset that the per-query-temp path agrees with the + // crossJoin baseline. Bail early if not. + verifyOracle(spark, scale, leftDf, rightParquetUri, dataRoot) + + // ---- Config A: vanilla Spark crossJoin + L2 UDF + min_by_k baseline ---- + // Use fewer repeats at higher scales since each A run is O(|L| × |R|) pair + // evaluations and quickly becomes minutes per run. + val baselineRepeats = if (scale.numL.toLong * scale.numR > 100000000L) 1 else repeats + val resultA = + timeIt(scale.name, "A: Spark crossJoin + min_by_k (parquet R)", baselineRepeats) { + () => crossJoinMinByK(leftDf, rightDfParquet, K) + } + results += resultA + + // ---- Config B: per-query temp Lance write + existing kNearestJoin ---- + val resultB = timeWithBreakdown(scale.name, "B: temp Lance write + kNearestJoin", repeats) { + () => + val tempUri = s"$dataRoot/${scale.name}_temp_${System.nanoTime()}" + // Step 1: temp write (timed) + val twStart = System.nanoTime() + writeTempLance(rightDfParquet, tempUri) + val twMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - twStart) + // Step 2: KNN against temp Lance (timed via runFull at outer level) + val tempLanceDf = spark.read.format("lance").load(tempUri) + val joined = leftDf.kNearestJoin( + right = tempLanceDf, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = K, + metric = "l2", + rightProjection = Some(Seq("rid")), + probeParallelism = 1) + (twMs, joined) + } + results += resultB + + // ---- Config C: Lance-native R (already-Lance reference; no temp write) ---- + val resultC = timeIt(scale.name, "C: Lance-native R + kNearestJoin", repeats) { () => + leftDf.kNearestJoin( + right = rightDfLance, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = K, + metric = "l2", + rightProjection = Some(Seq("rid")), + probeParallelism = 1) + } + results += resultC + + // ---- Config D: kNearestJoin extension's built-in non-Lance handling ---- + // This is the actual code path users hit when they write + // `queries.kNearestJoin(parquetDf, ...)` after sezruby/lance-spark#2. + // The extension internally calls LanceKnnImplicits.materializeNonLanceR → + // LanceTempR.materialize → existing probe pipeline. Compared against B which + // does the same thing but with the materialization spelled out explicitly, + // (D - B) should be near zero — a sanity check on the extension's wiring. + val resultD = timeIt(scale.name, "D: kNearestJoin(parquetDf) — built-in temp", repeats) { + () => + leftDf.kNearestJoin( + right = rightDfParquet, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = K, + metric = "l2", + rightProjection = Some(Seq("rid")), + probeParallelism = 1) + } + results += resultD + + // ---- Plan-shape dump for the actual code path ---- + // Print the analyzed + executed plans for config D (the new built-in path) at + // the smallest scale only. Helps verify the temp-Lance materialization and the + // staged probe/merge/materialize pipeline appear as expected when users hit the + // sezruby/lance-spark#2 path through the public API. + if (scale.name == scales.head.name) { + println() + println("-- df.explain(true) for config D (kNearestJoin against parquet R) --") + val explainDf = leftDf.kNearestJoin( + right = rightDfParquet, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = K, + metric = "l2", + rightProjection = Some(Seq("rid")), + probeParallelism = 1) + explainDf.explain(extended = true) + } + + leftDf.unpersist() + println() + } + + println("=" * 76) + println("Summary") + println("=" * 76) + printSummary(results.toSeq) + } finally { + spark.stop() + } + } + + // -- workload setup ---------------------------------------------------------------------- + + private def buildLeft(spark: SparkSession, scale: Scale): DataFrame = { + val schema = leftSchema(scale.dim) + val rng = new Random(Seed ^ 1L) + val rows = (0 until scale.numL).map { i => + RowFactory.create(Integer.valueOf(i), randomVector(rng, scale.dim)) + } + spark.createDataFrame(rows.asJava, schema) + } + + private def writeRightParquet(spark: SparkSession, scale: Scale, uri: String): Unit = { + val schema = rightSchema(scale.dim) + val parts = math.max(spark.sparkContext.defaultParallelism, 8) + val numR = scale.numR + val dim = scale.dim + val rdd = spark.sparkContext + .range(0L, numR.toLong, 1L, parts) + .mapPartitionsWithIndex { (idx, iter) => + val rng = new Random(0xCAFEBABEL ^ idx.toLong) + iter.map { i => + val v = new Array[Float](dim) + var k = 0 + while (k < dim) { v(k) = rng.nextFloat(); k += 1 } + RowFactory.create(Integer.valueOf(i.toInt), v): Row + } + } + spark.createDataFrame(rdd, schema).write.parquet(uri) + } + + private def writeTempLance(rightDf: DataFrame, tempUri: String): Unit = { + // Per-query temp: project rid + rvec only (the columns the existing kNearestJoin path + // needs). Carrying additional payload columns is a follow-up; this benchmark validates + // the minimal shape. + rightDf.select("rid", "rvec").write.format("lance").save(tempUri) + } + + private def leftSchema(dim: Int): StructType = new StructType( + Array( + StructField("lid", IntegerType, nullable = false), + StructField( + "lvec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", dim.toLong).build()))) + + private def rightSchema(dim: Int): StructType = new StructType( + Array( + StructField("rid", IntegerType, nullable = false), + StructField( + "rvec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", dim.toLong).build()))) + + private def randomVector(rng: Random, dim: Int): Array[Float] = { + val v = new Array[Float](dim) + var i = 0 + while (i < dim) { v(i) = rng.nextFloat(); i += 1 } + v + } + + // -- baseline ---------------------------------------------------------------------------- + + private def l2Udf = + udf((a: Seq[Float], b: Seq[Float]) => { + var s = 0.0f + var i = 0 + while (i < a.length) { val d = a(i) - b(i); s += d * d; i += 1 } + s + }) + + /** + * Same shape as `IndexedNearestJoinBenchmark.crossProductMinByK`. The realistic + * baseline a user would write today on parquet R (and what Spark 4.2's + * RewriteNearestByJoin lowers to). + */ + private def crossJoinMinByK(left: DataFrame, right: DataFrame, k: Int): DataFrame = { + val l2 = l2Udf + val r = right.select("rid", "rvec") + val crossed = left.crossJoin(r).withColumn("__dist", l2(col("lvec"), col("rvec"))) + crossed.groupBy("lid") + .agg( + slice( + sort_array(collect_list(struct(col("__dist"), col("rid"))), asc = true), + 1, + k).as("__matches")) + .select(col("lid"), inline(col("__matches")).as(Seq("__dist", "rid"))) + .select("lid", "rid", "__dist") + } + + // -- timing harness ---------------------------------------------------------------------- + + private def runFull(df: DataFrame): Unit = + df.write.format("noop").mode("overwrite").save() + + private def timeIt(scale: String, config: String, repeats: Int)(f: () => DataFrame): Result = { + print(s" $config ... ") + System.out.flush() + runFull(f()) // warmup + val runs = (0 until repeats).map { _ => + val t0 = System.nanoTime() + runFull(f()) + TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t0) + } + val sortedRuns = runs.sorted + val median = sortedRuns(sortedRuns.length / 2) + println(s"runs=${runs.mkString("[", ",", "]")} ms, median=$median ms") + Result(scale, config, tempWriteMs = None, totalMs = median, runs = runs) + } + + /** + * Like timeIt but the runnable returns (temp_write_ms, df). The full timing includes + * the temp write; we report it separately so the probe-only cost is visible. + */ + private def timeWithBreakdown(scale: String, config: String, repeats: Int)( + f: () => (Long, DataFrame)): Result = { + print(s" $config ... ") + System.out.flush() + val (_, warmDf) = f() + runFull(warmDf) // warmup + val records = (0 until repeats).map { _ => + val t0 = System.nanoTime() + val (twMs, df) = f() + runFull(df) + val total = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t0) + (twMs, total) + } + val sortedTotals = records.map(_._2).sorted + val medianTotal = sortedTotals(sortedTotals.length / 2) + val medianTw = records.map(_._1).sorted.apply(records.length / 2) + println( + f"runs(total)=${records.map(_._2).mkString("[", ",", "]")} ms, median=$medianTotal%d ms " + + f"(temp write median=$medianTw%d ms, probe median=${medianTotal - medianTw}%d ms)") + Result( + scale, + config, + tempWriteMs = Some(medianTw), + totalMs = medianTotal, + runs = records.map(_._2)) + } + + // -- oracle equivalence ------------------------------------------------------------------ + + private def verifyOracle( + spark: SparkSession, + scale: Scale, + leftDf: DataFrame, + rightParquetUri: String, + dataRoot: String): Unit = { + println(" Sanity: oracle check on 16-row left subset ...") + val left16 = leftDf.limit(16).cache() + left16.count() + val rightDfParquet = spark.read.parquet(rightParquetUri) + val tempUri = s"$dataRoot/${scale.name}_oracle_temp_${System.nanoTime()}" + writeTempLance(rightDfParquet, tempUri) + val tempLance = spark.read.format("lance").load(tempUri) + + val viaTemp = left16.kNearestJoin( + right = tempLance, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = K, + metric = "l2", + rightProjection = Some(Seq("rid")), + probeParallelism = 1) + val viaBaseline = crossJoinMinByK(left16, rightDfParquet, K) + + val tempByLid = viaTemp.collect().groupBy(_.getAs[Int]("lid")).map { + case (lid, rows) => lid -> rows.map(_.getAs[Int]("rid")).toSet + } + val baseByLid = viaBaseline.collect().groupBy(_.getAs[Int]("lid")).map { + case (lid, rows) => lid -> rows.map(_.getAs[Int]("rid")).toSet + } + val mismatches = tempByLid.toSeq.flatMap { + case (lid, ids) => + baseByLid.get(lid) match { + case Some(b) if b == ids => None + case Some(b) => Some(s"lid=$lid: temp=$ids baseline=$b") + case None => Some(s"lid=$lid: missing from baseline") + } + } + if (mismatches.nonEmpty) { + sys.error(s"Oracle mismatch:\n ${mismatches.mkString("\n ")}") + } + left16.unpersist() + println(f" ... oracle equivalence holds (${tempByLid.size} left rows × K=$K).") + } + + // -- reporting --------------------------------------------------------------------------- + + private def printSummary(results: Seq[Result]): Unit = { + val byScale = results.groupBy(_.scale) + println( + f"${"scale"}%-8s ${"config"}%-46s ${"med ms"}%8s ${"tw ms"}%8s ${"probe ms"}%9s ${"vs A"}%6s") + println("-" * 95) + val scaleOrder = Seq("tiny", "small", "medium", "fat").filter(byScale.contains) + scaleOrder.foreach { sc => + val rs = byScale(sc) + val baseline = rs.find(_.config.startsWith("A:")).map(_.totalMs).getOrElse(0L) + rs.foreach { r => + val twStr = r.tempWriteMs.map(_.toString).getOrElse("—") + val probeStr = r.probeMs.map(_.toString).getOrElse("—") + val speedup = + if (baseline > 0 && r.totalMs > 0) f"${baseline.toDouble / r.totalMs}%.1fx" else "—" + println( + f"${r.scale}%-8s ${r.config}%-46s ${r.totalMs}%8d $twStr%8s $probeStr%9s $speedup%6s") + } + println() + } + } +} diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempLifecycle.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempLifecycle.scala new file mode 100644 index 000000000..60bdaf19f --- /dev/null +++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempLifecycle.scala @@ -0,0 +1,192 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.knn.internal + +import org.apache.hadoop.fs.{FileSystem, Path => HadoopPath} +import org.apache.spark.SparkContext +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} +import org.apache.spark.sql.SparkSession + +import java.io.{File, IOException} +import java.net.URI + +import scala.collection.mutable + +/** + * Query-scoped cleanup for the per-query temp Lance datasets created by + * [[LanceTempR.materialize]]. Without it, every `kNearestJoin` against a non-Lance R + * leaks a Lance dataset on whatever scratch storage `spark.lance.knn.tempR.dir` + * points at — local FS, S3, HDFS — until the JVM dies. + * + * == Design == + * + * Cleanup runs on `SparkListenerApplicationEnd` and at JVM-shutdown via a `Runtime` + * shutdown hook. We deliberately do NOT clean up on `onJobEnd`: a single + * `kNearestJoin` invocation can run multiple Spark jobs (the temp write itself, + * the probe stage, the merge shuffle, the materialize stage). Tying cleanup to + * `onJobEnd` would race the still-running probe and break correctness. + * + * `onApplicationEnd` covers the well-behaved case where the SparkSession stops + * cleanly. The shutdown hook covers crashes / hard kills. Either way the temp + * dirs registered up to that point are deleted on a best-effort basis (errors + * are logged but never re-thrown — cleanup must not break the user's job tear- + * down). + * + * == Why not call `Files.delete` directly == + * + * Temp URIs may live on object stores (s3://...), HDFS, ABFS — non-local FS. + * `java.nio.file.Files` only handles local FS. We dispatch through + * `org.apache.hadoop.fs.FileSystem.get(uri, conf)` which routes via the standard + * Spark/Hadoop FileSystem registry — same machinery `df.write.format("lance")` + * already uses to write the temp. + */ +private[knn] object LanceTempLifecycle { + + // Logger via Spark's slf4j via -- log directly with println to stderr if we can't import. + // Kept small to avoid pulling slf4j into the Lance-knn module's surface. + private def logWarn(msg: String): Unit = System.err.println(s"[LanceTempLifecycle] $msg") + private def logInfo(msg: String): Unit = {} // intentionally quiet at info level + + // Synchronised because Spark task threads, listener-bus threads, and the JVM shutdown + // thread can all touch this. Per-application instances live forever in a static map; + // cleanup is keyed on the application id so we can be sure not to drop a different + // app's temps when a SparkContext stops within the same JVM. + private val instances = new mutable.HashMap[String, ApplicationTempRegistry] + + // Single shutdown hook for the JVM, installed on first `register` call. Runs all + // application registries' cleanup paths. + private val shutdownHookInstalled = new java.util.concurrent.atomic.AtomicBoolean(false) + + /** + * Track `tempUri` for cleanup when `spark`'s application ends or the JVM exits, whichever + * comes first. Idempotent: if `tempUri` is already registered for this application, no-op. + */ + def register(spark: SparkSession, tempUri: String): Unit = synchronized { + val sc = spark.sparkContext + val appId = sc.applicationId + val registry = instances.getOrElseUpdate( + appId, { + val r = new ApplicationTempRegistry(sc, appId) + sc.addSparkListener(r) + r + }) + registry.add(tempUri) + ensureShutdownHook() + } + + /** Drop all registered temp URIs for `appId` and clean up the listener. Public for tests. */ + private[knn] def stopForTesting(appId: String): Unit = synchronized { + instances.remove(appId).foreach(_.cleanupAll()) + } + + /** Number of currently-registered temp URIs for an app — for assertions in tests. */ + private[knn] def registeredCount(appId: String): Int = synchronized { + instances.get(appId).map(_.size).getOrElse(0) + } + + private def ensureShutdownHook(): Unit = { + if (shutdownHookInstalled.compareAndSet(false, true)) { + Runtime.getRuntime.addShutdownHook(new Thread("lance-temp-r-cleanup") { + override def run(): Unit = LanceTempLifecycle.synchronized { + instances.values.foreach(_.cleanupAll()) + instances.clear() + } + }) + } + } + + /** + * Per-application registry of temp URIs. Subscribes to `SparkListenerApplicationEnd` + * so cleanup fires as soon as the SparkContext starts shutting down — before the + * scratch FS becomes unreachable in cluster-tear-down ordering. + */ + final private class ApplicationTempRegistry(sc: SparkContext, appId: String) + extends SparkListener { + + private val tempUris = new mutable.LinkedHashSet[String] + private val hadoopConf = sc.hadoopConfiguration + + def add(uri: String): Unit = LanceTempLifecycle.synchronized { + tempUris.add(uri) + } + + def size: Int = LanceTempLifecycle.synchronized(tempUris.size) + + override def onApplicationEnd(end: SparkListenerApplicationEnd): Unit = { + cleanupAll() + LanceTempLifecycle.synchronized { + instances.remove(appId) + } + } + + /** + * Best-effort delete of every registered temp URI. Errors are logged and swallowed — + * cleanup runs during shutdown / context-stop, where re-throwing would obscure the + * actual reason the application is ending. + */ + def cleanupAll(): Unit = LanceTempLifecycle.synchronized { + val snapshot = tempUris.toSeq + tempUris.clear() + snapshot.foreach(deleteSilently) + } + + private def deleteSilently(uri: String): Unit = { + try { + deleteUri(uri, hadoopConf) + logInfo(s"deleted temp Lance dataset: $uri") + } catch { + case e: Throwable => + logWarn(s"failed to delete temp Lance dataset '$uri': ${e.getClass.getSimpleName}: ${e.getMessage}") + } + } + } + + /** + * Delete `uri` recursively. Routes through the same Hadoop FileSystem registry that + * Spark uses for writes, so it handles local FS, S3, HDFS, ABFS, etc. uniformly. + */ + private[knn] def deleteUri( + uri: String, + hadoopConf: org.apache.hadoop.conf.Configuration): Unit = { + if (uri == null || uri.isEmpty) return + if (looksLikeBareLocalPath(uri)) { + // Hadoop FileSystem.get(uri) on a bare path can sometimes route through unexpected + // FS implementations on YARN clusters. For unambiguous local paths use java.nio. + val f = new File(uri) + if (f.exists()) deleteRecursive(f) + } else { + val javaUri: URI = new URI(uri) + val hPath = new HadoopPath(uri) + val fs = FileSystem.get(javaUri, hadoopConf) + if (fs.exists(hPath)) { + if (!fs.delete(hPath, /* recursive = */ true)) { + throw new IOException(s"FileSystem.delete returned false for $uri") + } + } + } + } + + private def looksLikeBareLocalPath(uri: String): Boolean = + !uri.contains("://") + + private def deleteRecursive(f: File): Unit = { + if (f.isDirectory) { + val children = f.listFiles() + if (children != null) children.foreach(deleteRecursive) + } + if (!f.delete() && f.exists()) { + throw new IOException(s"failed to delete $f") + } + } +} diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempR.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempR.scala new file mode 100644 index 000000000..6ff1a3fa4 --- /dev/null +++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceTempR.scala @@ -0,0 +1,264 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.knn.internal + +import org.apache.spark.sql.{Column, DataFrame, SparkSession} +import org.apache.spark.sql.functions.{col, monotonically_increasing_id} +import org.apache.spark.sql.types._ + +import java.nio.file.{Files, Path, Paths} +import java.util.UUID + +/** + * Per-query temp Lance materialization for the right side of an indexed + * `NearestByJoin` when R is not already a Lance scan (parquet, delta, in-memory, + * arbitrary subplan). Materialization is eager: when called, the helper drives + * `right.write.format("lance").save(tempUri)` and returns the URI. The caller + * then passes that URI to the existing Lance-native probe pipeline; the rest of + * the staged plan (probe / merge / materialize) is unchanged. + * + * Design: see [sezruby/lance-spark#2](https://github.com/sezruby/lance-spark/issues/2). + * + * == Why a synchronous helper rather than a Catalyst exec node == + * + * The probe pipeline reads R via `LanceProbeStage.Conf.datasetUri`, which is captured + * once at plan construction time. To plumb a temp-write through Catalyst as a separate + * node we'd have to restructure `LanceProbeExec` (currently `UnaryExecNode` with the + * left side as its child) into a multi-input shape that depends on both the left plan + * AND a sibling temp-write — substantial blast radius. The simpler form: do the temp + * write at plan-build time, hand the resulting URI to the probe like any other Lance + * URI. Same data path on the wire; the only loss is `df.explain()` doesn't display the + * temp write as its own Catalyst operator. A future PR can promote it to an exec node + * if `df.explain()` visibility becomes load-bearing for users. + * + * == Why `monotonically_increasing_id` is the right rid == + * + * Per-query temp doesn't need cross-execution rid stability — temp is built and consumed + * in the same job, then deleted. `monotonically_increasing_id` is unique within a single + * execution and zero-cost to compute, so it fits exactly. For a future cached / persistent + * sidecar (different feature, not this issue), `_metadata.row_index` would be the natural + * choice for parquet-backed R. + */ +private[knn] object LanceTempR { + + /** + * Column name for the synthetic rid that the probe stage's `_rowid IN (...)` lookup + * targets after the write. Stable identifier — referenced by callers when constructing + * `rightProjection`. + */ + val RidColumnName: String = "_rid" + + /** + * User-tunable Spark conf key pointing at the directory under which per-query temp + * Lance datasets are created. In cluster mode this MUST be set to a path every + * executor (and the driver) can read+write — typically a shared object-store URI + * (`s3://...`, `abfss://...`, `file:///shared-mount/...`) or HDFS. In local mode the + * helper falls back to a subdirectory of `spark.local.dir` when this key is unset. + */ + val ScratchDirConfKey: String = "spark.lance.knn.tempR.dir" + + /** + * Materialize `right` to a temp Lance dataset suitable for use as the right side of + * the existing indexed `NearestByJoin` pipeline. + * + * Steps: + * 1. Compute the projection schema: rid (synthesised) + vec + any caller-requested + * additional columns the parent plan references. + * 2. Build a projected DataFrame: `right.select(monotonically_increasing_id() as _rid, + * ...projection)`. + * 3. Write it to `/` as a Lance dataset. + * 4. Return the URI string. + * + * The caller is responsible for: + * - Deleting the temp directory when the query finishes (see `LanceTempLifecycle`). + * - Telling the probe pipeline to project / materialize `RidColumnName` and `vecCol` + * (and any payload columns it should carry). + * + * @param right The non-Lance DataFrame to materialize. + * @param vecCol Name of the FixedSizeList vector column on `right`. + * @param projection Columns from `right` to carry into temp Lance, in addition to the + * synthesised rid and the vector. Empty Seq = carry rid + vec only. + * Use this to thread any payload columns the parent plan references. + * @param scratchDir Directory under which to create the temp Lance dataset. Must + * be a path the executor processes can write to (local FS for + * single-node, shared object store for cluster mode). Use + * [[resolveScratchDir]] to pick this up from session config in + * typical callers. + * @return The URI of the materialized temp Lance dataset. + */ + def materialize( + right: DataFrame, + vecCol: String, + projection: Seq[String], + scratchDir: String): String = { + require(vecCol.nonEmpty, "vecCol must not be empty") + require(scratchDir.nonEmpty, "scratchDir must not be empty") + require( + right.schema.fieldNames.contains(vecCol), + s"right DataFrame schema does not contain vector column '$vecCol'; " + + s"have [${right.schema.fieldNames.mkString(", ")}]") + val unknownCols = projection.filterNot(right.schema.fieldNames.contains) + require( + unknownCols.isEmpty, + s"projection columns not present in right DataFrame schema: " + + s"[${unknownCols.mkString(", ")}]; have [${right.schema.fieldNames.mkString(", ")}]") + require( + !projection.contains(RidColumnName), + s"projection must not include the reserved rid column name '$RidColumnName' — " + + "the helper synthesises it. Pick a different name on `right` or rename before calling.") + // Reject unsupported types BEFORE triggering the write. We project before checking so + // we only inspect the columns actually being written (vec + caller-requested payload), + // not unrelated columns the user happened to leave on `right`. + val projectedFields: Seq[StructField] = + ((vecCol +: projection.filterNot(_ == vecCol)).distinct).map { name => + right.schema(name) + } + findUnsupportedField(StructType(projectedFields)).foreach { reason => + throw new IllegalArgumentException( + s"per-query temp Lance materialization rejected: $reason. " + + "Drop the offending column from `projection`, cast it to a supported type, or use a Lance-native right side.") + } + + val tempUri = mintTempUri(scratchDir) + val ridCol: Column = monotonically_increasing_id().as(RidColumnName) + val payloadCols: Seq[Column] = (vecCol +: projection.filterNot(_ == vecCol)).distinct.map(col) + val projected: DataFrame = right.select(ridCol +: payloadCols: _*) + + projected.write.format("lance").save(tempUri) + + // Register for query-scoped cleanup. Cleanup fires on SparkListenerApplicationEnd + // (when the SparkSession stops cleanly) and on JVM shutdown via a shutdown hook + // (covers crashes / hard kills). See LanceTempLifecycle. + LanceTempLifecycle.register(right.sparkSession, tempUri) + + tempUri + } + + /** + * Resolve a writable scratch directory from session configuration, with a clear error + * for cluster runs that haven't set [[ScratchDirConfKey]]. + * + * Resolution order: + * 1. `spark.lance.knn.tempR.dir` if set — used as-is. Caller is responsible for it + * being executor-readable. + * 2. Local-mode-only fallback: `spark.local.dir` first entry + `/lance-temp-r`. Only + * acceptable in `local[*]` mode where the driver and executors share a JVM and + * hence the local FS. + * + * In a cluster (`spark.master` does not start with `local`), missing + * [[ScratchDirConfKey]] throws [[IllegalStateException]] — failing here is much better + * than failing later with a `FileNotFoundException` on an executor that can't see the + * driver's local disk. + */ + def resolveScratchDir(spark: SparkSession): String = { + spark.conf.getOption(ScratchDirConfKey).filter(_.nonEmpty) match { + case Some(p) => p + case None => + val master = Option(spark.sparkContext.master).getOrElse("") + if (!master.startsWith("local")) { + throw new IllegalStateException( + s"$ScratchDirConfKey is not set and Spark master is '$master' — per-query " + + "temp Lance materialization needs a scratch path every executor can " + + s"read+write. Set $ScratchDirConfKey to a shared URI (s3://..., abfss://..., " + + "file:///shared-mount/..., hdfs://...).") + } + val localDir = Option(spark.sparkContext.getConf.get("spark.local.dir", null)) + .map(_.split(",").head.trim) + .filter(_.nonEmpty) + .getOrElse(System.getProperty("java.io.tmpdir")) + s"$localDir/lance-temp-r" + } + } + + /** + * Walk the columns the helper would write (rid + vec + caller-requested projection) + * and return the first column whose type Lance can't write, or `None` if everything is + * fine. Callers use this to decide their fallback behaviour — the SQL rule path + * silently returns the original `NearestByJoin` (Spark's brute-force handles it), the + * DataFrame API path throws `IllegalArgumentException`. + * + * Without this pre-check, an unsupported type would surface as an opaque write-time + * error inside `df.write.format("lance").save()` after we've already started shipping + * task closures — slow and confusing. + * + * @param schema the projected schema (rid + vec + payload cols) the helper would write. + * @return `Some(reason)` if any field is unsupported, `None` if every field is fine. + */ + def checkSupported(schema: StructType): Option[String] = + findUnsupportedField(schema) + + /** + * Conservative type-allowlist for what Lance can write via `df.write.format("lance")`. + * + * Allowed: + * - All numeric primitives (Byte/Short/Int/Long/Float/Double/Decimal) + * - Boolean, String, Binary + * - Date / Timestamp / TimestampNTZ + * - StructType — recursive check on each field + * - ArrayType — recursive check on element type + * + * Rejected: + * - MapType — Lance's columnar layout doesn't support arbitrary string-keyed maps + * - CalendarIntervalType — no Arrow correspondence + * - UserDefinedType — opaque blobs that Lance can't know about + * - NullType — there's no element type to write + */ + private def findUnsupportedField(schema: StructType): Option[String] = { + schema.fields.iterator.flatMap { f => + findUnsupportedType(f.dataType).map(reason => + s"column '${f.name}' has type ${f.dataType.sql} which is not Lance-writable: $reason") + }.toSeq.headOption + } + + private def findUnsupportedType(dt: DataType): Option[String] = dt match { + case _: NumericType | BooleanType | StringType | BinaryType => None + case DateType | TimestampType => None + case s: StructType => + // Recursive check on each field. + s.fields.iterator.flatMap { f => + findUnsupportedType(f.dataType).map(r => s"nested field '${f.name}' of struct: $r") + }.toSeq.headOption + case a: ArrayType => + findUnsupportedType(a.elementType).map(r => s"array element type unsupported: $r") + case _: MapType => Some("MapType is not supported by lance-spark writer") + case _: NullType => Some("NullType has no concrete element to write") + case other => Some(s"unrecognised DataType ${other.getClass.getSimpleName}") + } + + /** + * Generate a unique scratch path under `scratchDir`. Caller must clean up. The path + * deliberately does NOT include the substring "lance" — the V2 catalog's path-identifier + * parser tokenises around it on writes (same workaround documented in + * `LanceWriteBenchmark`). + */ + def mintTempUri(scratchDir: String): String = { + val token = "tempr-" + UUID.randomUUID().toString + // For a local scratch dir, ensure the parent exists. Object-store URIs (s3://, etc.) + // are passed through as-is — the lance writer creates the path. + if (isLocalPath(scratchDir)) { + val parent: Path = Paths.get(stripFileScheme(scratchDir)) + Files.createDirectories(parent) + parent.resolve(token).toString + } else { + val sep = if (scratchDir.endsWith("/")) "" else "/" + s"$scratchDir$sep$token" + } + } + + private def isLocalPath(uri: String): Boolean = + uri.startsWith("/") || uri.startsWith("file://") || !uri.contains("://") + + private def stripFileScheme(uri: String): String = + if (uri.startsWith("file://")) uri.substring("file://".length) else uri +} diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/LanceKnnImplicitsTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/LanceKnnImplicitsTest.scala index 4f66f0840..5d8f4517c 100644 --- a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/LanceKnnImplicitsTest.scala +++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/LanceKnnImplicitsTest.scala @@ -25,15 +25,14 @@ import java.util.Random import scala.collection.JavaConverters._ /** - * End-to-end tests for the `df.kNearestJoin(rightDf, ...)` extension. Three things to - * verify: + * End-to-end tests for the `df.kNearestJoin(rightDf, ...)` extension. Verifies: * * 1. The extension returns the same rows as `IndexedNearestJoin.apply(uri, ...)` — the - * wrapper just changes the call site, not the semantics. + * wrapper just changes the call site, not the semantics, when R is a Lance scan. * 2. URI extraction handles a plain Lance `spark.read.load` — the common case. - * 3. URI extraction throws cleanly when the right DataFrame isn't backed by a Lance scan - * (e.g. created from in-memory rows). Bad input must fail fast with a helpful message, - * not surface as a confusing runtime error inside the probe. + * 3. **Non-Lance R is supported** via per-query temp Lance materialization (sezruby/lance-spark#2): + * parquet, in-memory, alias-wrapped, and subplan-backed sources all produce the same + * top-K as the equivalent Lance-native R. * * The Phase 0 oracle test in `LanceProbeValidationTest` covers correctness of the underlying * probe; we can keep these tests light and not re-validate that. @@ -119,89 +118,251 @@ class LanceKnnImplicitsTest { } /** - * A DataFrame backed by Parquet (or any non-Lance format) must also fail the Lance-only - * guard — the API contract is `format("lance").load(...)` specifically, not "any DataFrame - * Spark can read." Catches the case where a user wires in the wrong reader by mistake. + * Parquet R: kNearestJoin transparently materializes a temp Lance dataset and returns + * the same top-K row IDs as the equivalent Lance-native R run. Validates the oracle + * equivalence end-to-end. + * + * Per-issue #2 design: non-Lance R is materialized once via [[LanceTempR.materialize]], + * then the existing probe pipeline runs on the temp URI. */ - @Test def testParquetRightThrowsClearError(): Unit = { + @Test def testKNearestJoinAgainstParquetRMatchesLanceR(): Unit = { val (leftDf, _, _) = buildLeft() + val (rightLanceUri, rightIds, rightVecs) = writeRight() + + // Same R data as parquet, with rid + rvec only (matches the projection rightDf-Lance uses) + val parquetPath = tempDir.resolve(s"docs_${System.nanoTime()}.parquet").toString + rightIds.zip(rightVecs).map { case (rid, vec) => (rid, vec) } // sanity unused val parquetSchema = new StructType(Array( StructField("rid", IntegerType, nullable = false), - StructField("rvec", ArrayType(FloatType, containsNull = false), nullable = false))) - val rows = Seq( - RowFactory.create(Integer.valueOf(1), Array.fill(Dim)(0.0f)), - RowFactory.create(Integer.valueOf(2), Array.fill(Dim)(0.5f))) - val parquetPath = tempDir.resolve(s"docs_${System.nanoTime()}.parquet").toString - spark.createDataFrame(rows.asJava, parquetSchema).write.parquet(parquetPath) + StructField( + "rvec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build()))) + val rows = rightIds.zip(rightVecs).map { case (rid, vec) => + RowFactory.create(Integer.valueOf(rid), vec) + } + spark.createDataFrame(rows.toSeq.asJava, parquetSchema).write.parquet(parquetPath) val parquetDf = spark.read.parquet(parquetPath) - val ex = assertThrows( - classOf[IllegalArgumentException], - () => - leftDf.kNearestJoin( - right = parquetDf, - leftVecCol = "lvec", - rightVecCol = "rvec", - k = 3, - metric = "l2")) + val viaParquet = leftDf + .kNearestJoin( + right = parquetDf, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = 5, + metric = "l2", + rightProjection = Some(Seq("rid"))) + .collect() + + val viaLance = IndexedNearestJoin( + left = leftDf, + rightLanceUri = rightLanceUri, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = 5, + metric = "l2", + rightProjection = Some(Seq("rid"))) + .collect() + + val byLid = (rs: Array[org.apache.spark.sql.Row]) => + rs.groupBy(_.getAs[Int]("lid")).map { case (lid, group) => + lid -> group.map(_.getAs[Int]("rid")).toSet + } + assertEquals( + byLid(viaLance), + byLid(viaParquet), + "parquet R via temp materialization must produce the same top-K rid set as Lance R") + } + + /** + * Subplan-backed R: parquet → Filter → Project. The kNearestJoin extension only sees a + * DataFrame; the temp-Lance materialization driver-evaluates whatever subplan is + * underneath. Tests that this load-bearing case (issue #2 primary motivation) works. + */ + @Test def testKNearestJoinAgainstSubplanR(): Unit = { + val (leftDf, _, _) = buildLeft() + val (_, rightIds, rightVecs) = writeRight() + + // Wider source so we have something to Filter / Project away. Add a `kept` boolean + // column; the subplan will keep only rows where kept=true. + val wideSchema = new StructType(Array( + StructField("rid", IntegerType, nullable = false), + StructField( + "rvec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build()), + StructField("kept", BooleanType, nullable = false), + StructField("payload", StringType, nullable = false))) + val rows = rightIds.zip(rightVecs).zipWithIndex.map { + case ((rid, vec), idx) => + RowFactory.create( + Integer.valueOf(rid), + vec, + java.lang.Boolean.valueOf(idx % 2 == 0), + s"p$idx") + } + val widePath = tempDir.resolve(s"wide_${System.nanoTime()}.parquet").toString + spark.createDataFrame(rows.toSeq.asJava, wideSchema).write.parquet(widePath) + + import org.apache.spark.sql.functions.col + val subplan = spark.read.parquet(widePath) + .filter(col("kept") === true) + .select("rid", "rvec") + + val expectedKeptIds = rightIds.zipWithIndex.collect { case (rid, idx) if idx % 2 == 0 => rid } + .toSet + + val joined = leftDf + .kNearestJoin( + right = subplan, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = 3, + metric = "l2", + rightProjection = Some(Seq("rid"))) + .collect() + + // Every returned rid must come from the kept subset — proves the subplan was actually + // evaluated before materialization (not just a reference to the underlying parquet). + val joinedRids = joined.map(_.getAs[Int]("rid")).toSet + val leakedRids = joinedRids -- expectedKeptIds assertTrue( - ex.getMessage.contains("Lance scan"), - s"expected error message to mention Lance scan for parquet input; got: ${ex.getMessage}") + leakedRids.isEmpty, + s"top-K must be drawn only from the kept subset; got leaks: $leakedRids") + assertEquals(NumLeft * 3, joined.length, "expected k=3 results per left row") } /** - * Non-Lance DataFrame wrapped in a `SubqueryAlias` (via `as("d")`) must still fail. The - * URI extractor walks `SubqueryAlias` to find the underlying relation; if the underlying - * is not Lance, alias unwrapping must NOT silently accept it. + * Pin the within-query "same Lance dataset path" property: both the probe and the + * materialize stages of the staged plan must reference the same temp Lance URI. If + * the wiring ever drifts (e.g. helper produces URI A but probe pipeline gets URI B), + * the staged pipeline reads from the wrong dataset and produces silent wrong results. + * This is the structural pin that prevents that. */ - @Test def testNonLanceUnderAliasThrowsClearError(): Unit = { + @Test def testProbeAndMaterializeShareSameTempUri(): Unit = { + import org.lance.spark.knn.internal.staged.{LanceMaterializeLogicalPlan, LanceProbeLogicalPlan} val (leftDf, _, _) = buildLeft() - val rows = Seq(RowFactory.create(Integer.valueOf(1), Array.fill(Dim)(0.0f))) + val (_, rightIds, rightVecs) = writeRight() + val schema = new StructType(Array( StructField("rid", IntegerType, nullable = false), - StructField("rvec", ArrayType(FloatType, containsNull = false), nullable = false))) - val notLance = spark.createDataFrame(rows.asJava, schema).as("d") + StructField( + "rvec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build()))) + val rows = rightIds.zip(rightVecs).map { case (rid, vec) => + RowFactory.create(Integer.valueOf(rid), vec) + } + val inMemoryR = spark.createDataFrame(rows.toSeq.asJava, schema) - val ex = assertThrows( - classOf[IllegalArgumentException], - () => - leftDf.kNearestJoin( - right = notLance, - leftVecCol = "lvec", - rightVecCol = "rvec", - k = 3, - metric = "l2")) - assertTrue( - ex.getMessage.contains("Lance scan"), - s"alias-wrapped non-Lance must still fail; got: ${ex.getMessage}") + val joined = leftDf.kNearestJoin( + right = inMemoryR, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = 3, + metric = "l2", + rightProjection = Some(Seq("rid"))) + + // Walk the analyzed plan looking for the probe and materialize logical nodes; both + // must carry the same `datasetUri` in their stage configs. + val plan = joined.queryExecution.analyzed + val probeNodes = plan.collect { case p: LanceProbeLogicalPlan => p.stageConf.datasetUri } + val materializeNodes = plan.collect { case m: LanceMaterializeLogicalPlan => + m.stageConf.datasetUri + } + assertEquals( + 1, + probeNodes.size, + s"expected exactly one LanceProbeLogicalPlan; got: $probeNodes") + assertEquals( + 1, + materializeNodes.size, + s"expected exactly one LanceMaterializeLogicalPlan; got: $materializeNodes") + assertEquals( + probeNodes.head, + materializeNodes.head, + "probe and materialize must reference the SAME temp Lance dataset URI") } /** - * A DataFrame built from in-memory rows is NOT a Lance scan — the extension should throw - * an `IllegalArgumentException` with a message naming the constraint, so the user knows - * to hand a real Lance DataFrame instead. + * Right side has a column type Lance can't write (MapType). The DataFrame API path + * is explicit — the user called `kNearestJoin` directly — so it must throw with a + * helpful message rather than silently fall through. (The SQL rule path in the 4.2 + * module makes the opposite choice — it falls through to Spark's brute-force rewrite + * because the user wrote a generic `APPROX NEAREST` query.) */ - @Test def testNonLanceRightThrowsClearError(): Unit = { + @Test def testKNearestJoinRejectsUnsupportedColumnType(): Unit = { + import org.apache.spark.sql.functions.{lit, map} val (leftDf, _, _) = buildLeft() - val ridSchema = new StructType(Array( + val (_, rightIds, rightVecs) = writeRight() + val schema = new StructType(Array( StructField("rid", IntegerType, nullable = false), - StructField("rvec", ArrayType(FloatType, containsNull = false), nullable = false))) - val notLance = spark.createDataFrame( - Seq(RowFactory.create(Integer.valueOf(1), Array.fill(Dim)(0.0f))).asJava, - ridSchema) + StructField( + "rvec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build()))) + val rows = rightIds.zip(rightVecs).map { case (rid, vec) => + RowFactory.create(Integer.valueOf(rid), vec) + } + val withMap = spark.createDataFrame(rows.toSeq.asJava, schema) + .withColumn("attrs", map(lit("k"), lit("v"))) val ex = assertThrows( classOf[IllegalArgumentException], () => leftDf.kNearestJoin( - right = notLance, + right = withMap, leftVecCol = "lvec", rightVecCol = "rvec", k = 3, - metric = "l2")) + metric = "l2", + rightProjection = Some(Seq("rid", "attrs")))) + val msg = ex.getMessage.toLowerCase assertTrue( - ex.getMessage.contains("Lance scan"), - s"expected error message to mention Lance scan; got: ${ex.getMessage}") + msg.contains("attrs") || msg.contains("map"), + s"error should mention the offending column or type; got: ${ex.getMessage}") + } + + /** + * In-memory R (no underlying source — `createDataFrame(rows.asJava, schema)`). Same + * temp-materialization path; just exercises the case where the rid synthesis and write + * have no parquet/delta to come from. + */ + @Test def testKNearestJoinAgainstInMemoryR(): Unit = { + val (leftDf, _, _) = buildLeft() + val (_, rightIds, rightVecs) = writeRight() + + val schema = new StructType(Array( + StructField("rid", IntegerType, nullable = false), + StructField( + "rvec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build()))) + val rows = rightIds.zip(rightVecs).map { case (rid, vec) => + RowFactory.create(Integer.valueOf(rid), vec) + } + val inMemoryR = spark.createDataFrame(rows.toSeq.asJava, schema) + + // Verify even alias-wrapped works (SubqueryAlias unwrap → not a Lance scan → temp). + val aliased = inMemoryR.as("docs") + val joined = leftDf.kNearestJoin( + right = aliased, + leftVecCol = "lvec", + rightVecCol = "rvec", + k = 3, + metric = "l2", + rightProjection = Some(Seq("rid"))) + .collect() + assertEquals(NumLeft * 3, joined.length, "expected k=3 results per left row") + // All returned rids must come from the actual right side — sanity check. + val joinedRids = joined.map(_.getAs[Int]("rid")).toSet + val leaks = joinedRids -- rightIds.toSet + assertTrue(leaks.isEmpty, s"rids should be drawn from the input set; leaks: $leaks") } // -- helpers ------------------------------------------------------------------------------ diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceTempLifecycleTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceTempLifecycleTest.scala new file mode 100644 index 000000000..89b7790e5 --- /dev/null +++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceTempLifecycleTest.scala @@ -0,0 +1,185 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.knn.internal + +import org.apache.spark.sql.{Row, RowFactory, SparkSession} +import org.apache.spark.sql.types._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.io.TempDir + +import java.io.File +import java.nio.file.{Files, Path, Paths} +import java.util.Random + +import scala.collection.JavaConverters._ + +/** + * Tests for [[LanceTempLifecycle]]: + * + * - register adds the URI to the per-app registry, count reflects it + * - stopping the SparkSession (which fires SparkListenerApplicationEnd) deletes the + * registered temp dirs + * - stopForTesting (a back-door we expose explicitly so tests don't have to actually + * stop the session — that breaks subsequent BeforeEach setup in the same suite) + * also deletes + * - deleteUri handles a bare local path + * - registering the same URI twice in one app is idempotent + * + * Concurrent / multi-app cases are exercised by the existence of `appId`-keyed maps in + * the lifecycle code itself — testing genuine cross-app isolation in a JUnit suite would + * require multi-process orchestration that's not worth the test infra cost. The unit + * tests here cover the listener-driven path and the registry mechanics. + */ +class LanceTempLifecycleTest { + + @TempDir var tempDir: Path = _ + private var spark: SparkSession = _ + + private val Dim: Int = 4 + + @BeforeEach def setup(): Unit = { + spark = SparkSession.builder() + .appName("lance-temp-lifecycle-test") + .master("local[2]") + .config("spark.driver.bindAddress", "127.0.0.1") + .config("spark.driver.host", "127.0.0.1") + .getOrCreate() + } + + @AfterEach def teardown(): Unit = { + if (spark != null) { + // Clean up any test residue regardless of pass/fail. stopForTesting also handles + // any URIs leftover (e.g. when a test registered without going through .stop()). + LanceTempLifecycle.stopForTesting(spark.sparkContext.applicationId) + spark.stop() + } + } + + /** Registering and then deleting via the test back-door removes the URI from disk. */ + @Test def testRegisterAndExplicitCleanup(): Unit = { + val tempUri = writeTempLance() + val appId = spark.sparkContext.applicationId + + assertTrue(new File(tempUri).exists(), "precondition: temp Lance dataset exists") + assertEquals(0, LanceTempLifecycle.registeredCount(appId), "no registrations yet") + + LanceTempLifecycle.register(spark, tempUri) + assertEquals(1, LanceTempLifecycle.registeredCount(appId)) + + LanceTempLifecycle.stopForTesting(appId) + assertFalse(new File(tempUri).exists(), "cleanup must delete the temp dir") + assertEquals( + 0, + LanceTempLifecycle.registeredCount(appId), + "registry must be empty after cleanup") + } + + /** Multiple temp URIs from the same app are all cleaned up. */ + @Test def testMultipleRegistrationsAllCleanedUp(): Unit = { + val a = writeTempLance() + val b = writeTempLance() + val c = writeTempLance() + val appId = spark.sparkContext.applicationId + LanceTempLifecycle.register(spark, a) + LanceTempLifecycle.register(spark, b) + LanceTempLifecycle.register(spark, c) + assertEquals(3, LanceTempLifecycle.registeredCount(appId)) + + LanceTempLifecycle.stopForTesting(appId) + assertFalse(new File(a).exists()) + assertFalse(new File(b).exists()) + assertFalse(new File(c).exists()) + } + + /** + * Re-registering the same URI is a no-op. Important because LanceTempR.materialize + * could be called repeatedly with overlapping temp URIs in retry scenarios. + */ + @Test def testIdempotentRegistration(): Unit = { + val tempUri = writeTempLance() + val appId = spark.sparkContext.applicationId + LanceTempLifecycle.register(spark, tempUri) + LanceTempLifecycle.register(spark, tempUri) + LanceTempLifecycle.register(spark, tempUri) + assertEquals(1, LanceTempLifecycle.registeredCount(appId), "duplicate registers are deduped") + LanceTempLifecycle.stopForTesting(appId) + assertFalse(new File(tempUri).exists()) + } + + /** + * Stopping the SparkSession fires SparkListenerApplicationEnd and triggers cleanup + * via the listener path — the production cleanup trigger. We tear down `spark` + * inside this test, so override @AfterEach behavior by setting `spark = null`. + */ + @Test def testApplicationEndTriggersCleanup(): Unit = { + val tempUri = writeTempLance() + val appId = spark.sparkContext.applicationId + LanceTempLifecycle.register(spark, tempUri) + assertEquals(1, LanceTempLifecycle.registeredCount(appId)) + + spark.stop() + spark = null // prevent @AfterEach from calling stop() again + + // Listener fires on the listener bus thread; give it a moment to drain. + val deadline = System.currentTimeMillis() + 5000 + while (new File(tempUri).exists() && System.currentTimeMillis() < deadline) { + Thread.sleep(50) + } + assertFalse( + new File(tempUri).exists(), + "SparkListenerApplicationEnd path must delete the registered temp dir within 5s") + assertEquals(0, LanceTempLifecycle.registeredCount(appId)) + } + + /** deleteUri handles a non-existent path gracefully (no exception). */ + @Test def testDeleteUriNonExistentNoOps(): Unit = { + val ghost = tempDir.resolve("never_existed_" + System.nanoTime()).toString + LanceTempLifecycle.deleteUri(ghost, spark.sparkContext.hadoopConfiguration) + // No assertion — pass means no exception. + } + + /** deleteUri null/empty input is a no-op. */ + @Test def testDeleteUriNullOrEmpty(): Unit = { + LanceTempLifecycle.deleteUri(null, spark.sparkContext.hadoopConfiguration) + LanceTempLifecycle.deleteUri("", spark.sparkContext.hadoopConfiguration) + } + + // -- helpers -------------------------------------------------------------------------------- + + /** Write a tiny Lance dataset under tempDir and return its URI. */ + private def writeTempLance(): String = { + val schema = new StructType(Array( + StructField("id", IntegerType, nullable = false), + StructField( + "vec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build()))) + val rows: Seq[Row] = (0 until 4).map { i => + RowFactory.create(Integer.valueOf(i), randomVector(new Random(i.toLong), Dim)) + } + val df = spark.createDataFrame(rows.asJava, schema) + val target = tempDir.resolve("temp_" + System.nanoTime()) + df.write.format("lance").save(target.toString) + target.toString + } + + private def randomVector(rng: Random, dim: Int): Array[Float] = { + val v = new Array[Float](dim) + var i = 0 + while (i < dim) { v(i) = rng.nextFloat(); i += 1 } + v + } +} diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceTempRTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceTempRTest.scala new file mode 100644 index 000000000..0817c13a8 --- /dev/null +++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceTempRTest.scala @@ -0,0 +1,401 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.knn.internal + +import org.apache.spark.sql.{Row, RowFactory, SparkSession} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.io.TempDir + +import java.nio.file.{Files, Path} +import java.util.Random + +import scala.collection.JavaConverters._ + +/** + * Behavioural tests for the per-query temp-Lance helper [[LanceTempR]]. Validates the + * properties the design relies on: + * + * - Round-trip: rows materialized to temp Lance read back with the same row count and + * payload as the source DataFrame. + * - Synthesised rid is unique within a single materialization. + * - FixedSizeList vector columns survive the write+read. + * - Caller-requested projection columns are present in the temp; non-requested are + * dropped (column pruning). + * - Subplan-backed sources (Filter / Project chains over parquet) work the same as + * a flat parquet read — the helper only sees a `DataFrame`, not its source. + * - Empty source produces an empty (but readable) temp. + * - Validation: missing vec col, unknown projection col, reserved rid name → fail fast. + */ +class LanceTempRTest { + + @TempDir var tempDir: Path = _ + private var spark: SparkSession = _ + + private val Dim: Int = 8 + private val NumRows: Int = 32 + private val Seed: Long = 11L + + @BeforeEach def setup(): Unit = { + spark = SparkSession.builder() + .appName("lance-temp-r-test") + .master("local[2]") + .config("spark.driver.bindAddress", "127.0.0.1") + .config("spark.driver.host", "127.0.0.1") + .getOrCreate() + } + + @AfterEach def teardown(): Unit = { + if (spark != null) spark.stop() + } + + // -- core round-trip ------------------------------------------------------------------------ + + /** Parquet-backed source: write via the helper, read back, verify row count + uniqueness of rid. */ + @Test def testParquetRoundTripRowCountAndRidUnique(): Unit = { + val parquetUri = writeRandomParquet(NumRows, Dim) + val rightDf = spark.read.parquet(parquetUri) + + val tempUri = + LanceTempR.materialize( + rightDf, + vecCol = "vec", + projection = Seq.empty, + scratchDir = scratch()) + + val readBack = spark.read.format("lance").load(tempUri) + val rows = readBack.collect() + assertEquals(NumRows, rows.length, "row count round-trips") + + val rids = readBack.select(col(LanceTempR.RidColumnName)).collect().map(_.getLong(0)) + assertEquals(NumRows, rids.distinct.length, "rid column is unique within one materialization") + } + + /** Vec column survives unchanged: per-row equality with the source. */ + @Test def testVectorColumnPreserved(): Unit = { + val parquetUri = writeRandomParquet(NumRows, Dim) + val rightDf = spark.read.parquet(parquetUri) + val tempUri = + LanceTempR.materialize( + rightDf, + vecCol = "vec", + projection = Seq("id"), + scratchDir = scratch()) + + val srcRows = rightDf.orderBy("id").collect() + val tempRows = spark.read.format("lance").load(tempUri).orderBy("id").collect() + assertEquals(srcRows.length, tempRows.length) + srcRows.zip(tempRows).foreach { case (s, t) => + val sv = s.getAs[Seq[Float]]("vec").toArray + val tv = t.getAs[Seq[Float]]("vec").toArray + assertArrayEquals(sv, tv, 1e-6f, s"vector mismatch for id=${s.getInt(s.fieldIndex("id"))}") + } + } + + // -- projection (column pruning at write time) --------------------------------------------- + + /** When projection is specified, only requested cols (plus rid + vec) appear in temp. */ + @Test def testProjectionTrimsExtraColumns(): Unit = { + val parquetUri = writeWideParquet(NumRows, Dim) + val rightDf = spark.read.parquet(parquetUri) + // Source has id, vec, label, payload, untouched. Project only id + label. + val tempUri = LanceTempR.materialize( + rightDf, + vecCol = "vec", + projection = Seq("id", "label"), + scratchDir = scratch()) + + val tempSchema = spark.read.format("lance").load(tempUri).schema.fieldNames.toSet + assertEquals( + Set(LanceTempR.RidColumnName, "vec", "id", "label"), + tempSchema, + "temp Lance schema should be exactly rid + vec + projection cols") + } + + /** Empty projection still gets rid + vec. */ + @Test def testEmptyProjectionGivesRidPlusVec(): Unit = { + val parquetUri = writeRandomParquet(NumRows, Dim) + val rightDf = spark.read.parquet(parquetUri) + val tempUri = + LanceTempR.materialize( + rightDf, + vecCol = "vec", + projection = Seq.empty, + scratchDir = scratch()) + + val tempSchema = spark.read.format("lance").load(tempUri).schema.fieldNames.toSet + assertEquals(Set(LanceTempR.RidColumnName, "vec"), tempSchema) + } + + // -- subplan source (the load-bearing case) ------------------------------------------------- + + /** + * Source is a subplan: parquet → Filter → Project. Helper handles it the same as flat parquet, + * because it only consumes a DataFrame, not knowledge of the source. + */ + @Test def testSubplanSourceFilterPlusProject(): Unit = { + val parquetUri = writeWideParquet(NumRows, Dim) + val raw = spark.read.parquet(parquetUri) + // Keep only label = "even" rows (ids 0,2,4,...) and drop the payload column. + val subplanRight = raw.filter(col("label") === "even").select("id", "vec") + + val expectedCount = subplanRight.count() + assertTrue(expectedCount > 0, "test setup: subplan should have rows") + + val tempUri = LanceTempR.materialize( + subplanRight, + vecCol = "vec", + projection = Seq("id"), + scratchDir = scratch()) + val readBack = spark.read.format("lance").load(tempUri) + assertEquals( + expectedCount, + readBack.count(), + "row count of temp Lance equals row count of subplan-evaluated source") + } + + // -- empty input ---------------------------------------------------------------------------- + + /** Empty source DataFrame: temp dataset is created and reads back as 0 rows. */ + @Test def testEmptyDataFrame(): Unit = { + val parquetUri = writeRandomParquet(0, Dim) + val rightDf = spark.read.parquet(parquetUri) + val tempUri = + LanceTempR.materialize( + rightDf, + vecCol = "vec", + projection = Seq.empty, + scratchDir = scratch()) + + val readBack = spark.read.format("lance").load(tempUri) + assertEquals(0L, readBack.count(), "empty source produces empty Lance dataset") + } + + // -- validation ----------------------------------------------------------------------------- + + @Test def testRejectsMissingVecColumn(): Unit = { + val parquetUri = writeRandomParquet(2, Dim) + val rightDf = spark.read.parquet(parquetUri) + val ex = assertThrows( + classOf[IllegalArgumentException], + () => + LanceTempR.materialize( + rightDf, + vecCol = "no_such_col", + projection = Seq.empty, + scratchDir = scratch())) + assertTrue(ex.getMessage.contains("no_such_col")) + } + + @Test def testRejectsUnknownProjectionColumn(): Unit = { + val parquetUri = writeRandomParquet(2, Dim) + val rightDf = spark.read.parquet(parquetUri) + val ex = assertThrows( + classOf[IllegalArgumentException], + () => + LanceTempR.materialize( + rightDf, + vecCol = "vec", + projection = Seq("id", "ghost"), + scratchDir = scratch())) + assertTrue(ex.getMessage.contains("ghost")) + } + + @Test def testRejectsReservedRidName(): Unit = { + val parquetUri = writeRandomParquet(2, Dim) + val rightDf = spark.read.parquet(parquetUri) + val ex = assertThrows( + classOf[IllegalArgumentException], + () => + LanceTempR.materialize( + rightDf, + vecCol = "vec", + projection = Seq(LanceTempR.RidColumnName), + scratchDir = scratch())) + assertTrue(ex.getMessage.contains(LanceTempR.RidColumnName)) + } + + // -- schema validation --------------------------------------------------------------------- + + /** checkSupported is None for the typical projection (rid + vec + primitives + strings). */ + @Test def testCheckSupportedAcceptsCommonTypes(): Unit = { + val ok = new StructType(Array( + StructField("rid", LongType), + StructField( + "vec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", 8L).build()), + StructField("title", StringType), + StructField("count", IntegerType), + StructField("when", TimestampType), + StructField("flag", BooleanType), + StructField("payload", BinaryType))) + assertEquals(None, LanceTempR.checkSupported(ok)) + } + + /** MapType is not Lance-writable — checkSupported flags it with a clear message. */ + @Test def testCheckSupportedRejectsMap(): Unit = { + val notOk = new StructType(Array( + StructField("rid", LongType), + StructField("attrs", MapType(StringType, StringType)))) + val res = LanceTempR.checkSupported(notOk) + assertTrue(res.isDefined, "MapType must be rejected") + assertTrue(res.get.contains("attrs"), s"reason should name the column: ${res.get}") + assertTrue(res.get.toLowerCase.contains("map"), s"reason should mention Map: ${res.get}") + } + + /** Map nested inside an Array also rejected (recursive check). */ + @Test def testCheckSupportedRejectsArrayOfMap(): Unit = { + val notOk = new StructType(Array( + StructField("rid", LongType), + StructField("nested", ArrayType(MapType(StringType, IntegerType))))) + assertTrue(LanceTempR.checkSupported(notOk).isDefined) + } + + /** Map nested inside a Struct also rejected. */ + @Test def testCheckSupportedRejectsStructWithMap(): Unit = { + val inner = new StructType(Array(StructField("ext", MapType(StringType, StringType)))) + val notOk = new StructType(Array( + StructField("rid", LongType), + StructField("metadata", inner))) + assertTrue(LanceTempR.checkSupported(notOk).isDefined) + } + + /** + * materialize() throws IllegalArgumentException when projection includes an unsupported + * type — caller (DataFrame API path) propagates this to the user. + */ + @Test def testMaterializeRejectsUnsupportedProjection(): Unit = { + import org.apache.spark.sql.functions.{lit, map} + // Construct a DataFrame with a map column. + val parquetUri = writeRandomParquet(2, Dim) + val withMap = spark.read.parquet(parquetUri) + .withColumn("attrs", map(lit("k1"), lit("v1"))) + + val ex = assertThrows( + classOf[IllegalArgumentException], + () => + LanceTempR.materialize( + withMap, + vecCol = "vec", + projection = Seq("id", "attrs"), + scratchDir = scratch())) + assertTrue( + ex.getMessage.toLowerCase.contains("map") || ex.getMessage.contains("attrs"), + s"error should name the rejected column / type; got: ${ex.getMessage}") + } + + // -- resolveScratchDir --------------------------------------------------------------------- + + /** Conf key set: returned as-is. */ + @Test def testResolveScratchDirHonoursConfKey(): Unit = { + val explicit = scratch() + spark.conf.set(LanceTempR.ScratchDirConfKey, explicit) + try { + assertEquals(explicit, LanceTempR.resolveScratchDir(spark)) + } finally spark.conf.unset(LanceTempR.ScratchDirConfKey) + } + + /** + * Conf key unset in local mode: falls back to a local-FS path under spark.local.dir + * (or system tmp). The exact path is implementation-detail; the contract is "non-empty + * and works locally", verified by an actual round-trip below. + */ + @Test def testResolveScratchDirLocalFallbackWorks(): Unit = { + spark.conf.unset(LanceTempR.ScratchDirConfKey) + val resolved = LanceTempR.resolveScratchDir(spark) + assertTrue(resolved.nonEmpty, "local fallback must produce a non-empty path") + // Confirm it actually works as a scratchDir argument: round-trip a tiny dataset. + val parquetUri = writeRandomParquet(4, Dim) + val rightDf = spark.read.parquet(parquetUri) + val tempUri = LanceTempR.materialize( + rightDf, + vecCol = "vec", + projection = Seq.empty, + scratchDir = resolved) + assertEquals(4L, spark.read.format("lance").load(tempUri).count()) + } + + // -- helpers -------------------------------------------------------------------------------- + + /** Build (id, vec) parquet under tempDir; returns its URI. */ + private def writeRandomParquet(n: Int, dim: Int): String = { + val schema = new StructType(Array( + StructField("id", IntegerType, nullable = false), + StructField( + "vec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", dim.toLong).build()))) + val rng = new Random(Seed) + val rows: Seq[Row] = (0 until n).map { i => + RowFactory.create(Integer.valueOf(i), randomVector(rng, dim)) + } + val df = spark.createDataFrame(rows.asJava, schema) + val uri = subPath("right_parquet").toString + df.write.parquet(uri) + uri + } + + /** Wider source with label + payload + untouched, for projection-trim and subplan tests. */ + private def writeWideParquet(n: Int, dim: Int): String = { + val schema = new StructType(Array( + StructField("id", IntegerType, nullable = false), + StructField( + "vec", + ArrayType(FloatType, containsNull = false), + nullable = false, + new MetadataBuilder().putLong("arrow.fixed-size-list.size", dim.toLong).build()), + StructField("label", StringType, nullable = false), + StructField("payload", StringType, nullable = false), + StructField("untouched", IntegerType, nullable = false))) + val rng = new Random(Seed) + val rows: Seq[Row] = (0 until n).map { i => + RowFactory.create( + Integer.valueOf(i), + randomVector(rng, dim), + if (i % 2 == 0) "even" else "odd", + s"p$i", + Integer.valueOf(i * 17)) + } + val df = spark.createDataFrame(rows.asJava, schema) + val uri = subPath("wide_parquet").toString + df.write.parquet(uri) + uri + } + + /** Scratch directory exists (the helper writes a child of it) — created if needed. */ + private def scratch(): String = { + val p = tempDir.resolve("scratch_" + System.nanoTime()) + Files.createDirectories(p) + p.toString + } + + /** + * Path that does NOT pre-exist — used as a Spark write target. Spark refuses to write + * to an existing path without overwrite mode. + */ + private def subPath(name: String): Path = + tempDir.resolve(name + "_" + System.nanoTime()) + + private def randomVector(rng: Random, dim: Int): Array[Float] = { + val v = new Array[Float](dim) + var i = 0 + while (i < dim) { v(i) = rng.nextFloat(); i += 1 } + v + } +}