diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index b0f09bc43b..fd92231414 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -393,6 +393,7 @@ jobs: org.apache.comet.expressions.conditional.CometIfSuite org.apache.comet.expressions.conditional.CometCoalesceSuite org.apache.comet.expressions.conditional.CometCaseWhenSuite + org.apache.comet.CometRegExpJvmSuite - name: "sql" value: | org.apache.spark.sql.CometToPrettyStringSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index c743d1888a..20209b7c39 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -231,6 +231,7 @@ jobs: org.apache.comet.expressions.conditional.CometIfSuite org.apache.comet.expressions.conditional.CometCoalesceSuite org.apache.comet.expressions.conditional.CometCaseWhenSuite + org.apache.comet.CometRegExpJvmSuite - name: "sql" value: | org.apache.spark.sql.CometToPrettyStringSuite diff --git a/.gitignore b/.gitignore index a3c97ff992..eed00f3262 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ output docs/comet-*/ docs/build/ docs/temp/ +docs/superpowers/ diff --git a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java new file mode 100644 index 0000000000..24d7bcfbd0 --- /dev/null +++ b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf; + +import java.util.LinkedHashMap; +import java.util.Map; + +import org.apache.arrow.c.ArrowArray; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; + +/** + * JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method + * pattern used by CometScalarSubquery so the native side can dispatch via + * call_static_method_unchecked. + */ +public class CometUdfBridge { + + // Per-thread, bounded LRU of UDF instances keyed by class name. Comet + // native execution threads (Tokio/DataFusion worker pool) are reused + // across tasks within an executor, so the effective lifetime of cached + // entries is the worker thread (i.e. the executor JVM). This is fine for + // stateless UDFs like RegExpLikeUDF; future stateful UDFs would need + // explicit per-task isolation. + private static final int CACHE_CAPACITY = 64; + + private static final ThreadLocal> INSTANCES = + ThreadLocal.withInitial( + () -> + new LinkedHashMap(CACHE_CAPACITY, 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > CACHE_CAPACITY; + } + }); + + /** + * Called from native via JNI. + * + * @param udfClassName fully-qualified class name implementing CometUDF + * @param inputArrayPtrs addresses of pre-allocated FFI_ArrowArray structs (one per input) + * @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs (one per input) + * @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result + * @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the result + */ + public static void evaluate( + String udfClassName, + long[] inputArrayPtrs, + long[] inputSchemaPtrs, + long outArrayPtr, + long outSchemaPtr) { + LinkedHashMap cache = INSTANCES.get(); + CometUDF udf = cache.get(udfClassName); + if (udf == null) { + try { + // Resolve via the executor's context classloader so user-supplied UDF jars + // (added via spark.jars / --jars) are visible. + ClassLoader cl = Thread.currentThread().getContextClassLoader(); + if (cl == null) { + cl = CometUdfBridge.class.getClassLoader(); + } + udf = + (CometUDF) Class.forName(udfClassName, true, cl).getDeclaredConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new RuntimeException("Failed to instantiate CometUDF: " + udfClassName, e); + } + cache.put(udfClassName, udf); + } + + BufferAllocator allocator = org.apache.comet.package$.MODULE$.CometArrowAllocator(); + + ValueVector[] inputs = new ValueVector[inputArrayPtrs.length]; + ValueVector result = null; + try { + for (int i = 0; i < inputArrayPtrs.length; i++) { + ArrowArray inArr = ArrowArray.wrap(inputArrayPtrs[i]); + ArrowSchema inSch = ArrowSchema.wrap(inputSchemaPtrs[i]); + inputs[i] = Data.importVector(allocator, inArr, inSch, null); + } + + result = udf.evaluate(inputs); + if (!(result instanceof FieldVector)) { + throw new RuntimeException( + "CometUDF.evaluate() must return a FieldVector, got: " + result.getClass().getName()); + } + // Result length must match the longest input. Scalar (length-1) inputs + // are allowed to be shorter, but a vector input bounds the output. + int expectedLen = 0; + for (ValueVector v : inputs) { + expectedLen = Math.max(expectedLen, v.getValueCount()); + } + if (result.getValueCount() != expectedLen) { + throw new RuntimeException( + "CometUDF.evaluate() returned " + + result.getValueCount() + + " rows, expected " + + expectedLen); + } + ArrowArray outArr = ArrowArray.wrap(outArrayPtr); + ArrowSchema outSch = ArrowSchema.wrap(outSchemaPtr); + Data.exportVector(allocator, (FieldVector) result, null, outArr, outSch); + } finally { + for (ValueVector v : inputs) { + if (v != null) { + try { + v.close(); + } catch (RuntimeException ignored) { + // do not mask the original throwable + } + } + } + if (result != null) { + try { + result.close(); + } catch (RuntimeException ignored) { + // do not mask the original throwable + } + } + } + } +} diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index d3f51dfbe2..0c0248ad64 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -383,6 +383,24 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) + val REGEXP_ENGINE_RUST = "rust" + val REGEXP_ENGINE_JAVA = "java" + + val COMET_REGEXP_ENGINE: ConfigEntry[String] = + conf("spark.comet.exec.regexp.engine") + .category(CATEGORY_EXEC) + .doc( + "Experimental. Selects the engine used to evaluate supported regular-expression " + + s"expressions. `$REGEXP_ENGINE_RUST` uses the native DataFusion regexp engine. " + + s"`$REGEXP_ENGINE_JAVA` routes through a JVM-side UDF (java.util.regex.Pattern) for " + + "Spark-compatible semantics, at the cost of JNI roundtrips per batch. Expressions " + + "routed when set to java: rlike, regexp_extract, regexp_extract_all, regexp_replace, " + + "regexp_instr, and split.") + .stringConf + .transform(_.toLowerCase(Locale.ROOT)) + .checkValues(Set(REGEXP_ENGINE_RUST, REGEXP_ENGINE_JAVA)) + .createWithDefault(REGEXP_ENGINE_JAVA) + val COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.native.shuffle.partitioning.hash.enabled") .category(CATEGORY_SHUFFLE) diff --git a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala new file mode 100644 index 0000000000..ac7b72a883 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import org.apache.arrow.vector.ValueVector + +/** + * Scalar UDF invoked from native execution via JNI. Receives Arrow vectors as input and returns + * an Arrow vector. + * + * - Vector arguments arrive at the row count of the current batch. + * - Scalar (literal-folded) arguments arrive as length-1 vectors and must be read at index 0. + * - The returned vector's length must match the longest input. + * + * Implementations must have a public no-arg constructor and should be stateless: instances are + * cached per executor thread for the lifetime of the JVM. + */ +trait CometUDF { + def evaluate(inputs: Array[ValueVector]): ValueVector +} diff --git a/common/src/main/scala/org/apache/comet/udf/RegExpExtractAllUDF.scala b/common/src/main/scala/org/apache/comet/udf/RegExpExtractAllUDF.scala new file mode 100644 index 0000000000..8756e5c292 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/RegExpExtractAllUDF.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import java.nio.charset.StandardCharsets +import java.util +import java.util.regex.Pattern + +import org.apache.arrow.vector.{IntVector, ValueVector, VarCharVector} +import org.apache.arrow.vector.complex.ListVector +import org.apache.arrow.vector.types.pojo.{ArrowType, FieldType} + +import org.apache.comet.CometArrowAllocator + +/** + * `regexp_extract_all(subject, pattern, idx)` implemented with java.util.regex.Pattern. + * + * Returns an array of strings: for every match of pattern in subject, extracts the idx-th + * capturing group. idx=0 returns the entire match. + * + * Inputs: + * - inputs(0): VarCharVector subject column + * - inputs(1): VarCharVector pattern (scalar, length-1) + * - inputs(2): IntVector group index (scalar, length-1) + * + * Output: ListVector of VarChar, same length as subject. + */ +class RegExpExtractAllUDF extends CometUDF { + + private val patternCache = + new util.LinkedHashMap[String, Pattern]( + RegExpExtractAllUDF.PatternCacheCapacity, + 0.75f, + true) { + override def removeEldestEntry(eldest: util.Map.Entry[String, Pattern]): Boolean = + size() > RegExpExtractAllUDF.PatternCacheCapacity + } + + override def evaluate(inputs: Array[ValueVector]): ValueVector = { + require(inputs.length == 3, s"RegExpExtractAllUDF expects 3 inputs, got ${inputs.length}") + val subject = inputs(0).asInstanceOf[VarCharVector] + val patternVec = inputs(1).asInstanceOf[VarCharVector] + val idxVec = inputs(2).asInstanceOf[IntVector] + require( + patternVec.getValueCount >= 1 && !patternVec.isNull(0), + "RegExpExtractAllUDF requires a non-null scalar pattern") + require( + idxVec.getValueCount >= 1 && !idxVec.isNull(0), + "RegExpExtractAllUDF requires a non-null scalar group index") + + val patternStr = new String(patternVec.get(0), StandardCharsets.UTF_8) + val pattern = { + val cached = patternCache.get(patternStr) + if (cached != null) cached + else { + val compiled = Pattern.compile(patternStr) + patternCache.put(patternStr, compiled) + compiled + } + } + val idx = idxVec.get(0) + + val n = subject.getValueCount + val out = ListVector.empty("regexp_extract_all_result", CometArrowAllocator) + out.addOrGetVector[VarCharVector](new FieldType(true, ArrowType.Utf8.INSTANCE, null)) + val writer = out.getWriter + + var i = 0 + while (i < n) { + if (subject.isNull(i)) { + out.setNull(i) + } else { + val s = new String(subject.get(i), StandardCharsets.UTF_8) + val matcher = pattern.matcher(s) + writer.setPosition(i) + writer.startList() + while (matcher.find()) { + if (idx <= matcher.groupCount()) { + val group = matcher.group(idx) + val bytes = + if (group == null) "".getBytes(StandardCharsets.UTF_8) + else group.getBytes(StandardCharsets.UTF_8) + val buf = CometArrowAllocator.buffer(bytes.length) + buf.writeBytes(bytes) + writer.varChar().writeVarChar(0, bytes.length, buf) + buf.close() + } else { + val bytes = "".getBytes(StandardCharsets.UTF_8) + val buf = CometArrowAllocator.buffer(bytes.length) + buf.writeBytes(bytes) + writer.varChar().writeVarChar(0, bytes.length, buf) + buf.close() + } + } + writer.endList() + } + i += 1 + } + out.setValueCount(n) + out + } +} + +object RegExpExtractAllUDF { + private val PatternCacheCapacity: Int = 128 +} diff --git a/common/src/main/scala/org/apache/comet/udf/RegExpExtractUDF.scala b/common/src/main/scala/org/apache/comet/udf/RegExpExtractUDF.scala new file mode 100644 index 0000000000..d46ca03252 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/RegExpExtractUDF.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import java.nio.charset.StandardCharsets +import java.util +import java.util.regex.Pattern + +import org.apache.arrow.vector.{IntVector, ValueVector, VarCharVector} + +import org.apache.comet.CometArrowAllocator + +/** + * `regexp_extract(subject, pattern, idx)` implemented with java.util.regex.Pattern. + * + * Returns the string matching the idx-th capturing group of the first match, or empty string if + * no match. idx=0 returns the entire match. + * + * Inputs: + * - inputs(0): VarCharVector subject column + * - inputs(1): VarCharVector pattern (scalar, length-1) + * - inputs(2): IntVector group index (scalar, length-1) + * + * Output: VarCharVector, same length as subject. + */ +class RegExpExtractUDF extends CometUDF { + + private val patternCache = + new util.LinkedHashMap[String, Pattern](RegExpExtractUDF.PatternCacheCapacity, 0.75f, true) { + override def removeEldestEntry(eldest: util.Map.Entry[String, Pattern]): Boolean = + size() > RegExpExtractUDF.PatternCacheCapacity + } + + override def evaluate(inputs: Array[ValueVector]): ValueVector = { + require(inputs.length == 3, s"RegExpExtractUDF expects 3 inputs, got ${inputs.length}") + val subject = inputs(0).asInstanceOf[VarCharVector] + val patternVec = inputs(1).asInstanceOf[VarCharVector] + val idxVec = inputs(2).asInstanceOf[IntVector] + require( + patternVec.getValueCount >= 1 && !patternVec.isNull(0), + "RegExpExtractUDF requires a non-null scalar pattern") + require( + idxVec.getValueCount >= 1 && !idxVec.isNull(0), + "RegExpExtractUDF requires a non-null scalar group index") + + val patternStr = new String(patternVec.get(0), StandardCharsets.UTF_8) + val pattern = { + val cached = patternCache.get(patternStr) + if (cached != null) cached + else { + val compiled = Pattern.compile(patternStr) + patternCache.put(patternStr, compiled) + compiled + } + } + val idx = idxVec.get(0) + + val n = subject.getValueCount + val out = new VarCharVector("regexp_extract_result", CometArrowAllocator) + out.allocateNew(n) + + var i = 0 + while (i < n) { + if (subject.isNull(i)) { + out.setNull(i) + } else { + val s = new String(subject.get(i), StandardCharsets.UTF_8) + val matcher = pattern.matcher(s) + if (matcher.find() && idx <= matcher.groupCount()) { + val group = matcher.group(idx) + if (group == null) { + out.setSafe(i, "".getBytes(StandardCharsets.UTF_8)) + } else { + out.setSafe(i, group.getBytes(StandardCharsets.UTF_8)) + } + } else { + out.setSafe(i, "".getBytes(StandardCharsets.UTF_8)) + } + } + i += 1 + } + out.setValueCount(n) + out + } +} + +object RegExpExtractUDF { + private val PatternCacheCapacity: Int = 128 +} diff --git a/common/src/main/scala/org/apache/comet/udf/RegExpInStrUDF.scala b/common/src/main/scala/org/apache/comet/udf/RegExpInStrUDF.scala new file mode 100644 index 0000000000..4beee37e3d --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/RegExpInStrUDF.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import java.nio.charset.StandardCharsets +import java.util +import java.util.regex.Pattern + +import org.apache.arrow.vector.{IntVector, ValueVector, VarCharVector} + +import org.apache.comet.CometArrowAllocator + +/** + * `regexp_instr(subject, pattern, idx)` implemented with java.util.regex.Pattern. + * + * Returns the 1-based position of the start of the first match of the idx-th capturing group, or + * 0 if no match. idx=0 means the entire match. + * + * Inputs: + * - inputs(0): VarCharVector subject column + * - inputs(1): VarCharVector pattern (scalar, length-1) + * - inputs(2): IntVector group index (scalar, length-1) + * + * Output: IntVector, same length as subject. + */ +class RegExpInStrUDF extends CometUDF { + + private val patternCache = + new util.LinkedHashMap[String, Pattern](RegExpInStrUDF.PatternCacheCapacity, 0.75f, true) { + override def removeEldestEntry(eldest: util.Map.Entry[String, Pattern]): Boolean = + size() > RegExpInStrUDF.PatternCacheCapacity + } + + override def evaluate(inputs: Array[ValueVector]): ValueVector = { + require(inputs.length == 3, s"RegExpInStrUDF expects 3 inputs, got ${inputs.length}") + val subject = inputs(0).asInstanceOf[VarCharVector] + val patternVec = inputs(1).asInstanceOf[VarCharVector] + val idxVec = inputs(2).asInstanceOf[IntVector] + require( + patternVec.getValueCount >= 1 && !patternVec.isNull(0), + "RegExpInStrUDF requires a non-null scalar pattern") + require( + idxVec.getValueCount >= 1 && !idxVec.isNull(0), + "RegExpInStrUDF requires a non-null scalar group index") + + val patternStr = new String(patternVec.get(0), StandardCharsets.UTF_8) + val pattern = { + val cached = patternCache.get(patternStr) + if (cached != null) cached + else { + val compiled = Pattern.compile(patternStr) + patternCache.put(patternStr, compiled) + compiled + } + } + val idx = idxVec.get(0) + + val n = subject.getValueCount + val out = new IntVector("regexp_instr_result", CometArrowAllocator) + out.allocateNew(n) + + var i = 0 + while (i < n) { + if (subject.isNull(i)) { + out.setNull(i) + } else { + val s = new String(subject.get(i), StandardCharsets.UTF_8) + val matcher = pattern.matcher(s) + if (matcher.find()) { + // Spark regexp_instr always returns 1-based position of the entire match start + out.set(i, matcher.start() + 1) + } else { + out.set(i, 0) + } + } + i += 1 + } + out.setValueCount(n) + out + } +} + +object RegExpInStrUDF { + private val PatternCacheCapacity: Int = 128 +} diff --git a/common/src/main/scala/org/apache/comet/udf/RegExpLikeUDF.scala b/common/src/main/scala/org/apache/comet/udf/RegExpLikeUDF.scala new file mode 100644 index 0000000000..35ad1816e3 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/RegExpLikeUDF.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import java.nio.charset.StandardCharsets +import java.util +import java.util.regex.Pattern + +import org.apache.arrow.vector.{BitVector, ValueVector, VarCharVector} + +import org.apache.comet.CometArrowAllocator + +/** + * `regexp` / `RLike` implemented with java.util.regex.Pattern (Java semantics). + * + * Inputs: + * - inputs(0): VarCharVector subject column + * - inputs(1): VarCharVector pattern, 1-row scalar (serde guarantees this) + * + * Output: BitVector (Arrow boolean), same length as the subject vector. + */ +class RegExpLikeUDF extends CometUDF { + + // Bounded LRU so a workload with many distinct patterns does not retain + // Pattern objects for the executor's lifetime. + private val patternCache = + new util.LinkedHashMap[String, Pattern](RegExpLikeUDF.PatternCacheCapacity, 0.75f, true) { + override def removeEldestEntry(eldest: util.Map.Entry[String, Pattern]): Boolean = + size() > RegExpLikeUDF.PatternCacheCapacity + } + + override def evaluate(inputs: Array[ValueVector]): ValueVector = { + require(inputs.length == 2, s"RegExpLikeUDF expects 2 inputs, got ${inputs.length}") + val subject = inputs(0).asInstanceOf[VarCharVector] + val patternVec = inputs(1).asInstanceOf[VarCharVector] + require( + patternVec.getValueCount >= 1 && !patternVec.isNull(0), + "RegExpLikeUDF requires a non-null scalar pattern") + + val patternStr = new String(patternVec.get(0), StandardCharsets.UTF_8) + val pattern = { + val cached = patternCache.get(patternStr) + if (cached != null) cached + else { + val compiled = Pattern.compile(patternStr) + patternCache.put(patternStr, compiled) + compiled + } + } + + val n = subject.getValueCount + val out = new BitVector("rlike_result", CometArrowAllocator) + out.allocateNew(n) + + var i = 0 + while (i < n) { + if (subject.isNull(i)) { + out.setNull(i) + } else { + val s = new String(subject.get(i), StandardCharsets.UTF_8) + out.set(i, if (pattern.matcher(s).find()) 1 else 0) + } + i += 1 + } + out.setValueCount(n) + out + } +} + +object RegExpLikeUDF { + private val PatternCacheCapacity: Int = 128 +} diff --git a/common/src/main/scala/org/apache/comet/udf/RegExpReplaceUDF.scala b/common/src/main/scala/org/apache/comet/udf/RegExpReplaceUDF.scala new file mode 100644 index 0000000000..b6d06ecd64 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/RegExpReplaceUDF.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import java.nio.charset.StandardCharsets +import java.util +import java.util.regex.Pattern + +import org.apache.arrow.vector.{ValueVector, VarCharVector} + +import org.apache.comet.CometArrowAllocator + +/** + * `regexp_replace(subject, pattern, replacement)` implemented with java.util.regex.Pattern. + * + * Replaces all occurrences of pattern in subject with replacement. + * + * Inputs: + * - inputs(0): VarCharVector subject column + * - inputs(1): VarCharVector pattern (scalar, length-1) + * - inputs(2): VarCharVector replacement (scalar, length-1) + * + * Output: VarCharVector, same length as subject. + */ +class RegExpReplaceUDF extends CometUDF { + + private val patternCache = + new util.LinkedHashMap[String, Pattern](RegExpReplaceUDF.PatternCacheCapacity, 0.75f, true) { + override def removeEldestEntry(eldest: util.Map.Entry[String, Pattern]): Boolean = + size() > RegExpReplaceUDF.PatternCacheCapacity + } + + override def evaluate(inputs: Array[ValueVector]): ValueVector = { + require(inputs.length == 3, s"RegExpReplaceUDF expects 3 inputs, got ${inputs.length}") + val subject = inputs(0).asInstanceOf[VarCharVector] + val patternVec = inputs(1).asInstanceOf[VarCharVector] + val replacementVec = inputs(2).asInstanceOf[VarCharVector] + require( + patternVec.getValueCount >= 1 && !patternVec.isNull(0), + "RegExpReplaceUDF requires a non-null scalar pattern") + require( + replacementVec.getValueCount >= 1 && !replacementVec.isNull(0), + "RegExpReplaceUDF requires a non-null scalar replacement") + + val patternStr = new String(patternVec.get(0), StandardCharsets.UTF_8) + val pattern = { + val cached = patternCache.get(patternStr) + if (cached != null) cached + else { + val compiled = Pattern.compile(patternStr) + patternCache.put(patternStr, compiled) + compiled + } + } + val replacement = new String(replacementVec.get(0), StandardCharsets.UTF_8) + + val n = subject.getValueCount + val out = new VarCharVector("regexp_replace_result", CometArrowAllocator) + out.allocateNew(n) + + var i = 0 + while (i < n) { + if (subject.isNull(i)) { + out.setNull(i) + } else { + val s = new String(subject.get(i), StandardCharsets.UTF_8) + val result = pattern.matcher(s).replaceAll(replacement) + out.setSafe(i, result.getBytes(StandardCharsets.UTF_8)) + } + i += 1 + } + out.setValueCount(n) + out + } +} + +object RegExpReplaceUDF { + private val PatternCacheCapacity: Int = 128 +} diff --git a/common/src/main/scala/org/apache/comet/udf/StringSplitUDF.scala b/common/src/main/scala/org/apache/comet/udf/StringSplitUDF.scala new file mode 100644 index 0000000000..03fb7bb75f --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/StringSplitUDF.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import java.nio.charset.StandardCharsets +import java.util +import java.util.regex.Pattern + +import org.apache.arrow.vector.ValueVector +import org.apache.arrow.vector.VarCharVector +import org.apache.arrow.vector.complex.ListVector +import org.apache.arrow.vector.types.pojo.{ArrowType, FieldType} + +import org.apache.comet.CometArrowAllocator + +/** + * `split(subject, pattern, limit)` implemented with java.util.regex.Pattern. + * + * Splits the subject string around matches of the pattern, up to the specified limit. + * + * Inputs: + * - inputs(0): VarCharVector subject column + * - inputs(1): VarCharVector pattern (scalar, length-1) + * - inputs(2): IntVector limit (scalar, length-1) + * + * Output: ListVector of VarChar, same length as subject. + */ +class StringSplitUDF extends CometUDF { + + private val patternCache = + new util.LinkedHashMap[String, Pattern](StringSplitUDF.PatternCacheCapacity, 0.75f, true) { + override def removeEldestEntry(eldest: util.Map.Entry[String, Pattern]): Boolean = + size() > StringSplitUDF.PatternCacheCapacity + } + + override def evaluate(inputs: Array[ValueVector]): ValueVector = { + require(inputs.length == 3, s"StringSplitUDF expects 3 inputs, got ${inputs.length}") + val subject = inputs(0).asInstanceOf[VarCharVector] + val patternVec = inputs(1).asInstanceOf[VarCharVector] + val limitVec = inputs(2).asInstanceOf[org.apache.arrow.vector.IntVector] + require( + patternVec.getValueCount >= 1 && !patternVec.isNull(0), + "StringSplitUDF requires a non-null scalar pattern") + require( + limitVec.getValueCount >= 1 && !limitVec.isNull(0), + "StringSplitUDF requires a non-null scalar limit") + + val patternStr = new String(patternVec.get(0), StandardCharsets.UTF_8) + val pattern = { + val cached = patternCache.get(patternStr) + if (cached != null) cached + else { + val compiled = Pattern.compile(patternStr) + patternCache.put(patternStr, compiled) + compiled + } + } + val limit = limitVec.get(0) + + val n = subject.getValueCount + val out = ListVector.empty("string_split_result", CometArrowAllocator) + out.addOrGetVector[VarCharVector](new FieldType(true, ArrowType.Utf8.INSTANCE, null)) + val writer = out.getWriter + + var i = 0 + while (i < n) { + if (subject.isNull(i)) { + out.setNull(i) + } else { + val s = new String(subject.get(i), StandardCharsets.UTF_8) + // Spark semantics: limit <= 0 means no limit (split returns all) + val parts = if (limit <= 0) pattern.split(s, -1) else pattern.split(s, limit) + writer.setPosition(i) + writer.startList() + var j = 0 + while (j < parts.length) { + val bytes = parts(j).getBytes(StandardCharsets.UTF_8) + val buf = CometArrowAllocator.buffer(bytes.length) + buf.writeBytes(bytes) + writer.varChar().writeVarChar(0, bytes.length, buf) + buf.close() + j += 1 + } + writer.endList() + } + i += 1 + } + out.setValueCount(n) + out + } +} + +object StringSplitUDF { + private val PatternCacheCapacity: Int = 128 +} diff --git a/docs/source/user-guide/latest/compatibility/regex.md b/docs/source/user-guide/latest/compatibility/regex.md index 4d9d5b650c..0522ecc47c 100644 --- a/docs/source/user-guide/latest/compatibility/regex.md +++ b/docs/source/user-guide/latest/compatibility/regex.md @@ -19,6 +19,97 @@ under the License. # Regular Expressions -Comet uses the Rust regexp crate for evaluating regular expressions, and this has different behavior from Java's -regular expression engine. Comet will fall back to Spark for patterns that are known to produce different results, but -this can be overridden by setting `spark.comet.expression.regexp.allowIncompatible=true`. +Comet provides two regexp engines for evaluating regular expressions: a **Java engine** that calls back into +the JVM and a **Rust engine** that uses the Rust [`regex`] crate natively. The engine is selected with: + +``` +spark.comet.exec.regexp.engine=java # default +spark.comet.exec.regexp.engine=rust +``` + +## Choosing an engine + +| | Java engine | Rust engine | +|---|---|---| +| **Compatibility** | 100% compatible with Spark | Pattern-dependent differences | +| **Feature coverage** | All regexp expressions (`rlike`, `regexp_extract`, `regexp_extract_all`, `regexp_instr`, `regexp_replace`, `split`) | `rlike`, `regexp_replace`, `split` only | +| **Performance** | One JNI round-trip per batch (Arrow vectors stay columnar) | Fully native, no JNI overhead | +| **Pattern support** | All Java regex features (backreferences, lookaround, etc.) | Linear-time subset only | + +The **Java engine** (default) is recommended for correctness-sensitive workloads. It evaluates expressions by +passing Arrow vectors to a JVM-side UDF that uses `java.util.regex`, producing identical results to Spark for +all patterns. + +The **Rust engine** is faster but only supports a subset of patterns. When it encounters a pattern it cannot +handle, it falls back to Spark automatically. To opt in to native evaluation for patterns Comet considers +potentially incompatible, set: + +``` +spark.comet.expression.regexp.allowIncompatible=true +``` + +## Why the engines differ + +Java's `java.util.regex` is a backtracking engine in the Perl/PCRE family. It supports the full range of +features that style of engine provides, including some whose worst-case running time grows exponentially with +the input. + +Rust's [`regex`] crate is a finite-automaton engine in the [RE2] family. It deliberately omits features that +cannot be implemented with a guarantee of linear-time matching. In exchange, every pattern it does accept runs +in time linear in the size of the input. This is the same trade-off RE2, Go's `regexp`, and several other +engines make. + +The practical consequence is that Java accepts a strictly larger set of patterns than the Rust engine, and +several constructs that look the same in source have different semantics on the two sides. + +## Features supported by Java but not by the Rust engine + +Patterns that use any of the following will not compile in Comet's Rust engine and must run on Spark (or use +the Java engine): + +- **Backreferences** such as `\1`, `\2`, or `\k`. The Rust engine has no backtracking and cannot match + a previously captured group. +- **Lookaround**, including lookahead (`(?=...)`, `(?!...)`) and lookbehind (`(?<=...)`, `(?...)`). +- **Possessive quantifiers** (`*+`, `++`, `?+`, `{n,m}+`). Rust supports greedy and lazy quantifiers but not + possessive. +- **Embedded code, conditionals, and recursion** such as `(?(cond)yes|no)` or `(?R)`. Rust accepts none of + these. + +## Features that exist on both sides but behave differently + +Even where both engines accept a construct, the matching behavior is not always the same. + +- **Unicode-aware character classes.** In the Rust engine, `\d`, `\w`, `\s`, and `.` are Unicode-aware by + default, so `\d` matches every digit codepoint defined by Unicode rather than only `0`-`9`. Java's defaults + match ASCII only and require the `UNICODE_CHARACTER_CLASS` flag (or `(?U)` inline) to switch to Unicode + semantics. The same pattern can therefore match a different set of characters on each side. +- **Line terminators.** In multiline mode, Java treats `\r`, `\n`, `\r\n`, and a few additional Unicode line + separators as line boundaries by default. The Rust engine treats only `\n` as a line boundary unless CRLF + mode is enabled. `^`, `$`, and `.` (with `(?s)` off) all depend on this definition. +- **Case-insensitive matching.** Both engines support `(?i)`, but Java's default is ASCII case folding while + the Rust engine uses full Unicode simple case folding when Unicode mode is on. Patterns that match characters + outside ASCII can produce different results. +- **POSIX character classes.** The Rust engine supports `[[:alpha:]]` style POSIX classes inside bracket + expressions but not Java's `\p{Alpha}` shorthand. Java accepts both. Unicode property escapes (`\p{L}`, + `\p{Greek}`, etc.) are supported by both engines but cover slightly different sets of properties. +- **Octal and Unicode escapes.** Java accepts `\0nnn` for octal and `\uXXXX` for a BMP codepoint. Rust uses + `\x{...}` for arbitrary codepoints and does not accept Java's bare `\uXXXX` form. +- **Empty matches in `split`.** Spark's `StringSplit`, which is built on Java's regex, includes leading empty + strings produced by zero-width matches at the start of the input. The Rust engine's `split` follows different + rules, so split results can differ in edge cases involving empty matches even when the pattern itself is + identical on both sides. + +## When the Rust engine is safe + +For most ASCII-only, non-anchored patterns that use only literal characters, simple character classes, and +ordinary quantifiers, the two engines produce the same results. If you are confident your patterns fit this +shape and want to avoid the JNI overhead of the Java engine, switching to the Rust engine with +`allowIncompatible=true` is generally safe. + +For anything that uses backreferences, lookaround, or relies on Java's specific Unicode or line-handling +defaults, use the Java engine (the default). + +[`java.util.regex`]: https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html +[`regex`]: https://docs.rs/regex/latest/regex/ +[RE2]: https://github.com/google/re2/wiki/Syntax diff --git a/native/Cargo.lock b/native/Cargo.lock index ae2d6b074c..75e84d851d 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -2116,8 +2116,10 @@ dependencies = [ "criterion", "datafusion", "datafusion-comet-common", + "datafusion-comet-jni-bridge", "futures", "hex", + "jni 0.22.4", "num", "rand 0.10.1", "regex", diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 844cc07c69..6019f168cc 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -122,10 +122,10 @@ use datafusion_comet_proto::{ spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }; use datafusion_comet_spark_expr::{ - ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, - DecimalRescaleCheckOverflow, GetArrayStructFields, GetStructField, IfExpr, ListExtract, - NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance, - WideDecimalBinaryExpr, WideDecimalOp, + jvm_udf::JvmScalarUdfExpr, ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, + Covariance, CreateNamedStruct, DecimalRescaleCheckOverflow, GetArrayStructFields, + GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal, + ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, }; use itertools::Itertools; use jni::objects::{Global, JObject}; @@ -701,6 +701,23 @@ impl PhysicalPlanner { expr.names.clone(), ))) } + ExprStruct::JvmScalarUdf(udf) => { + let args = udf + .args + .iter() + .map(|e| self.create_expr(e, Arc::clone(&input_schema))) + .collect::, _>>()?; + let return_type = + to_arrow_datatype(udf.return_type.as_ref().ok_or_else(|| { + GeneralError("JvmScalarUdf missing return_type".to_string()) + })?); + Ok(Arc::new(JvmScalarUdfExpr::new( + udf.class_name.clone(), + args, + return_type, + udf.return_nullable, + ))) + } expr => Err(GeneralError(format!("Not implemented: {expr:?}"))), } } diff --git a/native/jni-bridge/src/comet_udf_bridge.rs b/native/jni-bridge/src/comet_udf_bridge.rs new file mode 100644 index 0000000000..89cd8ee514 --- /dev/null +++ b/native/jni-bridge/src/comet_udf_bridge.rs @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use jni::{ + errors::Result as JniResult, + objects::{JClass, JStaticMethodID}, + signature::{Primitive, ReturnType}, + strings::JNIString, + Env, +}; + +/// JNI handle for the JVM `org.apache.comet.udf.CometUdfBridge` class. +/// Mirrors the static-method pattern in `comet_exec.rs` (`CometScalarSubquery`). +#[allow(dead_code)] // class field is held to keep JStaticMethodID alive +pub struct CometUdfBridge<'a> { + pub class: JClass<'a>, + pub method_evaluate: JStaticMethodID, + pub method_evaluate_ret: ReturnType, +} + +impl<'a> CometUdfBridge<'a> { + pub const JVM_CLASS: &'static str = "org/apache/comet/udf/CometUdfBridge"; + + pub fn new(env: &mut Env<'a>) -> JniResult> { + let class = env.find_class(JNIString::new(Self::JVM_CLASS))?; + Ok(CometUdfBridge { + method_evaluate: env.get_static_method_id( + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("evaluate"), + jni::jni_sig!("(Ljava/lang/String;[J[JJJ)V"), + )?, + method_evaluate_ret: ReturnType::Primitive(Primitive::Void), + class, + }) + } +} diff --git a/native/jni-bridge/src/lib.rs b/native/jni-bridge/src/lib.rs index 21c647135b..f95d3cc174 100644 --- a/native/jni-bridge/src/lib.rs +++ b/native/jni-bridge/src/lib.rs @@ -192,11 +192,13 @@ pub use comet_exec::*; mod batch_iterator; mod comet_metric_node; mod comet_task_memory_manager; +mod comet_udf_bridge; mod shuffle_block_iterator; use batch_iterator::CometBatchIterator; pub use comet_metric_node::*; pub use comet_task_memory_manager::*; +use comet_udf_bridge::CometUdfBridge; use shuffle_block_iterator::CometShuffleBlockIterator; /// The JVM classes that are used in the JNI calls. @@ -228,6 +230,10 @@ pub struct JVMClasses<'a> { /// The CometTaskMemoryManager used for interacting with JVM side to /// acquire & release native memory. pub comet_task_memory_manager: CometTaskMemoryManager<'a>, + /// The CometUdfBridge class used to dispatch JVM scalar UDFs. + /// `None` if the class is not on the classpath; the JVM-UDF dispatch path + /// reports a clear error rather than crashing executor init. + pub comet_udf_bridge: Option>, } unsafe impl Send for JVMClasses<'_> {} @@ -298,6 +304,16 @@ impl JVMClasses<'_> { comet_batch_iterator: CometBatchIterator::new(env).unwrap(), comet_shuffle_block_iterator: CometShuffleBlockIterator::new(env).unwrap(), comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(), + comet_udf_bridge: { + // Optional: if the bridge class is absent (e.g. comet shading + // dropped org.apache.comet.udf.*), record None and clear the + // pending JVM exception so other JNI calls keep working. + let bridge = CometUdfBridge::new(env).ok(); + if env.exception_check() { + env.exception_clear(); + } + bridge + }, } }); } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index c7a305285d..90e3d87032 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -90,6 +90,7 @@ message Expr { ToCsv to_csv = 67; HoursTransform hours_transform = 68; ArraysZip arrays_zip = 69; + JvmScalarUdf jvm_scalar_udf = 70; } // Optional QueryContext for error reporting (contains SQL text and position) @@ -514,3 +515,18 @@ message ArraysZip { repeated Expr values = 1; repeated string names = 2; } + +// Scalar UDF dispatched to the JVM via JNI. Native side exports input arrays +// through Arrow C Data Interface, calls CometUdfBridge.evaluate, and imports +// the result. +message JvmScalarUdf { + // Fully-qualified Java/Scala class name implementing + // org.apache.comet.udf.CometUDF (must have a public no-arg constructor). + string class_name = 1; + // Argument expressions, evaluated by the native side before invocation. + repeated Expr args = 2; + // Expected return type. Used to import the result FFI_ArrowArray. + DataType return_type = 3; + // Whether the result column may contain nulls. + bool return_nullable = 4; +} diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index e9a4a546c1..33ffc1c886 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -36,6 +36,8 @@ regex = { workspace = true } # preserve_order: needed for get_json_object to match Spark's JSON key ordering serde_json = { version = "1.0", features = ["preserve_order"] } datafusion-comet-common = { workspace = true } +datafusion-comet-jni-bridge = { workspace = true } +jni = "0.22.4" futures = { workspace = true } twox-hash = "2.1.2" rand = { workspace = true } diff --git a/native/spark-expr/src/jvm_udf/mod.rs b/native/spark-expr/src/jvm_udf/mod.rs new file mode 100644 index 0000000000..e27082099a --- /dev/null +++ b/native/spark-expr/src/jvm_udf/mod.rs @@ -0,0 +1,252 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::fmt::{Display, Formatter}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use arrow::array::{make_array, ArrayRef}; +use arrow::datatypes::{DataType, Schema}; +use arrow::ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}; +use arrow::record_batch::RecordBatch; + +use datafusion::common::Result as DFResult; +use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr::PhysicalExpr; + +use datafusion_comet_jni_bridge::errors::{CometError, ExecutionError}; +use datafusion_comet_jni_bridge::JVMClasses; +use jni::objects::{JObject, JValue}; + +/// A scalar expression that delegates evaluation to a JVM-side `CometUDF` via JNI. +/// The JVM class named by `class_name` must implement `org.apache.comet.udf.CometUDF`. +#[derive(Debug)] +pub struct JvmScalarUdfExpr { + class_name: String, + args: Vec>, + return_type: DataType, + return_nullable: bool, +} + +impl JvmScalarUdfExpr { + pub fn new( + class_name: String, + args: Vec>, + return_type: DataType, + return_nullable: bool, + ) -> Self { + Self { + class_name, + args, + return_type, + return_nullable, + } + } +} + +impl Display for JvmScalarUdfExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "JvmScalarUdf({}", self.class_name)?; + for a in &self.args { + write!(f, ", {a}")?; + } + write!(f, ")") + } +} + +impl Hash for JvmScalarUdfExpr { + fn hash(&self, state: &mut H) { + self.class_name.hash(state); + for a in &self.args { + a.hash(state); + } + self.return_type.hash(state); + self.return_nullable.hash(state); + } +} + +impl PartialEq for JvmScalarUdfExpr { + fn eq(&self, other: &Self) -> bool { + self.class_name == other.class_name + && self.return_type == other.return_type + && self.return_nullable == other.return_nullable + && self.args.len() == other.args.len() + && self.args.iter().zip(&other.args).all(|(a, b)| a.eq(b)) + } +} + +impl Eq for JvmScalarUdfExpr {} + +impl PhysicalExpr for JvmScalarUdfExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } + + fn data_type(&self, _input_schema: &Schema) -> DFResult { + Ok(self.return_type.clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> DFResult { + Ok(self.return_nullable) + } + + fn evaluate(&self, batch: &RecordBatch) -> DFResult { + // Step 1: evaluate child expressions to get Arrow arrays. Scalar children + // (e.g. literal patterns) are sent as length-1 vectors rather than expanded + // to batch-row count, so the JVM bridge does not pay an O(rows) copy for + // values that never vary across the batch. + let arrays: Vec = self + .args + .iter() + .map(|e| match e.evaluate(batch)? { + ColumnarValue::Array(a) => Ok(a), + ColumnarValue::Scalar(s) => s.to_array_of_size(1), + }) + .collect::>()?; + + // Step 2: allocate FFI structs on the Rust heap and collect their raw pointers. + // The JVM writes into the out_array/out_schema slots and reads from the in_ slots. + let in_ffi_arrays: Vec> = arrays + .iter() + .map(|arr| Box::new(FFI_ArrowArray::new(&arr.to_data()))) + .collect(); + let in_ffi_schemas: Vec> = arrays + .iter() + .map(|arr| { + FFI_ArrowSchema::try_from(arr.data_type()) + .map(Box::new) + .map_err(|e| CometError::Arrow { source: e }) + }) + .collect::>()?; + + let in_arr_ptrs: Vec = in_ffi_arrays + .iter() + .map(|b| b.as_ref() as *const FFI_ArrowArray as i64) + .collect(); + let in_sch_ptrs: Vec = in_ffi_schemas + .iter() + .map(|b| b.as_ref() as *const FFI_ArrowSchema as i64) + .collect(); + + // Allocate output FFI slots. + let mut out_array = Box::new(FFI_ArrowArray::empty()); + let mut out_schema = Box::new(FFI_ArrowSchema::empty()); + let out_arr_ptr = out_array.as_mut() as *mut FFI_ArrowArray as i64; + let out_sch_ptr = out_schema.as_mut() as *mut FFI_ArrowSchema as i64; + + let class_name = self.class_name.clone(); + let n_args = arrays.len(); + + // Step 3: attach a JNI env for this thread and call the static bridge method. + JVMClasses::with_env(|env| { + let bridge = JVMClasses::get().comet_udf_bridge.as_ref().ok_or_else(|| { + CometError::from(ExecutionError::GeneralError( + "JVM UDF bridge unavailable: org.apache.comet.udf.CometUdfBridge \ + class was not found on the JVM classpath. Set \ + spark.comet.exec.regexp.engine=rust to disable this path." + .to_string(), + )) + })?; + + // Build the JVM String for the class name. + let jclass_name = env + .new_string(&class_name) + .map_err(|e| CometError::JNI { source: e })?; + + // Build the long[] arrays for input pointers. + let in_arr_java = env + .new_long_array(n_args) + .map_err(|e| CometError::JNI { source: e })?; + in_arr_java + .set_region(env, 0, &in_arr_ptrs) + .map_err(|e| CometError::JNI { source: e })?; + + let in_sch_java = env + .new_long_array(n_args) + .map_err(|e| CometError::JNI { source: e })?; + in_sch_java + .set_region(env, 0, &in_sch_ptrs) + .map_err(|e| CometError::JNI { source: e })?; + + // Call CometUdfBridge.evaluate(String, long[], long[], long, long) + let ret = unsafe { + env.call_static_method_unchecked( + &bridge.class, + bridge.method_evaluate, + bridge.method_evaluate_ret, + &[ + JValue::from(&jclass_name).as_jni(), + JValue::Object(JObject::from(in_arr_java).as_ref()).as_jni(), + JValue::Object(JObject::from(in_sch_java).as_ref()).as_jni(), + JValue::Long(out_arr_ptr).as_jni(), + JValue::Long(out_sch_ptr).as_jni(), + ], + ) + }; + + if let Some(exception) = datafusion_comet_jni_bridge::check_exception(env)? { + return Err(exception); + } + + ret.map_err(|e| CometError::JNI { source: e })?; + Ok(()) + })?; + + // Step 4: import the result from the FFI slots filled by the JVM. + // SAFETY: `*out_array` moves the FFI_ArrowArray out of the Box (the heap + // allocation is freed by the move), and `from_ffi` wraps it in an Arc that + // keeps the JVM-installed release callback alive until the resulting + // ArrayData drops. `out_schema` is borrowed; its release callback runs + // exactly once when the Box drops at end of scope. + let result_data = unsafe { from_ffi(*out_array, &out_schema) } + .map_err(|e| CometError::Arrow { source: e })?; + let result_array = make_array(result_data); + + // The JVM may produce arrays with different field names (e.g. Arrow Java's + // ListVector uses "$data$" for child fields) than what DataFusion expects + // (e.g. "item"). Cast to the declared return_type to normalize schema. + let result_array = if result_array.data_type() != &self.return_type { + arrow::compute::cast(&result_array, &self.return_type) + .map_err(|e| CometError::Arrow { source: e })? + } else { + result_array + }; + + Ok(ColumnarValue::Array(result_array)) + } + + fn children(&self) -> Vec<&Arc> { + self.args.iter().collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DFResult> { + Ok(Arc::new(JvmScalarUdfExpr::new( + self.class_name.clone(), + children, + self.return_type.clone(), + self.return_nullable, + ))) + } +} diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index eddf2ff460..df6a82ae74 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -57,6 +57,7 @@ pub use bloom_filter::{BloomFilterAgg, BloomFilterMightContain, SparkBloomFilter mod conditional_funcs; mod conversion_funcs; +pub mod jvm_udf; mod map_funcs; pub use map_funcs::spark_map_sort; mod math_funcs; diff --git a/pom.xml b/pom.xml index b83a6fd45b..3e104667ee 100644 --- a/pom.xml +++ b/pom.xml @@ -1148,6 +1148,7 @@ under the License. native/proto/src/generated/** benchmarks/tpc/queries/** .claude/** + docs/superpowers/** diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 3fc4adb623..c2a3a67747 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -177,6 +177,9 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[Like] -> CometLike, classOf[Lower] -> CometLower, classOf[OctetLength] -> CometScalarFunction("octet_length"), + classOf[RegExpExtract] -> CometRegExpExtract, + classOf[RegExpExtractAll] -> CometRegExpExtractAll, + classOf[RegExpInStr] -> CometRegExpInStr, classOf[RegExpReplace] -> CometRegExpReplace, classOf[Reverse] -> CometReverse, classOf[RLike] -> CometRLike, diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 968fe8cd69..6840cd9c13 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,15 +21,15 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper} -import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpExtract, RegExpExtractAll, RegExpInStr, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper} +import org.apache.spark.sql.types.{ArrayType, BinaryType, DataTypes, IntegerType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.expressions.{CometCast, CometEvalMode, RegExp} import org.apache.comet.serde.ExprOuterClass.Expr -import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} +import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, serializeDataType} object CometStringRepeat extends CometExpressionSerde[StringRepeat] { @@ -264,9 +264,32 @@ object CometLike extends CometExpressionSerde[Like] { object CometRLike extends CometExpressionSerde[RLike] { override def getIncompatibleReasons(): Seq[String] = Seq( - "Uses Rust regexp engine, which has different behavior to Java regexp engine") + s"When ${CometConf.COMET_REGEXP_ENGINE.key}=${CometConf.REGEXP_ENGINE_RUST}: " + + "Uses Rust regexp engine, which has different behavior to Java regexp engine") + + override def getSupportLevel(expr: RLike): SupportLevel = { + if (CometConf.COMET_REGEXP_ENGINE.get() == CometConf.REGEXP_ENGINE_JAVA) { + expr.right match { + case _: Literal => Compatible(None) + case _ => Unsupported(Some("Only scalar regexp patterns are supported")) + } + } else { + super.getSupportLevel(expr) + } + } override def convert(expr: RLike, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + if (CometConf.COMET_REGEXP_ENGINE.get() == CometConf.REGEXP_ENGINE_JAVA) { + convertViaJvmUdf(expr, inputs, binding) + } else { + convertViaNativeRegex(expr, inputs, binding) + } + } + + private def convertViaNativeRegex( + expr: RLike, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { expr.right match { case Literal(pattern, DataTypes.StringType) => if (!RegExp.isSupportedPattern(pattern.toString) && @@ -291,6 +314,256 @@ object CometRLike extends CometExpressionSerde[RLike] { None } } + + private def convertViaJvmUdf( + expr: RLike, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + expr.right match { + case Literal(value, DataTypes.StringType) => + if (value == null) { + withInfo(expr, "Null literal pattern is handled by Spark fallback") + return None + } + val patternStr = value.toString + try { + java.util.regex.Pattern.compile(patternStr) + } catch { + case e: java.util.regex.PatternSyntaxException => + withInfo(expr, s"Invalid regex pattern: ${e.getDescription}") + return None + } + val subjectProto = exprToProtoInternal(expr.left, inputs, binding) + val patternProto = exprToProtoInternal(expr.right, inputs, binding) + if (subjectProto.isEmpty || patternProto.isEmpty) { + return None + } + val returnType = serializeDataType(DataTypes.BooleanType).getOrElse(return None) + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName("org.apache.comet.udf.RegExpLikeUDF") + .addArgs(subjectProto.get) + .addArgs(patternProto.get) + .setReturnType(returnType) + .setReturnNullable(expr.nullable) + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) + case _ => + withInfo(expr, "Only scalar regexp patterns are supported") + None + } + } +} + +object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { + + override def getSupportLevel(expr: RegExpExtract): SupportLevel = { + if (CometConf.COMET_REGEXP_ENGINE.get() == CometConf.REGEXP_ENGINE_JAVA) { + (expr.regexp, expr.idx) match { + case (_: Literal, _: Literal) => Compatible(None) + case (_: Literal, _) => + Unsupported(Some("Only scalar group index is supported")) + case _ => Unsupported(Some("Only scalar regexp patterns are supported")) + } + } else { + Unsupported( + Some( + s"regexp_extract requires ${CometConf.COMET_REGEXP_ENGINE.key}=" + + s"${CometConf.REGEXP_ENGINE_JAVA}")) + } + } + + override def convert( + expr: RegExpExtract, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + if (CometConf.COMET_REGEXP_ENGINE.get() != CometConf.REGEXP_ENGINE_JAVA) { + withInfo( + expr, + s"regexp_extract requires ${CometConf.COMET_REGEXP_ENGINE.key}=" + + s"${CometConf.REGEXP_ENGINE_JAVA}") + return None + } + (expr.regexp, expr.idx) match { + case (Literal(pattern, DataTypes.StringType), Literal(idx, _: IntegerType)) => + if (pattern == null) { + withInfo(expr, "Null literal pattern is handled by Spark fallback") + return None + } + try { + java.util.regex.Pattern.compile(pattern.toString) + } catch { + case e: java.util.regex.PatternSyntaxException => + withInfo(expr, s"Invalid regex pattern: ${e.getDescription}") + return None + } + val subjectProto = exprToProtoInternal(expr.subject, inputs, binding) + val patternProto = exprToProtoInternal(expr.regexp, inputs, binding) + val idxProto = exprToProtoInternal(expr.idx, inputs, binding) + if (subjectProto.isEmpty || patternProto.isEmpty || idxProto.isEmpty) { + return None + } + val returnType = serializeDataType(DataTypes.StringType).getOrElse(return None) + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName("org.apache.comet.udf.RegExpExtractUDF") + .addArgs(subjectProto.get) + .addArgs(patternProto.get) + .addArgs(idxProto.get) + .setReturnType(returnType) + .setReturnNullable(expr.nullable) + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) + case _ => + withInfo(expr, "Only scalar regexp patterns and group index are supported") + None + } + } +} + +object CometRegExpExtractAll extends CometExpressionSerde[RegExpExtractAll] { + + override def getSupportLevel(expr: RegExpExtractAll): SupportLevel = { + if (CometConf.COMET_REGEXP_ENGINE.get() == CometConf.REGEXP_ENGINE_JAVA) { + (expr.regexp, expr.idx) match { + case (_: Literal, _: Literal) => Compatible(None) + case (_: Literal, _) => + Unsupported(Some("Only scalar group index is supported")) + case _ => Unsupported(Some("Only scalar regexp patterns are supported")) + } + } else { + Unsupported( + Some( + s"regexp_extract_all requires ${CometConf.COMET_REGEXP_ENGINE.key}=" + + s"${CometConf.REGEXP_ENGINE_JAVA}")) + } + } + + override def convert( + expr: RegExpExtractAll, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + if (CometConf.COMET_REGEXP_ENGINE.get() != CometConf.REGEXP_ENGINE_JAVA) { + withInfo( + expr, + s"regexp_extract_all requires ${CometConf.COMET_REGEXP_ENGINE.key}=" + + s"${CometConf.REGEXP_ENGINE_JAVA}") + return None + } + (expr.regexp, expr.idx) match { + case (Literal(pattern, DataTypes.StringType), Literal(idx, _: IntegerType)) => + if (pattern == null) { + withInfo(expr, "Null literal pattern is handled by Spark fallback") + return None + } + try { + java.util.regex.Pattern.compile(pattern.toString) + } catch { + case e: java.util.regex.PatternSyntaxException => + withInfo(expr, s"Invalid regex pattern: ${e.getDescription}") + return None + } + val subjectProto = exprToProtoInternal(expr.subject, inputs, binding) + val patternProto = exprToProtoInternal(expr.regexp, inputs, binding) + val idxProto = exprToProtoInternal(expr.idx, inputs, binding) + if (subjectProto.isEmpty || patternProto.isEmpty || idxProto.isEmpty) { + return None + } + val returnType = + serializeDataType(ArrayType(StringType, containsNull = true)).getOrElse(return None) + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName("org.apache.comet.udf.RegExpExtractAllUDF") + .addArgs(subjectProto.get) + .addArgs(patternProto.get) + .addArgs(idxProto.get) + .setReturnType(returnType) + .setReturnNullable(expr.nullable) + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) + case _ => + withInfo(expr, "Only scalar regexp patterns and group index are supported") + None + } + } +} + +object CometRegExpInStr extends CometExpressionSerde[RegExpInStr] { + + override def getSupportLevel(expr: RegExpInStr): SupportLevel = { + if (CometConf.COMET_REGEXP_ENGINE.get() == CometConf.REGEXP_ENGINE_JAVA) { + (expr.regexp, expr.idx) match { + case (_: Literal, _: Literal) => Compatible(None) + case (_: Literal, _) => + Unsupported(Some("Only scalar group index is supported")) + case _ => Unsupported(Some("Only scalar regexp patterns are supported")) + } + } else { + Unsupported( + Some( + s"regexp_instr requires ${CometConf.COMET_REGEXP_ENGINE.key}=" + + s"${CometConf.REGEXP_ENGINE_JAVA}")) + } + } + + override def convert( + expr: RegExpInStr, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + if (CometConf.COMET_REGEXP_ENGINE.get() != CometConf.REGEXP_ENGINE_JAVA) { + withInfo( + expr, + s"regexp_instr requires ${CometConf.COMET_REGEXP_ENGINE.key}=" + + s"${CometConf.REGEXP_ENGINE_JAVA}") + return None + } + (expr.regexp, expr.idx) match { + case (Literal(pattern, DataTypes.StringType), Literal(idx, _: IntegerType)) => + if (pattern == null) { + withInfo(expr, "Null literal pattern is handled by Spark fallback") + return None + } + try { + java.util.regex.Pattern.compile(pattern.toString) + } catch { + case e: java.util.regex.PatternSyntaxException => + withInfo(expr, s"Invalid regex pattern: ${e.getDescription}") + return None + } + val subjectProto = exprToProtoInternal(expr.subject, inputs, binding) + val patternProto = exprToProtoInternal(expr.regexp, inputs, binding) + val idxProto = exprToProtoInternal(expr.idx, inputs, binding) + if (subjectProto.isEmpty || patternProto.isEmpty || idxProto.isEmpty) { + return None + } + val returnType = serializeDataType(DataTypes.IntegerType).getOrElse(return None) + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName("org.apache.comet.udf.RegExpInStrUDF") + .addArgs(subjectProto.get) + .addArgs(patternProto.get) + .addArgs(idxProto.get) + .setReturnType(returnType) + .setReturnNullable(expr.nullable) + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) + case _ => + withInfo(expr, "Only scalar regexp patterns and group index are supported") + None + } + } } object CometStringRPad extends CometExpressionSerde[StringRPad] { @@ -352,23 +625,28 @@ object CometStringLPad extends CometExpressionSerde[StringLPad] { object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] { override def getIncompatibleReasons(): Seq[String] = Seq( - "Regexp pattern may not be compatible with Spark") + s"When ${CometConf.COMET_REGEXP_ENGINE.key}=${CometConf.REGEXP_ENGINE_RUST}: " + + "Regexp pattern may not be compatible with Spark") override def getUnsupportedReasons(): Seq[String] = Seq( "Only supports `regexp_replace` with an offset of 1 (no offset)") override def getSupportLevel(expr: RegExpReplace): SupportLevel = { - if (!RegExp.isSupportedPattern(expr.regexp.toString) && - !CometConf.isExprAllowIncompat("regexp")) { - withInfo( - expr, - s"Regexp pattern ${expr.regexp} is not compatible with Spark. " + - s"Set ${CometConf.getExprAllowIncompatConfigKey("regexp")}=true " + - "to allow it anyway.") - return Incompatible() - } expr.pos match { - case Literal(value, DataTypes.IntegerType) if value == 1 => Compatible() + case Literal(value, DataTypes.IntegerType) if value == 1 => + if (CometConf.COMET_REGEXP_ENGINE.get() == CometConf.REGEXP_ENGINE_JAVA) { + expr.regexp match { + case _: Literal => Compatible(None) + case _ => Unsupported(Some("Only scalar regexp patterns are supported")) + } + } else { + if (!RegExp.isSupportedPattern(expr.regexp.toString) && + !CometConf.isExprAllowIncompat("regexp")) { + Incompatible() + } else { + Compatible() + } + } case _ => Unsupported(Some("Comet only supports regexp_replace with an offset of 1 (no offset).")) } @@ -378,6 +656,26 @@ object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] { expr: RegExpReplace, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + if (CometConf.COMET_REGEXP_ENGINE.get() == CometConf.REGEXP_ENGINE_JAVA) { + convertViaJvmUdf(expr, inputs, binding) + } else { + convertViaNativeRegex(expr, inputs, binding) + } + } + + private def convertViaNativeRegex( + expr: RegExpReplace, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + if (!RegExp.isSupportedPattern(expr.regexp.toString) && + !CometConf.isExprAllowIncompat("regexp")) { + withInfo( + expr, + s"Regexp pattern ${expr.regexp} is not compatible with Spark. " + + s"Set ${CometConf.getExprAllowIncompatConfigKey("regexp")}=true " + + "to allow it anyway.") + return None + } val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) val replacementExpr = exprToProtoInternal(expr.rep, inputs, binding) @@ -392,6 +690,49 @@ object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] { flagsExpr) optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.rep, expr.pos) } + + private def convertViaJvmUdf( + expr: RegExpReplace, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + expr.regexp match { + case Literal(pattern, DataTypes.StringType) => + if (pattern == null) { + withInfo(expr, "Null literal pattern is handled by Spark fallback") + return None + } + try { + java.util.regex.Pattern.compile(pattern.toString) + } catch { + case e: java.util.regex.PatternSyntaxException => + withInfo(expr, s"Invalid regex pattern: ${e.getDescription}") + return None + } + val subjectProto = exprToProtoInternal(expr.subject, inputs, binding) + val patternProto = exprToProtoInternal(expr.regexp, inputs, binding) + val repProto = exprToProtoInternal(expr.rep, inputs, binding) + if (subjectProto.isEmpty || patternProto.isEmpty || repProto.isEmpty) { + return None + } + val returnType = serializeDataType(DataTypes.StringType).getOrElse(return None) + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName("org.apache.comet.udf.RegExpReplaceUDF") + .addArgs(subjectProto.get) + .addArgs(patternProto.get) + .addArgs(repProto.get) + .setReturnType(returnType) + .setReturnNullable(expr.nullable) + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) + case _ => + withInfo(expr, "Only scalar regexp patterns are supported") + None + } + } } /** @@ -402,15 +743,35 @@ object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] { object CometStringSplit extends CometExpressionSerde[StringSplit] { override def getIncompatibleReasons(): Seq[String] = Seq( - "Regex engine differences between Java and Rust") - - override def getSupportLevel(expr: StringSplit): SupportLevel = - Incompatible(Some("Regex engine differences between Java and Rust")) + s"When ${CometConf.COMET_REGEXP_ENGINE.key}=${CometConf.REGEXP_ENGINE_RUST}: " + + "Regex engine differences between Java and Rust") + + override def getSupportLevel(expr: StringSplit): SupportLevel = { + if (CometConf.COMET_REGEXP_ENGINE.get() == CometConf.REGEXP_ENGINE_JAVA) { + expr.regex match { + case _: Literal => Compatible(None) + case _ => Unsupported(Some("Only scalar regex patterns are supported")) + } + } else { + Incompatible(Some("Regex engine differences between Java and Rust")) + } + } override def convert( expr: StringSplit, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + if (CometConf.COMET_REGEXP_ENGINE.get() == CometConf.REGEXP_ENGINE_JAVA) { + convertViaJvmUdf(expr, inputs, binding) + } else { + convertViaNativeRegex(expr, inputs, binding) + } + } + + private def convertViaNativeRegex( + expr: StringSplit, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { val strExpr = exprToProtoInternal(expr.str, inputs, binding) val regexExpr = exprToProtoInternal(expr.regex, inputs, binding) val limitExpr = exprToProtoInternal(expr.limit, inputs, binding) @@ -423,6 +784,50 @@ object CometStringSplit extends CometExpressionSerde[StringSplit] { limitExpr) optExprWithInfo(optExpr, expr, expr.str, expr.regex, expr.limit) } + + private def convertViaJvmUdf( + expr: StringSplit, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + expr.regex match { + case Literal(pattern, DataTypes.StringType) => + if (pattern == null) { + withInfo(expr, "Null literal pattern is handled by Spark fallback") + return None + } + try { + java.util.regex.Pattern.compile(pattern.toString) + } catch { + case e: java.util.regex.PatternSyntaxException => + withInfo(expr, s"Invalid regex pattern: ${e.getDescription}") + return None + } + val strProto = exprToProtoInternal(expr.str, inputs, binding) + val regexProto = exprToProtoInternal(expr.regex, inputs, binding) + val limitProto = exprToProtoInternal(expr.limit, inputs, binding) + if (strProto.isEmpty || regexProto.isEmpty || limitProto.isEmpty) { + return None + } + val returnType = + serializeDataType(ArrayType(StringType, containsNull = false)).getOrElse(return None) + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName("org.apache.comet.udf.StringSplitUDF") + .addArgs(strProto.get) + .addArgs(regexProto.get) + .addArgs(limitProto.get) + .setReturnType(returnType) + .setReturnNullable(expr.nullable) + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) + case _ => + withInfo(expr, "Only scalar regex patterns are supported") + None + } + } } object CometGetJsonObject extends CometExpressionSerde[GetJsonObject] { diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql new file mode 100644 index 0000000000..d1eab21409 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql @@ -0,0 +1,56 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Test regexp_extract via JVM regex engine (default engine) + +statement +CREATE TABLE test_regexp_extract(s string) USING parquet + +statement +INSERT INTO test_regexp_extract VALUES ('abc123def'), ('no match'), (NULL), ('xyz789'), ('hello world'), ('aa') + +-- group 0: entire match +query +SELECT regexp_extract(s, '\d+', 0) FROM test_regexp_extract + +-- group 1: first capturing group +query +SELECT regexp_extract(s, '([a-z]+)(\d+)', 1) FROM test_regexp_extract + +-- group 2: second capturing group +query +SELECT regexp_extract(s, '([a-z]+)(\d+)', 2) FROM test_regexp_extract + +-- no match returns empty string +query +SELECT regexp_extract(s, 'NOMATCH', 0) FROM test_regexp_extract + +-- backreference pattern (Java-only) +query +SELECT regexp_extract(s, '(\w)\1', 0) FROM test_regexp_extract + +-- lookahead (Java-only) +query +SELECT regexp_extract(s, 'abc(?=\d)', 0) FROM test_regexp_extract + +-- embedded flags (Java-only) +query +SELECT regexp_extract(s, '(?i)HELLO', 0) FROM test_regexp_extract + +-- literal arguments +query +SELECT regexp_extract('abc123', '(\d+)', 1), regexp_extract('no digits', '(\d+)', 1), regexp_extract(NULL, '(\d+)', 1) diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_all.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_all.sql new file mode 100644 index 0000000000..69b84875a4 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_all.sql @@ -0,0 +1,52 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Test regexp_extract_all via JVM regex engine (default engine) + +statement +CREATE TABLE test_regexp_extract_all(s string) USING parquet + +statement +INSERT INTO test_regexp_extract_all VALUES ('abc123def456'), ('no match'), (NULL), ('100-200-300'), ('hello world') + +-- group 0: all entire matches +query +SELECT regexp_extract_all(s, '\d+', 0) FROM test_regexp_extract_all + +-- group 1: first capturing group from each match +query +SELECT regexp_extract_all(s, '([a-z]+)(\d+)', 1) FROM test_regexp_extract_all + +-- group 2: second capturing group from each match +query +SELECT regexp_extract_all(s, '([a-z]+)(\d+)', 2) FROM test_regexp_extract_all + +-- no match returns empty array +query +SELECT regexp_extract_all(s, 'NOMATCH', 0) FROM test_regexp_extract_all + +-- backreference pattern (Java-only) +query +SELECT regexp_extract_all(s, '(\d)\1', 0) FROM test_regexp_extract_all + +-- embedded flags (Java-only) +query +SELECT regexp_extract_all(s, '(?i)[A-Z]+', 0) FROM test_regexp_extract_all + +-- literal arguments +query +SELECT regexp_extract_all('abc123def456', '(\d+)', 1), regexp_extract_all('no digits', '(\d+)', 1), regexp_extract_all(NULL, '(\d+)', 1) diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_instr.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_instr.sql new file mode 100644 index 0000000000..c394b8bb4d --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_instr.sql @@ -0,0 +1,48 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Test regexp_instr via JVM regex engine (default engine) + +statement +CREATE TABLE test_regexp_instr(s string) USING parquet + +statement +INSERT INTO test_regexp_instr VALUES ('abc123def'), ('no match'), (NULL), ('123xyz'), ('hello world'), ('aa') + +-- basic: position of first digit sequence +query +SELECT regexp_instr(s, '\d+', 0) FROM test_regexp_instr + +-- group 1 (still returns position of entire match per Spark semantics) +query +SELECT regexp_instr(s, '([a-z]+)(\d+)', 1) FROM test_regexp_instr + +-- no match returns 0 +query +SELECT regexp_instr(s, 'NOMATCH', 0) FROM test_regexp_instr + +-- backreference pattern (Java-only) +query +SELECT regexp_instr(s, '(\w)\1', 0) FROM test_regexp_instr + +-- embedded flags (Java-only) +query +SELECT regexp_instr(s, '(?i)HELLO', 0) FROM test_regexp_instr + +-- literal arguments +query +SELECT regexp_instr('abc123', '\d+', 0), regexp_instr('no digits', '\d+', 0), regexp_instr(NULL, '\d+', 0) diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_replace_java.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_replace_java.sql new file mode 100644 index 0000000000..ee8331314f --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_replace_java.sql @@ -0,0 +1,50 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Test regexp_replace via JVM regex engine (default engine) + +statement +CREATE TABLE test_regexp_replace_java(s string) USING parquet + +statement +INSERT INTO test_regexp_replace_java VALUES ('100-200'), ('abc'), (''), (NULL), ('phone 123-456-7890'), ('aabbcc') + +query +SELECT regexp_replace(s, '\d+', 'X') FROM test_regexp_replace_java + +query +SELECT regexp_replace(s, '\d+', 'X', 1) FROM test_regexp_replace_java + +-- backreference in replacement +query +SELECT regexp_replace(s, '(\d+)-(\d+)', '$2-$1') FROM test_regexp_replace_java + +-- backreference in pattern (Java-only) +query +SELECT regexp_replace(s, '(\w)\1', 'Z') FROM test_regexp_replace_java + +-- lookahead (Java-only) +query +SELECT regexp_replace(s, '\d+(?=-)', 'X') FROM test_regexp_replace_java + +-- embedded flags (Java-only) +query +SELECT regexp_replace(s, '(?i)ABC', 'X') FROM test_regexp_replace_java + +-- literal arguments +query +SELECT regexp_replace('100-200', '(\d+)', 'X'), regexp_replace('abc', '(\d+)', 'X'), regexp_replace(NULL, '(\d+)', 'X') diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_replace.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_replace_rust.sql similarity index 82% rename from spark/src/test/resources/sql-tests/expressions/string/regexp_replace.sql rename to spark/src/test/resources/sql-tests/expressions/string/regexp_replace_rust.sql index 967674a894..c4b030356b 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/regexp_replace.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_replace_rust.sql @@ -14,6 +14,8 @@ -- KIND, either express or implied. See the License for the -- specific language governing permissions and limitations -- under the License. +-- Test regexp_replace with Rust regexp engine (patterns expected to fallback) +-- Config: spark.comet.exec.regexp.engine=rust statement CREATE TABLE test_regexp_replace(s string) USING parquet @@ -21,8 +23,8 @@ CREATE TABLE test_regexp_replace(s string) USING parquet statement INSERT INTO test_regexp_replace VALUES ('100-200'), ('abc'), (''), (NULL), ('phone 123-456-7890') -query expect_fallback(Regexp pattern) +query expect_fallback(is not fully compatible with Spark) SELECT regexp_replace(s, '(\\d+)', 'X') FROM test_regexp_replace -query expect_fallback(Regexp pattern) +query expect_fallback(is not fully compatible with Spark) SELECT regexp_replace(s, '(\\d+)', 'X', 1) FROM test_regexp_replace diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_replace_enabled.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_replace_rust_enabled.sql similarity index 91% rename from spark/src/test/resources/sql-tests/expressions/string/regexp_replace_enabled.sql rename to spark/src/test/resources/sql-tests/expressions/string/regexp_replace_rust_enabled.sql index 97b4917c33..ee275fbd61 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/regexp_replace_enabled.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_replace_rust_enabled.sql @@ -15,7 +15,8 @@ -- specific language governing permissions and limitations -- under the License. --- Test regexp_replace() with regexp allowIncompatible enabled (happy path) +-- Test regexp_replace() with Rust regexp engine and allowIncompatible enabled +-- Config: spark.comet.exec.regexp.engine=rust -- Config: spark.comet.expression.regexp.allowIncompatible=true statement diff --git a/spark/src/test/resources/sql-tests/expressions/string/rlike_java.sql b/spark/src/test/resources/sql-tests/expressions/string/rlike_java.sql new file mode 100644 index 0000000000..5f4252b02f --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/rlike_java.sql @@ -0,0 +1,49 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Test RLIKE via JVM regex engine (default engine) + +statement +CREATE TABLE test_rlike_java(s string) USING parquet + +statement +INSERT INTO test_rlike_java VALUES ('hello'), ('12345'), (''), (NULL), ('Hello World'), ('abc123'), ('aa'), ('ab') + +query +SELECT s RLIKE '^\d+$' FROM test_rlike_java + +query +SELECT s RLIKE '^[a-z]+$' FROM test_rlike_java + +query +SELECT s RLIKE '' FROM test_rlike_java + +-- backreference (Java-only) +query +SELECT s RLIKE '^(\w)\1$' FROM test_rlike_java + +-- lookahead (Java-only) +query +SELECT s RLIKE 'abc(?=\d)' FROM test_rlike_java + +-- embedded flags (Java-only) +query +SELECT s RLIKE '(?i)hello' FROM test_rlike_java + +-- literal arguments +query +SELECT 'hello' RLIKE '^[a-z]+$', '12345' RLIKE '^\d+$', '' RLIKE '', NULL RLIKE 'a' diff --git a/spark/src/test/resources/sql-tests/expressions/string/rlike.sql b/spark/src/test/resources/sql-tests/expressions/string/rlike_rust.sql similarity index 90% rename from spark/src/test/resources/sql-tests/expressions/string/rlike.sql rename to spark/src/test/resources/sql-tests/expressions/string/rlike_rust.sql index 97350918ba..3daf23f53c 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/rlike.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/rlike_rust.sql @@ -15,6 +15,9 @@ -- specific language governing permissions and limitations -- under the License. +-- Test RLIKE with Rust regexp engine (patterns expected to fallback) +-- Config: spark.comet.exec.regexp.engine=rust + statement CREATE TABLE test_rlike(s string) USING parquet diff --git a/spark/src/test/resources/sql-tests/expressions/string/rlike_enabled.sql b/spark/src/test/resources/sql-tests/expressions/string/rlike_rust_enabled.sql similarity index 92% rename from spark/src/test/resources/sql-tests/expressions/string/rlike_enabled.sql rename to spark/src/test/resources/sql-tests/expressions/string/rlike_rust_enabled.sql index 5b2bd05fb3..f4917b6228 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/rlike_enabled.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/rlike_rust_enabled.sql @@ -15,7 +15,8 @@ -- specific language governing permissions and limitations -- under the License. --- Test RLIKE with regexp allowIncompatible enabled (happy path) +-- Test RLIKE with Rust regexp engine and allowIncompatible enabled +-- Config: spark.comet.exec.regexp.engine=rust -- Config: spark.comet.expression.regexp.allowIncompatible=true statement diff --git a/spark/src/test/resources/sql-tests/expressions/string/split_java.sql b/spark/src/test/resources/sql-tests/expressions/string/split_java.sql new file mode 100644 index 0000000000..6420ca9cee --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/split_java.sql @@ -0,0 +1,52 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Test split via JVM regex engine (default engine) + +statement +CREATE TABLE test_split_java(s string) USING parquet + +statement +INSERT INTO test_split_java VALUES ('one,two,three'), ('hello'), (''), (NULL), ('a::b::c'), ('aXbXc') + +-- basic split on comma +query +SELECT split(s, ',', -1) FROM test_split_java + +-- split with limit +query +SELECT split(s, ',', 2) FROM test_split_java + +-- split on regex pattern +query +SELECT split(s, '[,:]', -1) FROM test_split_java + +-- split on multi-char separator +query +SELECT split(s, '::', -1) FROM test_split_java + +-- lookahead in pattern (Java-only) +query +SELECT split(s, '(?=X)', -1) FROM test_split_java + +-- embedded flags (Java-only) +query +SELECT split(s, '(?i)x', -1) FROM test_split_java + +-- literal arguments +query +SELECT split('a,b,c', ',', -1), split('hello', ',', -1), split(NULL, ',', -1) diff --git a/spark/src/test/resources/sql-tests/expressions/string/split_rust.sql b/spark/src/test/resources/sql-tests/expressions/string/split_rust.sql new file mode 100644 index 0000000000..fc1cf3d815 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/split_rust.sql @@ -0,0 +1,31 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Test split with Rust regexp engine (patterns expected to fallback) +-- Config: spark.comet.exec.regexp.engine=rust + +statement +CREATE TABLE test_split_rust(s string) USING parquet + +statement +INSERT INTO test_split_rust VALUES ('one,two,three'), ('hello'), (''), (NULL), ('a::b::c') + +query expect_fallback(is not fully compatible with Spark) +SELECT split(s, ',', -1) FROM test_split_rust + +query expect_fallback(is not fully compatible with Spark) +SELECT split(s, '::', -1) FROM test_split_rust diff --git a/spark/src/test/resources/sql-tests/expressions/string/split_rust_enabled.sql b/spark/src/test/resources/sql-tests/expressions/string/split_rust_enabled.sql new file mode 100644 index 0000000000..048b44452b --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/split_rust_enabled.sql @@ -0,0 +1,39 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Test split with Rust regexp engine and allowIncompatible enabled +-- Config: spark.comet.exec.regexp.engine=rust +-- Config: spark.comet.expression.StringSplit.allowIncompatible=true + +statement +CREATE TABLE test_split_rust_enabled(s string) USING parquet + +statement +INSERT INTO test_split_rust_enabled VALUES ('one,two,three'), ('hello'), (''), (NULL), ('a::b::c') + +query +SELECT split(s, ',', -1) FROM test_split_rust_enabled + +query +SELECT split(s, ',', 2) FROM test_split_rust_enabled + +query +SELECT split(s, '::', -1) FROM test_split_rust_enabled + +-- literal arguments +query +SELECT split('a,b,c', ',', -1), split('hello', ',', -1), split(NULL, ',', -1) diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 48b8905035..63936a94b7 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -243,7 +243,7 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp val df = spark.read .parquet(path.toString) .withColumn("arr", array(col("_4"), lit(null), col("_4"))) - .withColumn("idx", udf((_: Int) => 1).apply(col("_4"))) + .withColumn("idx", org.apache.spark.sql.functions.udf((_: Int) => 1).apply(col("_4"))) .withColumn("arrUnsupportedArgs", expr("array_insert(arr, idx, 1)")) checkSparkAnswerAndFallbackReasons( df.select("arrUnsupportedArgs"), diff --git a/spark/src/test/scala/org/apache/comet/CometRegExpJvmSuite.scala b/spark/src/test/scala/org/apache/comet/CometRegExpJvmSuite.scala new file mode 100644 index 0000000000..e100c77913 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometRegExpJvmSuite.scala @@ -0,0 +1,391 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.SparkConf +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.comet.{CometFilterExec, CometProjectExec} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + +class CometRegExpJvmSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + override protected def sparkConf: SparkConf = + super.sparkConf.set(CometConf.COMET_REGEXP_ENGINE.key, CometConf.REGEXP_ENGINE_JAVA) + + // Patterns that the Rust regex crate cannot handle. Using one of these proves + // the JVM path was taken: if the pattern reached native, native would have + // rejected it and the operator would not be Comet. + private val backreference = "^(\\\\w)\\\\1$" + private val lookahead = "foo(?=bar)" + private val lookbehind = "(?<=foo)bar" + private val embeddedFlags = "(?i)foo" + private val namedGroup = "(?\\\\d)" + + private def withSubjects(values: String*)(f: => Unit): Unit = { + withTable("t") { + sql("CREATE TABLE t (s STRING) USING parquet") + val rows = values + .map(v => if (v == null) "(NULL)" else s"('${v.replace("'", "''")}')") + .mkString(", ") + sql(s"INSERT INTO t VALUES $rows") + f + } + } + + // ========== rlike tests ========== + + test("rlike: projection produces Java regex semantics with null handling") { + withSubjects("abc123", "no digits", null, "mixed_42_data") { + val df = sql("SELECT s, s rlike '\\\\d+' AS m FROM t") + checkSparkAnswerAndOperator(df) + } + } + + test("rlike: predicate filters rows using Java regex semantics") { + withSubjects("abc123", "no digits", null, "mixed_42_data") { + val df = sql("SELECT s FROM t WHERE s rlike '\\\\d+'") + checkSparkAnswerAndOperator(df) + } + } + + test("rlike: backreference in projection (Java-only construct)") { + withSubjects("aa", "ab", "xyzzy", null) { + val df = sql(s"SELECT s, s rlike '$backreference' FROM t") + checkSparkAnswerAndOperator(df) + val plan = df.queryExecution.executedPlan + assert( + collect(plan) { case p: CometProjectExec => p }.nonEmpty, + s"Expected CometProjectExec in:\n$plan") + } + } + + test("rlike: backreference in predicate (Java-only construct)") { + withSubjects("aa", "ab", "xyzzy", null) { + val df = sql(s"SELECT s FROM t WHERE s rlike '$backreference'") + checkSparkAnswerAndOperator(df) + val plan = df.queryExecution.executedPlan + assert( + collect(plan) { case f: CometFilterExec => f }.nonEmpty, + s"Expected CometFilterExec in:\n$plan") + } + } + + test("rlike: lookahead pattern (Java-only construct)") { + withSubjects("foobar", "foobaz", "barfoo", null) { + checkSparkAnswerAndOperator(sql(s"SELECT s, s rlike '$lookahead' FROM t")) + checkSparkAnswerAndOperator(sql(s"SELECT s FROM t WHERE s rlike '$lookahead'")) + } + } + + test("rlike: lookbehind pattern (Java-only construct)") { + withSubjects("foobar", "barbar", "foofoo", null) { + checkSparkAnswerAndOperator(sql(s"SELECT s, s rlike '$lookbehind' FROM t")) + } + } + + test("rlike: embedded case-insensitive flag (Java-only construct)") { + withSubjects("FOO", "foo", "fOO", "bar") { + checkSparkAnswerAndOperator(sql(s"SELECT s, s rlike '$embeddedFlags' FROM t")) + } + } + + test("rlike: named groups (Java-only construct)") { + withSubjects("a1", "ab", "9z", null) { + checkSparkAnswerAndOperator(sql(s"SELECT s, s rlike '$namedGroup' FROM t")) + } + } + + test("rlike: empty pattern matches every non-null row") { + withSubjects("abc", "", null) { + checkSparkAnswerAndOperator(sql("SELECT s, s rlike '' FROM t")) + } + } + + test("rlike: empty subject string is handled correctly") { + withSubjects("", "x", null) { + checkSparkAnswerAndOperator(sql("SELECT s, s rlike '^$' FROM t")) + } + } + + test("rlike: all-null subject column produces all-null result") { + withSubjects(null, null, null) { + checkSparkAnswerAndOperator(sql("SELECT s rlike '\\\\d+' FROM t")) + } + } + + test("rlike: null literal pattern falls back to Spark") { + withSubjects("a", "b", null) { + checkSparkAnswer(sql("SELECT s rlike CAST(NULL AS STRING) FROM t")) + } + } + + test("rlike: invalid pattern falls back to Spark") { + withSubjects("a") { + val ex = intercept[Throwable](sql("SELECT s rlike '[' FROM t").collect()) + assert( + ex.getMessage.toLowerCase.contains("regex") || + ex.getMessage.contains("PatternSyntax") || + ex.getMessage.contains("Unclosed"), + s"Unexpected error: ${ex.getMessage}") + } + } + + test("rlike: combines with filter, projection, and aggregate") { + withTable("t") { + sql("CREATE TABLE t (s STRING, k INT) USING parquet") + sql("""INSERT INTO t VALUES + | ('aa', 1), ('ab', 1), ('aa', 2), ('xyzzy', 2), ('aa', 3), (NULL, 3)""".stripMargin) + val df = sql(s"""SELECT k, COUNT(*) AS c + |FROM t + |WHERE s rlike '$backreference' + |GROUP BY k + |ORDER BY k""".stripMargin) + checkSparkAnswerAndOperator(df) + } + } + + test("rlike: many rows spanning multiple batches") { + withTable("t") { + sql("CREATE TABLE t (s STRING) USING parquet") + val values = (0 until 5000) + .map(i => if (i % 7 == 0) "(NULL)" else s"('row_${i}_aa')") + .mkString(", ") + sql(s"INSERT INTO t VALUES $values") + checkSparkAnswerAndOperator(sql(s"SELECT s, s rlike '$backreference' FROM t")) + checkSparkAnswerAndOperator(sql(s"SELECT s FROM t WHERE s rlike '$backreference'")) + } + } + + // ========== regexp_extract tests ========== + + test("regexp_extract: basic group extraction") { + withSubjects("abc123def", "no match", null, "xyz789") { + checkSparkAnswerAndOperator( + sql("SELECT s, regexp_extract(s, '([a-z]+)(\\\\d+)', 1) FROM t")) + checkSparkAnswerAndOperator( + sql("SELECT s, regexp_extract(s, '([a-z]+)(\\\\d+)', 2) FROM t")) + } + } + + test("regexp_extract: group 0 returns entire match") { + withSubjects("hello world", "foo123bar", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_extract(s, '\\\\d+', 0) FROM t")) + } + } + + test("regexp_extract: no match returns empty string") { + withSubjects("abc", "def", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_extract(s, '\\\\d+', 0) FROM t")) + } + } + + test("regexp_extract: backreference pattern (Java-only)") { + withSubjects("aa", "ab", "bb", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_extract(s, '(\\\\w)\\\\1', 0) FROM t")) + } + } + + test("regexp_extract: lookahead pattern (Java-only)") { + withSubjects("foobar", "foobaz", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_extract(s, 'foo(?=bar)', 0) FROM t")) + } + } + + test("regexp_extract: embedded flags (Java-only)") { + withSubjects("FOO123", "foo456", "bar789") { + checkSparkAnswerAndOperator( + sql("SELECT s, regexp_extract(s, '(?i)(foo)(\\\\d+)', 2) FROM t")) + } + } + + test("regexp_extract: all-null column") { + withSubjects(null, null, null) { + checkSparkAnswerAndOperator(sql("SELECT regexp_extract(s, '(\\\\d+)', 1) FROM t")) + } + } + + // ========== regexp_extract_all tests ========== + + test("regexp_extract_all: basic extraction of all matches") { + withSubjects("abc123def456", "no match", null, "x1y2z3") { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_extract_all(s, '(\\\\d+)', 1) FROM t")) + } + } + + test("regexp_extract_all: group 0 returns full matches") { + withSubjects("cat bat hat", "no vowels", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_extract_all(s, '[a-z]at', 0) FROM t")) + } + } + + test("regexp_extract_all: multiple groups") { + withSubjects("a1b2c3", "x9y8", null) { + checkSparkAnswerAndOperator( + sql("SELECT s, regexp_extract_all(s, '([a-z])(\\\\d)', 1) FROM t")) + checkSparkAnswerAndOperator( + sql("SELECT s, regexp_extract_all(s, '([a-z])(\\\\d)', 2) FROM t")) + } + } + + test("regexp_extract_all: no matches returns empty array") { + withSubjects("abc", "def") { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_extract_all(s, '\\\\d+', 0) FROM t")) + } + } + + test("regexp_extract_all: lookahead pattern (Java-only)") { + withSubjects("foobar foobaz fooqux") { + checkSparkAnswerAndOperator( + sql("SELECT s, regexp_extract_all(s, 'foo(?=ba[rz])', 0) FROM t")) + } + } + + // ========== regexp_replace tests ========== + + test("regexp_replace: basic replacement") { + withSubjects("abc123def456", "no digits", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_replace(s, '\\\\d+', 'NUM') FROM t")) + } + } + + test("regexp_replace: backreference in pattern (Java-only)") { + withSubjects("aabbcc", "abcabc", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_replace(s, '(\\\\w)\\\\1', 'X') FROM t")) + } + } + + test("regexp_replace: backreference in replacement") { + withSubjects("hello world", "foo bar", null) { + checkSparkAnswerAndOperator( + sql("SELECT s, regexp_replace(s, '(\\\\w+) (\\\\w+)', '$2 $1') FROM t")) + } + } + + test("regexp_replace: lookahead pattern (Java-only)") { + withSubjects("foobar", "foobaz", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_replace(s, 'foo(?=bar)', 'XXX') FROM t")) + } + } + + test("regexp_replace: empty pattern replaces between characters") { + withSubjects("abc", "", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_replace(s, '', '-') FROM t")) + } + } + + test("regexp_replace: all-null column") { + withSubjects(null, null, null) { + checkSparkAnswerAndOperator(sql("SELECT regexp_replace(s, '\\\\d', 'X') FROM t")) + } + } + + // ========== regexp_instr tests ========== + + test("regexp_instr: basic position finding") { + withSubjects("abc123def", "no match", null, "456xyz") { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_instr(s, '\\\\d+', 0) FROM t")) + } + } + + test("regexp_instr: specific group position") { + withSubjects("abc123def456", "xyz", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_instr(s, '([a-z]+)(\\\\d+)', 1) FROM t")) + checkSparkAnswerAndOperator(sql("SELECT s, regexp_instr(s, '([a-z]+)(\\\\d+)', 2) FROM t")) + } + } + + test("regexp_instr: no match returns 0") { + withSubjects("abc", "def", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_instr(s, '\\\\d+', 0) FROM t")) + } + } + + test("regexp_instr: lookahead (Java-only)") { + withSubjects("foobar", "foobaz", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_instr(s, 'foo(?=bar)', 0) FROM t")) + } + } + + // ========== split tests ========== + + test("split: basic regex split") { + withSubjects("a,b,c", "x,,y", null, "single") { + checkSparkAnswerAndOperator(sql("SELECT s, split(s, ',') FROM t")) + } + } + + test("split: regex pattern") { + withSubjects("abc123def456ghi", "no-digits", null) { + checkSparkAnswerAndOperator(sql("SELECT s, split(s, '\\\\d+') FROM t")) + } + } + + test("split: with limit") { + withSubjects("a,b,c,d,e") { + checkSparkAnswerAndOperator(sql("SELECT s, split(s, ',', 3) FROM t")) + } + } + + test("split: limit -1 returns all") { + withSubjects("a,,b,,c") { + checkSparkAnswerAndOperator(sql("SELECT s, split(s, ',', -1) FROM t")) + } + } + + test("split: lookahead pattern (Java-only)") { + withSubjects("camelCaseString", "anotherOne", null) { + checkSparkAnswerAndOperator(sql("SELECT s, split(s, '(?=[A-Z])') FROM t")) + } + } + + test("split: all-null column") { + withSubjects(null, null, null) { + checkSparkAnswerAndOperator(sql("SELECT split(s, ',') FROM t")) + } + } + + // ========== multi-batch and combined tests ========== + + test("regexp_extract: many rows spanning multiple batches") { + withTable("t") { + sql("CREATE TABLE t (s STRING) USING parquet") + val values = (0 until 5000) + .map(i => if (i % 7 == 0) "(NULL)" else s"('item_${i}_value')") + .mkString(", ") + sql(s"INSERT INTO t VALUES $values") + checkSparkAnswerAndOperator( + sql("SELECT s, regexp_extract(s, 'item_(\\\\d+)_value', 1) FROM t")) + } + } + + test("all regexp expressions combined in one query") { + withSubjects("abc123def456", "hello world", null, "aa") { + checkSparkAnswerAndOperator(sql(""" + |SELECT + | s, + | s rlike '\\d+' AS has_digits, + | regexp_extract(s, '(\\d+)', 1) AS first_num, + | regexp_replace(s, '\\d+', 'N') AS replaced, + | regexp_instr(s, '\\d+', 0) AS num_pos + |FROM t + |""".stripMargin)) + } + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometRegExpBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometRegExpBenchmark.scala new file mode 100644 index 0000000000..956729f08a --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometRegExpBenchmark.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.benchmark.Benchmark + +import org.apache.comet.CometConf + +/** + * Configuration for a single rlike pattern under benchmark. + * + * @param name + * short label for the pattern + * @param pattern + * the regex literal supplied to rlike + */ +case class RegExpPattern(name: String, pattern: String) + +/** + * Benchmark `rlike` across all execution modes: + * - Spark + * - Comet (Scan only) + * - Comet (Scan + Exec, native Rust regex) + * - Comet (Scan + Exec, JVM-side java.util.regex) + * + * To run: + * {{{ + * SPARK_GENERATE_BENCHMARK_FILES=1 \ + * make benchmark-org.apache.spark.sql.benchmark.CometRegExpBenchmark + * }}} + * + * Results land in `spark/benchmarks/CometRegExpBenchmark-**results.txt`. + */ +object CometRegExpBenchmark extends CometBenchmarkBase { + + // Patterns chosen to span common rlike shapes. Avoid Java-only constructs + // that the native (Rust) path cannot accept, since those would be skipped + // rather than benchmarked in the native case. + private val patterns = List( + RegExpPattern("character_class", "[0-9]+"), + RegExpPattern("anchored", "^[0-9]"), + RegExpPattern("alternation", "abc|def|ghi"), + RegExpPattern("multi_class", "[a-zA-Z][0-9]+"), + RegExpPattern("repetition", "(ab){2,}")) + + override def runCometBenchmark(mainArgs: Array[String]): Unit = { + runBenchmarkWithTable("rlike modes", 1024) { v => + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable( + dir, + spark.sql(s"SELECT REPEAT(CAST(value AS STRING), 10) AS c1 FROM $tbl")) + + patterns.foreach { p => + val query = s"select c1 rlike '${p.pattern}' from parquetV1Table" + runBenchmark(p.name) { + runRLikeModes(p.name, v, query) + } + } + } + } + } + } + + /** Runs all four modes for a single rlike query. */ + private def runRLikeModes(name: String, cardinality: Long, query: String): Unit = { + val benchmark = new Benchmark(name, cardinality, output = output) + + benchmark.addCase("Spark") { _ => + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + spark.sql(query).noop() + } + } + + benchmark.addCase("Comet (Scan)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "false") { + spark.sql(query).noop() + } + } + + val baseExec = Map( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + "spark.sql.optimizer.constantFolding.enabled" -> "false") + + benchmark.addCase("Comet (Exec, native Rust regex)") { _ => + val configs = baseExec ++ Map(CometConf.getExprAllowIncompatConfigKey("regexp") -> "true") + withSQLConf(configs.toSeq: _*) { + spark.sql(query).noop() + } + } + + benchmark.addCase("Comet (Exec, JVM regex)") { _ => + val configs = + baseExec ++ Map(CometConf.COMET_REGEXP_ENGINE.key -> CometConf.REGEXP_ENGINE_JAVA) + withSQLConf(configs.toSeq: _*) { + spark.sql(query).noop() + } + } + + benchmark.run() + } +}