Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package org.lance.spark.knn.catalyst

import org.apache.spark.sql.{LanceKnnDatasetBridge, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not, Or, VectorCosineSimilarity, VectorInnerProduct, VectorL2Distance}
import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, NearestByDirection, NearestByDistance, NearestBySimilarity}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, NearestByJoin, Project, SubqueryAlias}
Expand All @@ -21,7 +22,7 @@ import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.{BooleanType, ByteType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.lance.spark.knn.internal.{LanceMaterializeStage, LanceMergeStage, LanceProbeStage, Metric}
import org.lance.spark.knn.internal.{LanceMaterializeStage, LanceMergeStage, LanceProbeStage, LanceTempR, Metric}
import org.lance.spark.knn.internal.staged.{LanceMaterializeLogicalPlan, LanceMergeLogicalPlan, LanceProbeLogicalPlan, ProbedLeftCodec}

/**
Expand Down Expand Up @@ -101,6 +102,22 @@ object IndexedNearestByJoinRule extends Rule[LogicalPlan] {
/** Configuration key that gates the rule. Off by default to keep the rule opt-in for now. */
val EnabledConfKey: String = "spark.lance.knn.indexedNearestByJoin.enabled"

/**
* Configuration key that gates the per-query temp-Lance materialization for non-Lance
* right sides (sezruby/lance-spark#2). When enabled (and [[EnabledConfKey]] is also on),
* a `NearestByJoin` whose right side isn't a Lance scan triggers
* [[org.lance.spark.knn.internal.LanceTempR.materialize]] at rule-application time, then
* proceeds with the same staged-plan rewrite against the temp URI.
*
* Off by default. Two reasons to keep it opt-in:
* 1. The rule evaluates the right plan synchronously at analysis time — for large R
* that's a meaningful cost users should consciously accept.
* 2. Cluster mode requires `spark.lance.knn.tempR.dir` to be set (see
* [[LanceTempR.resolveScratchDir]]); a misconfigured cluster would fail with a
* clear error message that's still better surfaced behind an explicit opt-in.
*/
val TempRForSqlEnabledConfKey: String = "spark.lance.knn.tempRForSqlRule.enabled"

/**
* IVF cluster count to visit per query. Higher = better recall, more compute. Default
* (None) leaves Lance's index-default (typically 1).
Expand All @@ -120,6 +137,7 @@ object IndexedNearestByJoinRule extends Rule[LogicalPlan] {
}
val nprobes = optInt(NprobesConfKey)
val refineFactor = optInt(RefineFactorConfKey)
val tempRForSqlEnabled = conf.getConfString(TempRForSqlEnabledConfKey, "false").toBoolean
plan.transformDown {
case j @ NearestByJoin(left, right, joinType, true, k, rankingExpr, direction) =>
rewriteIfApplicable(
Expand All @@ -131,7 +149,8 @@ object IndexedNearestByJoinRule extends Rule[LogicalPlan] {
rankingExpr,
direction,
nprobes,
refineFactor).getOrElse(j)
refineFactor,
tempRForSqlEnabled).getOrElse(j)
}
}

Expand Down Expand Up @@ -164,10 +183,20 @@ object IndexedNearestByJoinRule extends Rule[LogicalPlan] {
rankingExpr: Expression,
direction: NearestByDirection,
nprobes: Option[Int],
refineFactor: Option[Int]): Option[LogicalPlan] = {
refineFactor: Option[Int],
tempRForSqlEnabled: Boolean): Option[LogicalPlan] = {
for {
lance <- unwrapLanceScan(right)
// Recognize the ranking BEFORE attempting the right-side resolution so we don't
// pay a temp-Lance materialization for queries that are going to fall through
// anyway (wrong direction, mixed-side rank expression, etc.).
(metric, leftVecAttr, rightVecCol) <- recognizeRanking(rankingExpr, direction, left, right)
lance <- unwrapLanceScan(right).orElse {
if (tempRForSqlEnabled) {
materializeNonLanceR(right, rightVecCol)
} else {
None
}
}
} yield {
val leftVecIdx = left.output.indexWhere(_.exprId == leftVecAttr.exprId)
require(leftVecIdx >= 0, s"left vector attr not found in left.output: $leftVecAttr")
Expand Down Expand Up @@ -422,6 +451,58 @@ object IndexedNearestByJoinRule extends Rule[LogicalPlan] {
}
}

/**
* Synthesise a [[LanceScanInfo]] for a non-Lance right plan by materializing it to a
* temp Lance dataset (sezruby/lance-spark#2). The materialization runs synchronously at
* rule-application time. Returns `None` if anything goes wrong — caller falls through
* to Spark's brute-force rewrite.
*
* The synthesised `output` reuses the ORIGINAL right plan's attribute references
* (`right.output`) — preserving `ExprId`s — so the top-level `Project(j.output, ...)`
* the rule emits stays attribute-resolved. The materialize stage reads from the temp
* Lance dataset (which has those columns by name plus the synthetic `_rid`) and binds
* the read-back rows back to those original attribute references.
*/
private def materializeNonLanceR(
right: LogicalPlan,
rightVecCol: String): Option[LanceScanInfo] = {
try {
val spark = SparkSession.active
val rightDf = LanceKnnDatasetBridge.asDataFrame(spark, right)
// Pre-check schema BEFORE triggering any work — if Lance can't write any of the
// projected columns, fall through (return None) so Spark's brute-force rewrite
// handles the query. Same "refusal not partial pushdown" pattern the prefilter
// translator uses.
val projection: Seq[String] = right.output.map(_.name).filterNot(_ == rightVecCol)
val projectedSchema = StructType(
(rightVecCol +: projection).map(name => rightDf.schema(name)))
if (LanceTempR.checkSupported(projectedSchema).isDefined) {
return None
}
val scratchDir = LanceTempR.resolveScratchDir(spark)
val tempUri = LanceTempR.materialize(
right = rightDf,
vecCol = rightVecCol,
projection = projection,
scratchDir = scratchDir)
// Synthesise a LanceScanInfo: URI is the temp; version is None (per-query temp);
// output keeps the original right.output (preserving ExprIds for the top-level
// Project resolution); no prefilter (the temp already represents the FULLY
// evaluated right plan, so any source-side filters were applied during the write).
Some(
LanceScanInfo(
uri = tempUri,
version = None,
output = right.output,
prefilter = None))
} catch {
case _: Throwable =>
// Any failure (config missing in cluster mode, write fails, etc.) — fall through
// to Spark's brute-force rewrite. The user's query still runs, just slowly.
None
}
}

private def isLanceTable(table: Table): Boolean = {
val cls = table.getClass.getName
// Loose by design — the rule is opt-in via spark.lance.knn.indexedNearestByJoin.enabled, so
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,144 @@ class IndexedNearestByJoinRuleTest {
assertSame(join, rewritten, "computed expression must refuse pushdown")
}

// -- per-query temp-Lance path for non-Lance R (sezruby/lance-spark#2) --------------------

/**
* With `spark.lance.knn.tempRForSqlRule.enabled = true`, a non-Lance right side
* (here: a parquet-backed DataFrame) triggers per-query temp-Lance materialization
* inside the rule. The rule should rewrite to the same staged-plan tree it produces
* for Lance R, but with the probe pointed at the temp URI.
*/
@Test def testTempRForSqlRewritesNonLanceR(): Unit = {
spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true")
spark.conf.set(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey, "true")
try {
val left = trivialPlan("lid", "lvec")
// Right side: an analyzed parquet plan, not a Lance scan.
val parquetRight = parquetLikePlan(idCol = "rid", vecCol = "rvec")

val leftVec = left.output.find(_.name == "lvec").get
val rightVec = parquetRight.output.find(_.name == "rvec").get
val join = NearestByJoin(
left,
parquetRight,
Inner,
approx = true,
numResults = 5,
rankingExpression = VectorL2Distance(leftVec, rightVec),
direction = NearestByDistance)
val rewritten = IndexedNearestByJoinRule(join)
// Same shape assertion as the Lance-R happy path: Project wrapping the materialize node.
assertTrue(
rewritten.isInstanceOf[Project] &&
rewritten.asInstanceOf[Project].child.isInstanceOf[LanceMaterializeLogicalPlan],
s"expected rewrite to Project(LanceMaterialize(...)), got: $rewritten")
// The probe URI should point at a temp dir, not anything from the parquet relation.
val summary = expectRewritten(rewritten)
// No prefilter when the right side was materialized — the temp already represents
// the fully-evaluated plan.
assertEquals(None, summary.prefilter)
} finally {
spark.conf.unset(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey)
spark.conf.unset(IndexedNearestByJoinRule.EnabledConfKey)
}
}

/**
* With the temp-R conf enabled but the rule itself disabled, a non-Lance right side still
* falls through. The materialize-and-rewrite path is gated on BOTH the main enabled flag
* AND the temp-R-for-SQL flag — turning on only one isn't enough.
*/
@Test def testTempRForSqlRequiresMainEnabledFlag(): Unit = {
spark.conf.unset(IndexedNearestByJoinRule.EnabledConfKey) // main flag off
spark.conf.set(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey, "true")
try {
val left = trivialPlan("lid", "lvec")
val parquetRight = parquetLikePlan(idCol = "rid", vecCol = "rvec")
val leftVec = left.output.find(_.name == "lvec").get
val rightVec = parquetRight.output.find(_.name == "rvec").get
val join = NearestByJoin(
left,
parquetRight,
Inner,
approx = true,
numResults = 5,
rankingExpression = VectorL2Distance(leftVec, rightVec),
direction = NearestByDistance)
val rewritten = IndexedNearestByJoinRule(join)
assertSame(
join,
rewritten,
"main enabled flag off → rule must not fire even with tempR enabled")
} finally {
spark.conf.unset(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey)
}
}

/**
* Right side has a Lance-unsupported column type (MapType). Even with the temp-R conf
* enabled, the rule must fall through to brute-force rather than throw — Catalyst rules
* shouldn't fail the user's query when a fallback exists. Pin: schema check refuses,
* rewrite returns None, original NearestByJoin propagates.
*/
@Test def testTempRForSqlFallsThroughOnUnsupportedSchema(): Unit = {
spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true")
spark.conf.set(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey, "true")
try {
val left = trivialPlan("lid", "lvec")
val rightWithMap = parquetLikePlanWithMap(idCol = "rid", vecCol = "rvec")
val leftVec = left.output.find(_.name == "lvec").get
val rightVec = rightWithMap.output.find(_.name == "rvec").get
val join = NearestByJoin(
left,
rightWithMap,
Inner,
approx = true,
numResults = 5,
rankingExpression = VectorL2Distance(leftVec, rightVec),
direction = NearestByDistance)
val rewritten = IndexedNearestByJoinRule(join)
assertSame(
join,
rewritten,
"right with MapType column must fall through (rule cannot materialize unsupported type)")
} finally {
spark.conf.unset(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey)
spark.conf.unset(IndexedNearestByJoinRule.EnabledConfKey)
}
}

/**
* Without the temp-R conf, a non-Lance right side falls through even when the main rule is
* enabled. Pins the existing behavior so we don't accidentally change it when extending the
* rule.
*/
@Test def testNonLanceRWithoutTempRConfFallsThrough(): Unit = {
spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true")
spark.conf.unset(IndexedNearestByJoinRule.TempRForSqlEnabledConfKey)
try {
val left = trivialPlan("lid", "lvec")
val parquetRight = parquetLikePlan(idCol = "rid", vecCol = "rvec")
val leftVec = left.output.find(_.name == "lvec").get
val rightVec = parquetRight.output.find(_.name == "rvec").get
val join = NearestByJoin(
left,
parquetRight,
Inner,
approx = true,
numResults = 5,
rankingExpression = VectorL2Distance(leftVec, rightVec),
direction = NearestByDistance)
val rewritten = IndexedNearestByJoinRule(join)
assertSame(
join,
rewritten,
"tempR conf off → non-Lance right must fall through to brute-force")
} finally {
spark.conf.unset(IndexedNearestByJoinRule.EnabledConfKey)
}
}

/** Filter wrapped in SubqueryAlias still pushes — order of unwrap shouldn't matter. */
@Test def testFilterUnderSubqueryAliasPushes(): Unit = {
spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true")
Expand Down Expand Up @@ -391,6 +529,47 @@ class IndexedNearestByJoinRuleTest {
spark.createDataFrame(rows.asJava, schema).queryExecution.analyzed
}

/**
* A parquet-backed analyzed plan for testing the per-query temp-R rewrite path. Materialization
* runs `df.write.format("lance").save(...)` against this plan, so the plan must execute
* (unlike `trivialPlan` which is just an analyzed `LocalRelation`). Writes a tiny parquet to
* tempDir on first call, returns its analyzed scan plan.
*/
private def parquetLikePlan(idCol: String, vecCol: String): LogicalPlan = {
val schema = new StructType(Array(
StructField(idCol, IntegerType, nullable = false),
StructField(
vecCol,
ArrayType(FloatType, containsNull = false),
nullable = false,
new MetadataBuilder().putLong("arrow.fixed-size-list.size", 8L).build())))
val rows = (0 until 4).map(i => RowFactory.create(Integer.valueOf(i), Array.fill(8)(0.5f)))
val parquetPath = tempDir.resolve(s"parquet_for_rule_${System.nanoTime()}").toString
spark.createDataFrame(rows.asJava, schema).write.parquet(parquetPath)
spark.read.parquet(parquetPath).queryExecution.analyzed
}

/**
* Like parquetLikePlan but adds a MapType column, which Lance's writer can't handle. Used
* to verify the rule's schema check refuses-and-falls-through on unsupported types.
*/
private def parquetLikePlanWithMap(idCol: String, vecCol: String): LogicalPlan = {
import org.apache.spark.sql.functions.{lit => sqlLit, map => mapFn}
val schema = new StructType(Array(
StructField(idCol, IntegerType, nullable = false),
StructField(
vecCol,
ArrayType(FloatType, containsNull = false),
nullable = false,
new MetadataBuilder().putLong("arrow.fixed-size-list.size", 8L).build())))
val rows = (0 until 4).map(i => RowFactory.create(Integer.valueOf(i), Array.fill(8)(0.5f)))
val parquetPath = tempDir.resolve(s"parquet_with_map_${System.nanoTime()}").toString
val withMap = spark.createDataFrame(rows.asJava, schema)
.withColumn("attrs", mapFn(sqlLit("k"), sqlLit("v")))
withMap.write.parquet(parquetPath)
spark.read.parquet(parquetPath).queryExecution.analyzed
}

/**
* Build a `DataSourceV2Relation` whose `table.getClass.getName.contains("Lance")` so the
* rule's duck-type check accepts it. We don't actually run any I/O. Includes a `category`
Expand Down
Loading