diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index b0f09bc43b..b62a000f6c 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -354,6 +354,7 @@ jobs: org.apache.comet.exec.CometGenerateExecSuite org.apache.comet.exec.CometWindowExecSuite org.apache.comet.exec.CometJoinSuite + org.apache.comet.exec.CometPythonMapInArrowSuite org.apache.comet.CometNativeSuite org.apache.comet.CometSparkSessionExtensionsSuite org.apache.spark.CometPluginsSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index c743d1888a..fe972818e6 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -193,6 +193,7 @@ jobs: org.apache.comet.exec.CometGenerateExecSuite org.apache.comet.exec.CometWindowExecSuite org.apache.comet.exec.CometJoinSuite + org.apache.comet.exec.CometPythonMapInArrowSuite org.apache.comet.CometNativeSuite org.apache.comet.CometSparkSessionExtensionsSuite org.apache.spark.CometPluginsSuite diff --git a/.github/workflows/pyarrow_udf_test.yml b/.github/workflows/pyarrow_udf_test.yml new file mode 100644 index 0000000000..211a9bd23a --- /dev/null +++ b/.github/workflows/pyarrow_udf_test.yml @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: PyArrow UDF Tests + +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + +on: + push: + branches: + - main + paths-ignore: + - "benchmarks/**" + - "doc/**" + - "docs/**" + - "**.md" + - "dev/changelog/*.md" + - "native/core/benches/**" + - "native/spark-expr/benches/**" + - "spark/src/test/scala/org/apache/spark/sql/benchmark/**" + - "spark/src/main/scala/org/apache/comet/GenerateDocs.scala" + pull_request: + paths-ignore: + - "benchmarks/**" + - "doc/**" + - "docs/**" + - "**.md" + - "dev/changelog/*.md" + - "native/core/benches/**" + - "native/spark-expr/benches/**" + - "spark/src/test/scala/org/apache/spark/sql/benchmark/**" + - "spark/src/main/scala/org/apache/comet/GenerateDocs.scala" + workflow_dispatch: + +permissions: + contents: read + +env: + RUST_VERSION: stable + RUST_BACKTRACE: 1 + RUSTFLAGS: "-Clink-arg=-fuse-ld=bfd" + +jobs: + pyarrow-udf: + name: PyArrow UDF (Spark 4.0, JDK 17, Python 3.11) + runs-on: ubuntu-latest + container: + # Pinned to the Debian 12 (bookworm) base so the system `python3` is 3.11. The default + # `amd64/rust` image is Debian 13 (trixie) which ships Python 3.13 and no python3.11 apt + # package, breaking `apt-get install python3.11`. + image: rust:bookworm + env: + JAVA_TOOL_OPTIONS: "--add-exports=java.base/sun.nio.ch=ALL-UNNAMED --add-exports=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.lang=ALL-UNNAMED" + steps: + - uses: actions/checkout@v6 + + - name: Setup Rust & Java toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ env.RUST_VERSION }} + jdk-version: 17 + + - name: Cache Maven dependencies + uses: actions/cache@v5 + with: + path: | + ~/.m2/repository + /root/.m2/repository + key: ${{ runner.os }}-java-maven-${{ hashFiles('**/pom.xml') }}-pyarrow-udf + restore-keys: | + ${{ runner.os }}-java-maven- + + - name: Build Comet (debug, Spark 4.0 / Scala 2.13) + run: | + cd native && cargo build + cd .. && ./mvnw -B install -DskipTests -Pspark-4.0 -Pscala-2.13 + + - name: Install Python 3.11 and pip + run: | + apt-get update + apt-get install -y --no-install-recommends python3 python3-venv python3-pip + python3 -m venv /tmp/venv + /tmp/venv/bin/pip install --upgrade pip + /tmp/venv/bin/pip install "pyspark==4.0.1" "pyarrow>=14" pandas pytest + + - name: Run PyArrow UDF pytest + env: + # Spark launches Python workers in a fresh subprocess and looks up `python3` + # on PATH unless PYSPARK_PYTHON is set. Without this, workers use the system + # python which has no pyarrow installed and UDF execution fails with + # ModuleNotFoundError. + PYSPARK_PYTHON: /tmp/venv/bin/python + PYSPARK_DRIVER_PYTHON: /tmp/venv/bin/python + run: | + /tmp/venv/bin/python -m pytest -v \ + spark/src/test/resources/pyspark/test_pyarrow_udf.py diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index d3f51dfbe2..0bdc35d3ce 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -314,6 +314,18 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) + val COMET_PYARROW_UDF_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.exec.pyarrowUdf.enabled") + .category(CATEGORY_EXEC) + .doc( + "Experimental: whether to enable optimized execution of PyArrow UDFs " + + "(mapInArrow/mapInPandas). When enabled, Comet passes Arrow columnar data " + + "directly to Python UDFs without the intermediate Arrow-to-Row-to-Arrow " + + "conversion that Spark normally performs. Disabled by default while the " + + "feature stabilizes.") + .booleanConf + .createWithDefault(false) + val COMET_TRACING_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.tracing.enabled") .category(CATEGORY_TUNING) .doc(s"Enable fine-grained tracing of events and memory usage. $TRACING_GUIDE.") diff --git a/docs/source/user-guide/latest/index.rst b/docs/source/user-guide/latest/index.rst index 480ec4f702..c96dea7750 100644 --- a/docs/source/user-guide/latest/index.rst +++ b/docs/source/user-guide/latest/index.rst @@ -38,5 +38,6 @@ Comet $COMET_VERSION User Guide Understanding Comet Plans Tuning Guide Metrics Guide + PyArrow UDF Acceleration Iceberg Guide Kubernetes Guide diff --git a/docs/source/user-guide/latest/pyarrow-udfs.md b/docs/source/user-guide/latest/pyarrow-udfs.md new file mode 100644 index 0000000000..23ef50e79c --- /dev/null +++ b/docs/source/user-guide/latest/pyarrow-udfs.md @@ -0,0 +1,188 @@ + + +# PyArrow UDF Acceleration + +Comet can accelerate Python UDFs that use PyArrow-backed batch processing, such as `mapInArrow` and `mapInPandas`. +These APIs are commonly used for ML inference, feature engineering, and data transformation workloads. + +## Background + +Spark's `mapInArrow` and `mapInPandas` APIs allow users to apply Python functions that operate on Arrow +RecordBatches or Pandas DataFrames. Under the hood, Spark communicates with the Python worker process +using the Arrow IPC format. + +Without Comet, the execution path for these UDFs involves unnecessary data conversions: + +1. Comet reads data in Arrow columnar format (via CometScan) +2. Spark inserts a ColumnarToRow transition (converts Arrow to UnsafeRow) +3. The Python runner converts those rows back to Arrow to send to Python +4. Python executes the UDF on Arrow batches +5. Results are returned as Arrow and then converted back to rows + +Steps 2 and 3 are redundant since the data starts and ends in Arrow format. + +## How Comet Optimizes This + +When enabled, Comet detects `PythonMapInArrowExec` / `MapInArrowExec` and `MapInPandasExec` +operators in the physical plan and replaces them with `CometMapInBatchExec`, which: + +- Reads Arrow columnar batches directly from the upstream Comet operator +- Feeds them to the Python runner without the expensive UnsafeProjection copy +- Keeps the Python output in columnar format for downstream operators + +This eliminates the ColumnarToRow transition and the output row conversion, reducing CPU overhead +and memory allocations. The internal row-to-Arrow IPC re-encoding inside Spark's +`ArrowPythonRunner` is unchanged in this version; full round-trip elimination is tracked in +[#4240](https://github.com/apache/datafusion-comet/issues/4240). + +### Plan flow + +Without Comet's optimization: + +``` +PythonMapInArrow / MapInArrow / MapInPandas ++- ColumnarToRow <- Arrow -> Row copy + +- CometNativeExec <- Arrow batch + +- CometScan +``` + +With the optimization enabled: + +``` +CometMapInBatch <- Arrow batch in/out, Python runner attached ++- CometNativeExec + +- CometScan +``` + +## Configuration + +The optimization is experimental and disabled by default. Enable it with: + +``` +spark.comet.exec.pyarrowUdf.enabled=true +``` + +The default is `false` while the feature stabilizes. + +## Supported APIs + +| PySpark API | Spark Plan Node | Supported | +| -------------------------------- | --------------------------- | --------- | +| `df.mapInArrow(func, schema)` | `PythonMapInArrowExec` | Yes | +| `df.mapInPandas(func, schema)` | `MapInPandasExec` | Yes | +| `@pandas_udf` (scalar) | `ArrowEvalPythonExec` | Not yet | +| `df.applyInPandas(func, schema)` | `FlatMapGroupsInPandasExec` | Not yet | + +## Example + +```python +import pyarrow as pa +from pyspark.sql import SparkSession, types as T + +spark = SparkSession.builder \ + .config("spark.plugins", "org.apache.spark.CometPlugin") \ + .config("spark.comet.enabled", "true") \ + .config("spark.comet.exec.enabled", "true") \ + .config("spark.comet.exec.pyarrowUdf.enabled", "true") \ + .config("spark.memory.offHeap.enabled", "true") \ + .config("spark.memory.offHeap.size", "2g") \ + .getOrCreate() + +df = spark.read.parquet("data.parquet") + +def transform(batch: pa.RecordBatch) -> pa.RecordBatch: + # Your transformation logic here + table = batch.to_pandas() + table["new_col"] = table["value"] * 2 + return pa.RecordBatch.from_pandas(table) + +output_schema = T.StructType([ + T.StructField("value", T.DoubleType()), + T.StructField("new_col", T.DoubleType()), +]) + +result = df.mapInArrow(transform, output_schema) +``` + +## Verifying the Optimization + +Use `explain()` to verify that `CometMapInBatch` appears in your plan: + +```python +result.explain(mode="extended") +``` + +You should see: + +``` +CometMapInBatch ... ++- CometNativeExec ... + +- CometScan ... +``` + +Instead of the unoptimized plan: + +``` +PythonMapInArrow ... ++- ColumnarToRow + +- CometNativeExec ... + +- CometScan ... +``` + +When AQE is enabled (the Spark default) and the query contains a shuffle, the +optimization is applied during stage materialization. Calling `explain()` before +running an action will show the unoptimized plan: + +``` +AdaptiveSparkPlan isFinalPlan=false ++- PythonMapInArrow ... + +- CometExchange ... +``` + +To see the optimized plan, run an action first (for example `result.collect()` or +`result.cache(); result.count()`) and then call `explain()`. The post-execution +plan shows the materialized stages and includes `CometMapInBatch` if the +optimization fired. + +## Barrier execution + +`mapInArrow(..., barrier=True)` and `mapInPandas(..., barrier=True)` are honored: the +optimized operator propagates `isBarrier` through `RDD.barrier()`, so all tasks are +gang-scheduled and `BarrierTaskContext.barrier()` works inside the UDF the same way it does +on the unoptimized path. + +## Limitations + +- The optimization currently applies only to `mapInArrow` and `mapInPandas`. Scalar pandas UDFs + (`@pandas_udf`) and grouped operations (`applyInPandas`) are not yet supported. +- The internal row-to-Arrow conversion inside the Python runner is still present in this version. + Comet currently routes columnar input through `ColumnarBatch.rowIterator()` so that the existing + `ArrowPythonRunner` can re-encode the rows back to Arrow IPC. A future optimization will write + Arrow batches directly to the Python IPC stream, eliminating the remaining round-trip and + achieving near zero-copy data transfer. +- The optimization requires Arrow data on the input side. If a shuffle sits between the upstream + Comet operator and the Python UDF, you need Comet's native shuffle for the optimization to + apply. Set `spark.shuffle.manager` to + `org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager` and enable + `spark.comet.exec.shuffle.enabled=true` at session startup. With a vanilla Spark `Exchange` + in the plan the data leaves the shuffle as rows and the optimization cannot fire. +- Spark 3.4 lacks several APIs the optimization depends on (`MapInBatchExec.isBarrier`, + `arrowUseLargeVarTypes`, `JobArtifactSet`, the modern `ArrowPythonRunner` constructor). On + Spark 3.4 the feature is a no-op even when enabled. Spark 3.5+ is required. diff --git a/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala b/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala index 7402a83248..24c969c173 100644 --- a/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala +++ b/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala @@ -22,8 +22,9 @@ package org.apache.comet.rules import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.sideBySide -import org.apache.spark.sql.comet.{CometCollectLimitExec, CometColumnarToRowExec, CometNativeColumnarToRowExec, CometNativeWriteExec, CometPlan, CometSparkToColumnarExec} +import org.apache.spark.sql.comet.{CometCollectLimitExec, CometColumnarToRowExec, CometMapInBatchExec, CometNativeColumnarToRowExec, CometNativeWriteExec, CometPlan, CometSparkToColumnarExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} +import org.apache.spark.sql.comet.shims.ShimCometMapInBatch import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.QueryStageExec import org.apache.spark.sql.execution.exchange.ReusedExchangeExec @@ -51,7 +52,9 @@ import org.apache.comet.CometConf // various reasons) or Spark requests row-based output such as a `collect` call. Spark will adds // another `ColumnarToRowExec` on top of `CometSparkToColumnarExec`. In this case, the pair could // be removed. -case class EliminateRedundantTransitions(session: SparkSession) extends Rule[SparkPlan] { +case class EliminateRedundantTransitions(session: SparkSession) + extends Rule[SparkPlan] + with ShimCometMapInBatch { private lazy val showTransformations = CometConf.COMET_EXPLAIN_TRANSFORMATIONS.get() @@ -98,6 +101,27 @@ case class EliminateRedundantTransitions(session: SparkSession) extends Rule[Spa case CometNativeColumnarToRowExec(sparkToColumnar: CometSparkToColumnarExec) => sparkToColumnar.child case CometSparkToColumnarExec(child: CometSparkToColumnarExec) => child + // Replace MapInBatchExec (PythonMapInArrowExec / MapInArrowExec / MapInPandasExec) that has + // a ColumnarToRow child with CometMapInBatchExec, eliminating the input and output + // UnsafeProjection copies and keeping the stage columnar. The matchers are + // version-shimmed: Spark 3.4 returns None (it lacks the required APIs) and Spark 4.1+ + // matches the renamed `MapInArrowExec`. + case p: SparkPlan if CometConf.COMET_PYARROW_UDF_ENABLED.get() => + matchMapInArrow(p).orElse(matchMapInPandas(p)) match { + case Some(info) => + extractColumnarChild(info.child) + .map { columnarChild => + CometMapInBatchExec( + info.func, + info.output, + columnarChild, + info.isBarrier, + info.pythonEvalType) + } + .getOrElse(p) + case None => p + } + // Spark adds `RowToColumnar` under Comet columnar shuffle. But it's redundant as the // shuffle takes row-based input. case s @ CometShuffleExchangeExec( @@ -130,6 +154,19 @@ case class EliminateRedundantTransitions(session: SparkSession) extends Rule[Spa } } + /** + * If the given plan is a Comet ColumnarToRow transition, returns the columnar child the Python + * UDF operator can consume directly. By the time this rule runs the earlier + * `hasCometNativeChild` arm has already rewritten any `ColumnarToRowExec` over a Comet columnar + * source to one of the Comet variants, so vanilla `ColumnarToRowExec` cannot reach here on a + * Comet-driven plan and is intentionally not handled. + */ + private def extractColumnarChild(plan: SparkPlan): Option[SparkPlan] = plan match { + case CometColumnarToRowExec(child) => Some(child) + case CometNativeColumnarToRowExec(child) => Some(child) + case _ => None + } + /** * Creates an appropriate columnar to row transition operator. * diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometMapInBatchExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometMapInBatchExec.scala new file mode 100644 index 0000000000..77dbfff7ce --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMapInBatchExec.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import scala.collection.JavaConverters._ + +import org.apache.spark.{ContextAwareIterator, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.comet.shims.ShimCometMapInBatch +import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.python.{BatchIterator, PythonSQLMetrics} +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} + +/** + * Comet replacement for Spark's `MapInBatchExec` family (`PythonMapInArrowExec` / + * `MapInArrowExec` in 4.1+ / `MapInPandasExec`). Accepts columnar input directly from a Comet + * child instead of going through the per-row `UnsafeProjection` that `ColumnarToRowExec` applies, + * and keeps the Python runner output as `ColumnarBatch` so downstream Comet operators consume it + * natively. + * + * What this eliminates: two `UnsafeProjection` copies (input and output) and the row transition + * between Comet and the Python operator. The internal row-to-Arrow IPC re-encoding inside + * `ArrowPythonRunner` is unchanged; full round-trip elimination is tracked in #4240. + */ +case class CometMapInBatchExec( + func: Expression, + output: Seq[Attribute], + child: SparkPlan, + isBarrier: Boolean, + pythonEvalType: Int) + extends UnaryExecNode + with CometPlan + with PythonSQLMetrics + with ShimCometMapInBatch { + + override def supportsColumnar: Boolean = true + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override lazy val metrics: Map[String, SQLMetric] = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "number of output batches"), + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows")) ++ + pythonMetrics + + // Fallback for row-consuming parents (e.g. a top-level `collect()` that produces rows). + // Wraps this columnar exec in `ColumnarToRowExec`, reintroducing exactly the row transition + // this operator otherwise eliminates. Only fires when nothing downstream consumes columnar. + override def doExecute(): RDD[InternalRow] = { + ColumnarToRowExec(this).doExecute() + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val numOutputBatches = longMetric("numOutputBatches") + val numInputRows = longMetric("numInputRows") + + val pythonUDF = func.asInstanceOf[PythonUDF] + val outputAttrs = output + val childSchema = child.schema + val batchSize = conf.arrowMaxRecordsPerBatch + val evalType = pythonEvalType + val sqlConf = conf + val metricsCopy = pythonMetrics + + val inputRDD = child.executeColumnar() + + def processPartition(batches: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = { + val context = TaskContext.get() + val argOffsets = Array(Array(0)) + + val rowIter = batches.flatMap { batch => + numInputRows += batch.numRows() + batch.rowIterator().asScala + } + + val contextAwareIterator = new ContextAwareIterator(context, rowIter) + + // Wrap rows as a struct, matching MapInBatchEvaluatorFactory behavior + val wrappedIter = contextAwareIterator.map(InternalRow(_)) + + val batchIter = + if (batchSize > 0) new BatchIterator(wrappedIter, batchSize) else Iterator(wrappedIter) + + val columnarBatchIter = computeArrowPython( + pythonUDF, + evalType, + argOffsets, + StructType(Array(StructField("struct", childSchema))), + sqlConf, + metricsCopy, + batchIter, + context.partitionId(), + context) + + columnarBatchIter.map { batch => + val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + val outputVectors = outputAttrs.indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + numOutputRows += flattenedBatch.numRows() + numOutputBatches += 1 + flattenedBatch + } + } + + // Preserve isBarrier semantics: when set, run inside a barrier stage so all tasks + // are gang-scheduled and BarrierTaskContext.barrier() works inside the UDF. + if (isBarrier) { + inputRDD.barrier().mapPartitions(processPartition) + } else { + inputRDD.mapPartitionsInternal(processPartition) + } + } + + override protected def withNewChildInternal(newChild: SparkPlan): CometMapInBatchExec = + copy(child = newChild) +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/shims/MapInBatchInfo.scala b/spark/src/main/scala/org/apache/spark/sql/comet/shims/MapInBatchInfo.scala new file mode 100644 index 0000000000..f610c575b1 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/shims/MapInBatchInfo.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.execution.SparkPlan + +/** + * Spark-version-agnostic projection of a `MapInBatchExec` (`PythonMapInArrowExec`, + * `MapInArrowExec`, or `MapInPandasExec`) that the Comet rewrite needs. Lives outside the shims + * so the Comet planner can pattern-match on it without depending on which concrete Spark class + * was matched. + */ +case class MapInBatchInfo( + func: Expression, + output: Seq[Attribute], + child: SparkPlan, + isBarrier: Boolean, + pythonEvalType: Int) diff --git a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala new file mode 100644 index 0000000000..c7d6ae2f97 --- /dev/null +++ b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Spark 3.4 shim for the PyArrow UDF acceleration support. + * + * Spark 3.4 lacks several APIs that the optimization relies on (`isBarrier` on `MapInBatchExec`, + * `arrowUseLargeVarTypes`, `JobArtifactSet`, the modern `ArrowPythonRunner` constructor), so the + * matchers return `None` and the runner factory throws. The optimization is effectively a no-op + * on Spark 3.4. + */ +trait ShimCometMapInBatch { + + protected def matchMapInArrow(plan: SparkPlan): Option[MapInBatchInfo] = None + + protected def matchMapInPandas(plan: SparkPlan): Option[MapInBatchInfo] = None + + protected def computeArrowPython( + pythonUDF: PythonUDF, + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType, + conf: SQLConf, + pythonMetrics: Map[String, SQLMetric], + batchIter: Iterator[Iterator[InternalRow]], + partitionId: Int, + context: TaskContext): Iterator[ColumnarBatch] = + throw new UnsupportedOperationException("CometMapInBatchExec is not supported on Spark 3.4") +} diff --git a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala new file mode 100644 index 0000000000..42d66465f4 --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.{JobArtifactSet, TaskContext} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.{ArrowPythonRunner, MapInPandasExec, PythonMapInArrowExec} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +trait ShimCometMapInBatch { + + protected def matchMapInArrow(plan: SparkPlan): Option[MapInBatchInfo] = + plan match { + case p: PythonMapInArrowExec => + Some( + MapInBatchInfo( + p.func, + p.output, + p.child, + p.isBarrier, + PythonEvalType.SQL_MAP_ARROW_ITER_UDF)) + case _ => None + } + + protected def matchMapInPandas(plan: SparkPlan): Option[MapInBatchInfo] = + plan match { + case p: MapInPandasExec => + Some( + MapInBatchInfo( + p.func, + p.output, + p.child, + p.isBarrier, + PythonEvalType.SQL_MAP_PANDAS_ITER_UDF)) + case _ => None + } + + protected def computeArrowPython( + pythonUDF: PythonUDF, + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType, + conf: SQLConf, + pythonMetrics: Map[String, SQLMetric], + batchIter: Iterator[Iterator[InternalRow]], + partitionId: Int, + context: TaskContext): Iterator[ColumnarBatch] = { + val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonUDF.func))) + val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + new ArrowPythonRunner( + chainedFunc, + evalType, + argOffsets, + schema, + conf.sessionLocalTimeZone, + conf.arrowUseLargeVarTypes, + ArrowPythonRunner.getPythonRunnerConfMap(conf), + pythonMetrics, + jobArtifactUUID).compute(batchIter, partitionId, context) + } +} diff --git a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala new file mode 100644 index 0000000000..0c21cb3738 --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.ArrowPythonRunner +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +trait ShimCometMapInBatch extends Spark4xMapInBatchSupport { + + protected def computeArrowPython( + pythonUDF: PythonUDF, + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType, + conf: SQLConf, + pythonMetrics: Map[String, SQLMetric], + batchIter: Iterator[Iterator[InternalRow]], + partitionId: Int, + context: TaskContext): Iterator[ColumnarBatch] = { + val r = runnerInputs(pythonUDF, conf) + new ArrowPythonRunner( + r.chainedFunc, + evalType, + argOffsets, + schema, + r.timeZoneId, + r.largeVarTypes, + r.pythonRunnerConf, + pythonMetrics, + r.jobArtifactUUID, + None).compute(batchIter, partitionId, context) + } +} diff --git a/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala new file mode 100644 index 0000000000..e73748aafe --- /dev/null +++ b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.ArrowPythonRunner +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +trait ShimCometMapInBatch extends Spark4xMapInBatchSupport { + + protected def computeArrowPython( + pythonUDF: PythonUDF, + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType, + conf: SQLConf, + pythonMetrics: Map[String, SQLMetric], + batchIter: Iterator[Iterator[InternalRow]], + partitionId: Int, + context: TaskContext): Iterator[ColumnarBatch] = { + val r = runnerInputs(pythonUDF, conf) + new ArrowPythonRunner( + r.chainedFunc, + evalType, + argOffsets, + schema, + r.timeZoneId, + r.largeVarTypes, + r.pythonRunnerConf, + pythonMetrics, + r.jobArtifactUUID, + None, + None).compute(batchIter, partitionId, context) + } +} diff --git a/spark/src/main/spark-4.2/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala b/spark/src/main/spark-4.2/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala new file mode 100644 index 0000000000..0c21cb3738 --- /dev/null +++ b/spark/src/main/spark-4.2/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.ArrowPythonRunner +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +trait ShimCometMapInBatch extends Spark4xMapInBatchSupport { + + protected def computeArrowPython( + pythonUDF: PythonUDF, + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType, + conf: SQLConf, + pythonMetrics: Map[String, SQLMetric], + batchIter: Iterator[Iterator[InternalRow]], + partitionId: Int, + context: TaskContext): Iterator[ColumnarBatch] = { + val r = runnerInputs(pythonUDF, conf) + new ArrowPythonRunner( + r.chainedFunc, + evalType, + argOffsets, + schema, + r.timeZoneId, + r.largeVarTypes, + r.pythonRunnerConf, + pythonMetrics, + r.jobArtifactUUID, + None).compute(batchIter, partitionId, context) + } +} diff --git a/spark/src/main/spark-4.x/org/apache/spark/sql/comet/shims/Spark4xMapInBatchSupport.scala b/spark/src/main/spark-4.x/org/apache/spark/sql/comet/shims/Spark4xMapInBatchSupport.scala new file mode 100644 index 0000000000..78672aea5e --- /dev/null +++ b/spark/src/main/spark-4.x/org/apache/spark/sql/comet/shims/Spark4xMapInBatchSupport.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.{JobArtifactSet, TaskContext} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.{ArrowPythonRunner, MapInArrowExec, MapInPandasExec} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Shared 4.x bits for `ShimCometMapInBatch`. The matchers and `getRunnerInputs` helper are + * identical across 4.0/4.1/4.2; only the `ArrowPythonRunner` constructor parameter list differs + * per minor, so each minor's `ShimCometMapInBatch` provides only `computeArrowPython`. + */ +trait Spark4xMapInBatchSupport { + + protected def matchMapInArrow(plan: SparkPlan): Option[MapInBatchInfo] = + plan match { + case p: MapInArrowExec => + Some( + MapInBatchInfo( + p.func, + p.output, + p.child, + p.isBarrier, + PythonEvalType.SQL_MAP_ARROW_ITER_UDF)) + case _ => None + } + + protected def matchMapInPandas(plan: SparkPlan): Option[MapInBatchInfo] = + plan match { + case p: MapInPandasExec => + Some( + MapInBatchInfo( + p.func, + p.output, + p.child, + p.isBarrier, + PythonEvalType.SQL_MAP_PANDAS_ITER_UDF)) + case _ => None + } + + /** Inputs every 4.x `ArrowPythonRunner` constructor needs in the same shape. */ + protected case class RunnerInputs( + chainedFunc: Seq[(ChainedPythonFunctions, Long)], + timeZoneId: String, + largeVarTypes: Boolean, + pythonRunnerConf: Map[String, String], + jobArtifactUUID: Option[String]) + + protected def runnerInputs(pythonUDF: PythonUDF, conf: SQLConf): RunnerInputs = + RunnerInputs( + chainedFunc = Seq((ChainedPythonFunctions(Seq(pythonUDF.func)), pythonUDF.resultId.id)), + timeZoneId = conf.sessionLocalTimeZone, + largeVarTypes = conf.arrowUseLargeVarTypes, + pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf), + jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)) +} diff --git a/spark/src/test/resources/pyspark/benchmark_pyarrow_udf.py b/spark/src/test/resources/pyspark/benchmark_pyarrow_udf.py new file mode 100644 index 0000000000..49574130c0 --- /dev/null +++ b/spark/src/test/resources/pyspark/benchmark_pyarrow_udf.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +End-to-end wall-clock benchmark for Comet's PyArrow UDF acceleration. + +Times `df.mapInArrow(passthrough, schema).count()` and the equivalent +`mapInPandas` query with `spark.comet.exec.pyarrowUdf.enabled` set +to false (vanilla Spark path) and true (Comet's optimized path). Both +modes run the same Python worker, so the measured delta covers what the +optimization actually changes for users: + + * vanilla: CometScan -> ColumnarToRow + UnsafeProjection -> ArrowPythonRunner + * optimized: CometScan -> rowIterator -> ArrowPythonRunner (same runner; + no UnsafeProjection, output kept as ColumnarBatch) + +Results are wall-clock seconds, so they include Python interpreter, +Arrow IPC, and downstream count() costs. That's intentional: the +optimization's user-visible value is what fraction of end-to-end time +it shaves off, not the JVM-side delta in isolation. + +Caveat: the workload here is `passthrough_udf` + `count()` on `local[2]`, +so most of the wall time is Spark's Python fork/IPC overhead with very +little real Python work. Real UDFs (PyArrow compute, pandas ops, model +inference) increase the per-row Python cost, which dilutes the JVM-side +savings and shrinks the speedup ratio relative to what you see here. + +Usage: + # Build Comet (release for representative numbers): + make release + + pip install pyspark==3.5.8 pyarrow pandas + + python3 spark/src/test/resources/pyspark/benchmark_pyarrow_udf.py + +Override defaults via environment variables: + COMET_JAR=/path/to/comet.jar path to the Comet jar + BENCHMARK_ROWS=2000000 rows per run + BENCHMARK_WARMUP=2 warmup iterations per case + BENCHMARK_ITERS=5 measured iterations per case +""" + +import contextlib +import os +import statistics +import sys +import tempfile +import time + +from pyspark.sql import SparkSession + +sys.path.insert(0, os.path.dirname(__file__)) +from conftest import resolve_comet_jar + + +def _build_spark() -> SparkSession: + jar = resolve_comet_jar() + os.environ["PYSPARK_SUBMIT_ARGS"] = ( + f"--jars {jar} --driver-class-path {jar} pyspark-shell" + ) + return ( + SparkSession.builder.master("local[2]") + .appName("comet-pyarrow-udf-benchmark") + .config("spark.plugins", "org.apache.spark.CometPlugin") + .config("spark.comet.enabled", "true") + .config("spark.comet.exec.enabled", "true") + .config("spark.memory.offHeap.enabled", "true") + .config("spark.memory.offHeap.size", "4g") + .config("spark.driver.memory", "4g") + # Pin AQE off so the explain output and plan structure are stable + # across iterations. AQE doesn't change the optimization's behavior; + # it just makes plan inspection harder. + .config("spark.sql.adaptive.enabled", "false") + .getOrCreate() + ) + + +def _passthrough_arrow(iterator): + for batch in iterator: + yield batch + + +def _passthrough_pandas(iterator): + for pdf in iterator: + yield pdf + + +def _narrow_primitives(spark: SparkSession, n: int): + return spark.range(n).selectExpr( + "id as id_long", + "cast(id as int) as id_int", + "cast(id as double) as id_double", + ) + + +def _mixed_with_strings(spark: SparkSession, n: int): + return spark.range(n).selectExpr( + "id as id_long", + "cast(id as int) as id_int", + "cast(id as double) as id_double", + "concat('row_', cast(id as string)) as id_str", + "cast(id % 2 as boolean) as id_bool", + ) + + +def _wide_rows(spark: SparkSession, n: int): + types = ["int", "long", "double"] + cols = [ + f"cast(id + {i} as {types[i % len(types)]}) as col_{i}" for i in range(50) + ] + return spark.range(n).selectExpr(*cols) + + +WORKLOADS = [ + ("narrow primitives", _narrow_primitives), + ("mixed with strings", _mixed_with_strings), + ("wide rows (50 cols)", _wide_rows), +] + + +@contextlib.contextmanager +def _temp_parquet(spark: SparkSession, build_df, n: int): + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "src.parquet") + build_df(spark, n).write.parquet(path) + yield path + + +def _time_run(spark: SparkSession, parquet_path: str, accelerate: bool, api: str) -> float: + spark.conf.set( + "spark.comet.exec.pyarrowUdf.enabled", + "true" if accelerate else "false", + ) + df = spark.read.parquet(parquet_path) + schema = df.schema + if api == "mapInArrow": + df = df.mapInArrow(_passthrough_arrow, schema) + else: + df = df.mapInPandas(_passthrough_pandas, schema) + t0 = time.perf_counter() + df.count() + return time.perf_counter() - t0 + + +def main() -> None: + rows = int(os.environ.get("BENCHMARK_ROWS", 1024 * 1024)) + warmup = int(os.environ.get("BENCHMARK_WARMUP", 2)) + iters = int(os.environ.get("BENCHMARK_ITERS", 5)) + + spark = _build_spark() + spark.sparkContext.setLogLevel("WARN") + + print(f"\nrows per run: {rows:,}") + print(f"warmup iters: {warmup}, measured iters: {iters}") + print(f"jar: {resolve_comet_jar()}\n") + + header = " {:<14} {:<10} {:>10} {:>10} {:>10} {:>13} {:>9}".format( + "api", "mode", "min (s)", "median (s)", "max (s)", "rows/s", "speedup" + ) + print(header) + print(" " + "-" * (len(header) - 2)) + + for name, build_df in WORKLOADS: + print(f"\n=== {name} ===") + with _temp_parquet(spark, build_df, rows) as parquet_path: + for api in ("mapInArrow", "mapInPandas"): + samples_by_mode = {} + for mode, accelerate in (("vanilla", False), ("optimized", True)): + for _ in range(warmup): + _time_run(spark, parquet_path, accelerate, api) + samples = [ + _time_run(spark, parquet_path, accelerate, api) + for _ in range(iters) + ] + samples_by_mode[mode] = samples + median = statistics.median(samples) + speedup = "" + if mode == "optimized": + speedup = "{:.2f}x".format( + statistics.median(samples_by_mode["vanilla"]) / median + ) + print( + " {:<14} {:<10} {:>10} {:>10} {:>10} {:>13} {:>9}".format( + api, + mode, + "{:.3f}".format(min(samples)), + "{:.3f}".format(median), + "{:.3f}".format(max(samples)), + "{:,.0f}".format(rows / median), + speedup, + ) + ) + + spark.stop() + + +if __name__ == "__main__": + main() diff --git a/spark/src/test/resources/pyspark/conftest.py b/spark/src/test/resources/pyspark/conftest.py new file mode 100644 index 0000000000..35d6d85191 --- /dev/null +++ b/spark/src/test/resources/pyspark/conftest.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Shared helpers for the pytest modules under this directory and for the +benchmark scripts that import them. + +`resolve_comet_jar` returns the path to the Comet jar a Spark session needs. +Resolution order: the `COMET_JAR` env var (taken verbatim if it points at a +file, expanded as a glob otherwise), then `/spark/target` matched against +the installed pyspark major.minor version. +""" + +import glob +import os + + +REPO_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..") +) + + +def resolve_comet_jar() -> str: + explicit = os.environ.get("COMET_JAR") + if explicit: + if any(ch in explicit for ch in "*?["): + matches = sorted(glob.glob(explicit)) + if not matches: + raise FileNotFoundError( + f"COMET_JAR pattern matched nothing: {explicit}" + ) + return matches[-1] + return explicit + + # Pick the jar that matches the installed pyspark major.minor version. The + # Comet jars are published per Spark version (e.g. + # comet-spark-spark3.5_2.12-*.jar); using the wrong one yields + # ClassNotFoundException on Scala stdlib classes. + import pyspark + + major_minor = ".".join(pyspark.__version__.split(".")[:2]) + spark_tag = f"spark{major_minor}" + scala_tag = "_2.12" if major_minor.startswith("3.") else "_2.13" + pattern = os.path.join( + REPO_ROOT, + f"spark/target/comet-spark-{spark_tag}{scala_tag}-*-SNAPSHOT.jar", + ) + candidates = [ + m + for m in sorted(glob.glob(pattern)) + if "sources" not in os.path.basename(m) and "tests" not in os.path.basename(m) + ] + if not candidates: + raise FileNotFoundError( + "Comet jar not found. Set COMET_JAR or run `make release`. " + f"Looked under {pattern}." + ) + return candidates[-1] diff --git a/spark/src/test/resources/pyspark/test_pyarrow_udf.py b/spark/src/test/resources/pyspark/test_pyarrow_udf.py new file mode 100644 index 0000000000..87558ec057 --- /dev/null +++ b/spark/src/test/resources/pyspark/test_pyarrow_udf.py @@ -0,0 +1,479 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Pytest-driven integration tests for Comet's PyArrow UDF acceleration. + +Each test runs against two execution paths: + - "accelerated": spark.comet.exec.pyarrowUdf.enabled=true + (plan should contain CometMapInBatch and no ColumnarToRow) + - "fallback": spark.comet.exec.pyarrowUdf.enabled=false + (plan should contain vanilla PythonMapInArrow / MapInArrow) + +Usage: + # Build Comet first: + make + + # Then either let the test discover the jar from spark/target, or pass it + # explicitly via COMET_JAR: + export COMET_JAR=$PWD/spark/target/comet-spark-spark3.5_2.12-0.16.0-SNAPSHOT.jar + + pip install pyspark==3.5.8 pyarrow pandas pytest + pytest -v spark/src/test/resources/pyspark/test_pyarrow_udf.py +""" + +import datetime as dt +import os +from decimal import Decimal + +import pyarrow as pa +import pytest +from pyspark.sql import SparkSession, types as T + +from conftest import resolve_comet_jar + + +@pytest.fixture(scope="session") +def spark(): + jar = resolve_comet_jar() + # PYSPARK_SUBMIT_ARGS is consumed when pyspark launches its JVM. Setting + # --jars puts the Comet jar on both driver and executor classpaths so the + # CometPlugin can be loaded. + os.environ["PYSPARK_SUBMIT_ARGS"] = ( + f"--jars {jar} --driver-class-path {jar} pyspark-shell" + ) + session = ( + SparkSession.builder.master("local[2]") + .appName("comet-pyarrow-udf-tests") + .config("spark.plugins", "org.apache.spark.CometPlugin") + .config("spark.comet.enabled", "true") + .config("spark.comet.exec.enabled", "true") + .config("spark.memory.offHeap.enabled", "true") + .config("spark.memory.offHeap.size", "2g") + .getOrCreate() + ) + try: + yield session + finally: + session.stop() + + +@pytest.fixture(params=[True, False], ids=["accelerated", "fallback"]) +def accelerated(request, spark) -> bool: + spark.conf.set( + "spark.comet.exec.pyarrowUdf.enabled", + "true" if request.param else "false", + ) + return request.param + + +def _executed_plan(df) -> str: + return df._jdf.queryExecution().executedPlan().toString() + + +def _assert_plan_matches_mode( + plan: str, accelerated: bool, vanilla_node: str = "MapInArrow" +) -> None: + if accelerated: + assert "CometMapInBatch" in plan, ( + f"expected CometMapInBatch in accelerated plan, got:\n{plan}" + ) + assert "ColumnarToRow" not in plan, ( + f"unexpected ColumnarToRow in accelerated plan:\n{plan}" + ) + else: + assert "CometMapInBatch" not in plan, ( + f"unexpected CometMapInBatch in fallback plan:\n{plan}" + ) + assert vanilla_node in plan, ( + f"expected {vanilla_node} in fallback plan, got:\n{plan}" + ) + + +def test_map_in_arrow_doubles_value(spark, tmp_path, accelerated): + data = [(i, float(i * 1.5), f"name_{i}") for i in range(100)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(data, ["id", "value", "name"]).write.parquet(src) + + def double_value(iterator): + for batch in iterator: + pdf = batch.to_pandas() + pdf["value"] = pdf["value"] * 2 + yield pa.RecordBatch.from_pandas(pdf) + + schema = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + T.StructField("name", T.StringType()), + ] + ) + result_df = spark.read.parquet(src).mapInArrow(double_value, schema) + + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + rows = result_df.orderBy("id").collect() + assert len(rows) == len(data) + for row, original in zip(rows, data): + assert row["id"] == original[0] + assert abs(row["value"] - original[1] * 2) < 1e-6 + assert row["name"] == original[2] + + +# All other tests use the default `vanilla_node="MapInArrow"`. The mapInPandas tests below +# pass `MapInPandas` explicitly. The substring is the same on Spark 3.5 (PythonMapInArrowExec) +# and Spark 4.x (MapInArrowExec) since the latter is a substring of the former. + + +def test_map_in_arrow_changes_schema(spark, tmp_path, accelerated): + data = [(i, float(i)) for i in range(50)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(data, ["id", "value"]).write.parquet(src) + + def add_computed_column(iterator): + for batch in iterator: + pdf = batch.to_pandas() + pdf["squared"] = pdf["value"] ** 2 + pdf["label"] = pdf["id"].apply(lambda x: f"item_{x}") + yield pa.RecordBatch.from_pandas(pdf) + + schema = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + T.StructField("squared", T.DoubleType()), + T.StructField("label", T.StringType()), + ] + ) + result_df = spark.read.parquet(src).mapInArrow(add_computed_column, schema) + + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + rows = result_df.orderBy("id").collect() + assert len(rows) == 50 + for i, row in enumerate(rows): + assert abs(row["squared"] - float(i) ** 2) < 1e-6 + assert row["label"] == f"item_{i}" + + +def test_map_in_pandas_doubles_value(spark, tmp_path, accelerated): + data = [(i, float(i * 1.5)) for i in range(100)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(data, ["id", "value"]).write.parquet(src) + + def double_value(iterator): + for pdf in iterator: + pdf = pdf.copy() + pdf["value"] = pdf["value"] * 2 + yield pdf + + schema = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + ] + ) + result_df = spark.read.parquet(src).mapInPandas(double_value, schema) + + _assert_plan_matches_mode( + _executed_plan(result_df), accelerated, vanilla_node="MapInPandas" + ) + + rows = result_df.orderBy("id").collect() + assert len(rows) == len(data) + for row, original in zip(rows, data): + assert row["id"] == original[0] + assert abs(row["value"] - original[1] * 2) < 1e-6 + + +def test_map_in_pandas_changes_schema(spark, tmp_path, accelerated): + data = [(i, float(i)) for i in range(50)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(data, ["id", "value"]).write.parquet(src) + + def add_squared(iterator): + for pdf in iterator: + pdf = pdf.copy() + pdf["squared"] = pdf["value"] ** 2 + yield pdf + + schema = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + T.StructField("squared", T.DoubleType()), + ] + ) + result_df = spark.read.parquet(src).mapInPandas(add_squared, schema) + + _assert_plan_matches_mode( + _executed_plan(result_df), accelerated, vanilla_node="MapInPandas" + ) + + rows = result_df.orderBy("id").collect() + assert len(rows) == 50 + for i, row in enumerate(rows): + assert abs(row["squared"] - float(i) ** 2) < 1e-6 + + +def test_map_in_arrow_preserves_nulls(spark, tmp_path, accelerated): + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("name", T.StringType()), + ] + ) + rows = [ + (1, "a"), + (2, None), + (None, "c"), + (None, None), + (5, "e"), + ] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + # Pure Arrow passthrough so nulls survive without a pandas roundtrip + # (pandas would coerce null longs to NaN floats). + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + out = {(r["id"], r["name"]) for r in result_df.collect()} + assert out == set(rows) + + +def test_map_in_arrow_empty_input(spark, tmp_path, accelerated): + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + ] + ) + src = str(tmp_path / "src.parquet") + spark.createDataFrame([(1, 1.0), (2, 2.0)], schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + # Filter all rows out so the operator sees an empty stream from CometScan. + result_df = ( + spark.read.parquet(src).where("id < 0").mapInArrow(passthrough, schema_in) + ) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + assert result_df.count() == 0 + + +def test_map_in_arrow_python_exception_propagates(spark, tmp_path, accelerated): + schema_in = T.StructType([T.StructField("id", T.LongType())]) + data = [(i,) for i in range(10)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(data, schema_in).write.parquet(src) + + sentinel = "boom-from-pyarrow-udf" + + def boom(iterator): + for _batch in iterator: + raise ValueError(sentinel) + # Unreachable, but mapInArrow requires the callable to be a generator. + yield # pragma: no cover + + result_df = spark.read.parquet(src).mapInArrow(boom, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + with pytest.raises(Exception) as exc_info: + result_df.collect() + assert sentinel in str(exc_info.value), ( + f"expected sentinel {sentinel!r} in exception, got: {exc_info.value}" + ) + + +def test_map_in_arrow_decimal_type(spark, tmp_path, accelerated): + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("amount", T.DecimalType(18, 6)), + ] + ) + rows = [ + (1, Decimal("123.456789")), + (2, Decimal("0.000001")), + (3, Decimal("-99999999.999999")), + (4, None), + ] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + out = {(r["id"], r["amount"]) for r in result_df.collect()} + assert out == set(rows) + + +def test_map_in_arrow_date_and_timestamp(spark, tmp_path, accelerated): + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("d", T.DateType()), + T.StructField("ts", T.TimestampType()), + ] + ) + rows = [ + (1, dt.date(2024, 1, 1), dt.datetime(2024, 1, 1, 12, 30, 45)), + (2, dt.date(1999, 12, 31), dt.datetime(2000, 6, 15, 0, 0, 0)), + (3, None, None), + ] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + out = {(r["id"], r["d"], r["ts"]) for r in result_df.collect()} + assert out == set(rows) + + +def test_map_in_arrow_array_and_struct(spark, tmp_path, accelerated): + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("nums", T.ArrayType(T.IntegerType())), + T.StructField( + "addr", + T.StructType( + [ + T.StructField("city", T.StringType()), + T.StructField("zip", T.IntegerType()), + ] + ), + ), + ] + ) + rows = [ + (1, [1, 2, 3], ("Berlin", 10115)), + (2, [], ("NYC", 10001)), + (3, None, None), + (4, [None, 5], ("Tokyo", None)), + ] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + def _normalize(row): + nums = tuple(row["nums"]) if row["nums"] is not None else None + addr = row["addr"] + addr_tuple = (addr["city"], addr["zip"]) if addr is not None else None + return (row["id"], nums, addr_tuple) + + out = {_normalize(r) for r in result_df.collect()} + expected = { + (r[0], tuple(r[1]) if r[1] is not None else None, r[2]) for r in rows + } + assert out == expected + + +def test_map_in_arrow_after_shuffle(spark, tmp_path, accelerated): + """ + Verifies correctness when a shuffle sits between the Comet scan and the + Python UDF. Without `spark.shuffle.manager` configured at session startup + the shuffle stays a vanilla `Exchange`, which is not columnar, so the + optimization does not fire across it today. This test does not assert on + the plan; it only ensures the path produces correct results in both modes + so a future change that wires Comet shuffle into the optimization does + not silently break correctness. + """ + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + ] + ) + rows = [(i, float(i)) for i in range(50)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = ( + spark.read.parquet(src) + .repartition(4, "id") + .mapInArrow(passthrough, schema_in) + ) + + out = sorted((r["id"], r["value"]) for r in result_df.collect()) + assert out == sorted(rows) + + +def test_map_in_arrow_barrier_mode(spark, tmp_path, accelerated): + """ + `mapInArrow(..., barrier=True)` runs the stage in barrier execution mode + (gang scheduling, all-or-nothing failure semantics, BarrierTaskContext + available inside the UDF). The optimization captures isBarrier in the + operator constructor and must propagate it through to RDD.barrier(); + otherwise the runtime context the UDF sees changes when the optimization + fires and any code calling BarrierTaskContext APIs breaks. + """ + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + ] + ) + rows = [(i, float(i)) for i in range(20)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def assert_barrier_context(iterator): + from pyspark import BarrierTaskContext + + # Will raise if the task is not running inside a barrier stage. + BarrierTaskContext.get() + for batch in iterator: + yield batch + + result_df = ( + spark.read.parquet(src).mapInArrow( + assert_barrier_context, schema_in, barrier=True + ) + ) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + out = sorted((r["id"], r["value"]) for r in result_df.collect()) + assert out == sorted(rows) diff --git a/spark/src/test/spark-3.5/org/apache/spark/sql/comet/CometMapInBatchSuite.scala b/spark/src/test/spark-3.5/org/apache/spark/sql/comet/CometMapInBatchSuite.scala new file mode 100644 index 0000000000..af960c5c97 --- /dev/null +++ b/spark/src/test/spark-3.5/org/apache/spark/sql/comet/CometMapInBatchSuite.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import org.apache.spark.api.python.{PythonAccumulatorV2, PythonBroadcast, PythonEvalType, PythonFunction} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, ExprId, PythonUDF} +import org.apache.spark.sql.execution.{ColumnarToRowExec, LeafExecNode} +import org.apache.spark.sql.execution.python.PythonMapInArrowExec +import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.CometConf +import org.apache.comet.rules.EliminateRedundantTransitions + +/** Minimal CometPlan leaf used to anchor the rule's transform without triggering execution. */ +private case class StubCometLeaf(override val output: Seq[Attribute]) + extends LeafExecNode + with CometPlan { + override def supportsColumnar: Boolean = true + override protected def doExecute(): RDD[InternalRow] = + throw new UnsupportedOperationException + override protected def doExecuteColumnar(): RDD[ColumnarBatch] = + throw new UnsupportedOperationException +} + +/** + * Plan-rule test for the `EliminateRedundantTransitions` rewrite that produces + * `CometMapInBatchExec`. Pure Python execution paths are covered by the pytest module + * `test_pyarrow_udf.py`; this suite verifies the JVM-side rule without spinning up Python. + * + * Lives under `org.apache.spark.sql.comet` so it can reference Spark's `private[spark]` + * `PythonFunction` / `PythonAccumulatorV2` / `PythonBroadcast` classes when fabricating a stub + * `PythonUDF` for `PythonMapInArrowExec` to wrap. + */ +class CometMapInBatchSuite extends CometTestBase { + + private def stubPythonUDF: PythonUDF = { + val pyFunc = new PythonFunction { + override val command: Seq[Byte] = Seq.empty[Byte] + override val envVars: java.util.Map[String, String] = + new java.util.HashMap[String, String]() + override val pythonIncludes: java.util.List[String] = + java.util.Collections.emptyList[String]() + override val pythonExec: String = "python3" + override val pythonVer: String = "3" + override val broadcastVars: java.util.List[Broadcast[PythonBroadcast]] = + java.util.Collections.emptyList[Broadcast[PythonBroadcast]]() + override val accumulator: PythonAccumulatorV2 = null + } + PythonUDF( + name = "test_udf", + func = pyFunc, + dataType = StructType(Seq(StructField("id", LongType))), + children = Seq(AttributeReference("id", LongType)(ExprId(0L))), + evalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF, + udfDeterministic = true) + } + + private def buildPlan(): PythonMapInArrowExec = { + val cometChild = StubCometLeaf(Seq(AttributeReference("id", LongType)(ExprId(0L)))) + PythonMapInArrowExec( + stubPythonUDF, + cometChild.output, + ColumnarToRowExec(cometChild), + isBarrier = false) + } + + test("rule rewrites PythonMapInArrowExec over Comet to CometMapInBatchExec") { + withSQLConf(CometConf.COMET_PYARROW_UDF_ENABLED.key -> "true") { + val rewritten = EliminateRedundantTransitions(spark).apply(buildPlan()) + assert( + rewritten.exists(_.isInstanceOf[CometMapInBatchExec]), + s"expected CometMapInBatchExec in rewritten plan:\n$rewritten") + } + } + + test("rule does not rewrite when feature is disabled") { + withSQLConf(CometConf.COMET_PYARROW_UDF_ENABLED.key -> "false") { + val rewritten = EliminateRedundantTransitions(spark).apply(buildPlan()) + assert( + !rewritten.exists(_.isInstanceOf[CometMapInBatchExec]), + s"unexpected CometMapInBatchExec when disabled:\n$rewritten") + } + } +} diff --git a/spark/src/test/spark-4.x/org/apache/spark/sql/comet/CometMapInBatchSuite.scala b/spark/src/test/spark-4.x/org/apache/spark/sql/comet/CometMapInBatchSuite.scala new file mode 100644 index 0000000000..5ab0b927a2 --- /dev/null +++ b/spark/src/test/spark-4.x/org/apache/spark/sql/comet/CometMapInBatchSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import org.apache.spark.api.python.{PythonAccumulatorV2, PythonBroadcast, PythonEvalType, PythonFunction} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, ExprId, PythonUDF} +import org.apache.spark.sql.execution.{ColumnarToRowExec, LeafExecNode} +import org.apache.spark.sql.execution.python.MapInArrowExec +import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.CometConf +import org.apache.comet.rules.EliminateRedundantTransitions + +/** Minimal CometPlan leaf used to anchor the rule's transform without triggering execution. */ +private case class StubCometLeaf(override val output: Seq[Attribute]) + extends LeafExecNode + with CometPlan { + override def supportsColumnar: Boolean = true + override protected def doExecute(): RDD[InternalRow] = + throw new UnsupportedOperationException + override protected def doExecuteColumnar(): RDD[ColumnarBatch] = + throw new UnsupportedOperationException +} + +/** + * Plan-rule test for the `EliminateRedundantTransitions` rewrite that produces + * `CometMapInBatchExec`. Pure Python execution paths are covered by the pytest module + * `test_pyarrow_udf.py`; this suite verifies the JVM-side rule without spinning up Python. + * + * Lives under `org.apache.spark.sql.comet` so it can reference Spark's `private[spark]` + * `PythonFunction` / `PythonAccumulatorV2` / `PythonBroadcast` classes when fabricating a stub + * `PythonUDF` for `MapInArrowExec` to wrap. + */ +class CometMapInBatchSuite extends CometTestBase { + + private def stubPythonUDF: PythonUDF = { + val pyFunc = new PythonFunction { + override val command: Seq[Byte] = Seq.empty[Byte] + override val envVars: java.util.Map[String, String] = + new java.util.HashMap[String, String]() + override val pythonIncludes: java.util.List[String] = + java.util.Collections.emptyList[String]() + override val pythonExec: String = "python3" + override val pythonVer: String = "3" + override val broadcastVars: java.util.List[Broadcast[PythonBroadcast]] = + java.util.Collections.emptyList[Broadcast[PythonBroadcast]]() + override val accumulator: PythonAccumulatorV2 = null + } + PythonUDF( + name = "test_udf", + func = pyFunc, + dataType = StructType(Seq(StructField("id", LongType))), + children = Seq(AttributeReference("id", LongType)(ExprId(0L))), + evalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF, + udfDeterministic = true) + } + + private def buildPlan(): MapInArrowExec = { + val cometChild = StubCometLeaf(Seq(AttributeReference("id", LongType)(ExprId(0L)))) + MapInArrowExec( + stubPythonUDF, + cometChild.output, + ColumnarToRowExec(cometChild), + isBarrier = false, + profile = None) + } + + test("rule rewrites MapInArrowExec over Comet to CometMapInBatchExec") { + withSQLConf(CometConf.COMET_PYARROW_UDF_ENABLED.key -> "true") { + val rewritten = EliminateRedundantTransitions(spark).apply(buildPlan()) + assert( + rewritten.exists(_.isInstanceOf[CometMapInBatchExec]), + s"expected CometMapInBatchExec in rewritten plan:\n$rewritten") + } + } + + test("rule does not rewrite when feature is disabled") { + withSQLConf(CometConf.COMET_PYARROW_UDF_ENABLED.key -> "false") { + val rewritten = EliminateRedundantTransitions(spark).apply(buildPlan()) + assert( + !rewritten.exists(_.isInstanceOf[CometMapInBatchExec]), + s"unexpected CometMapInBatchExec when disabled:\n$rewritten") + } + } +}