From d0f43afb9f0f34cf7d8e2ceea38c5f92a23c83a1 Mon Sep 17 00:00:00 2001 From: akhadka Date: Wed, 6 May 2026 16:13:17 -0700 Subject: [PATCH 1/9] [SPARK-50593][SQL] Generalize ReducibleFunction reducer API with ReducibleParameters container --- .../catalog/functions/ReducibleFunction.java | 43 +++++++++ .../functions/ReducibleParameters.java | 93 +++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java index ef1a14e50cdad..bfbf186516b10 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -78,7 +78,10 @@ public interface ReducibleFunction { * @param otherBucketFunction the other parameterized function * @param otherNumBuckets parameter for the other function * @return a reduction function if it is reducible, null if not + * @deprecated Use {@link #reducer(ReducibleParameters, ReducibleFunction, ReducibleParameters)} + * for generic parameterized transforms. */ + @Deprecated default Reducer reducer( int thisNumBuckets, ReducibleFunction otherBucketFunction, @@ -103,4 +106,44 @@ default Reducer reducer( default Reducer reducer(ReducibleFunction otherFunction) { throw new UnsupportedOperationException(); } + + /** + * Generic reducer for any parameterized transform function. + *

+ * This extends SPJ support beyond bucket to transforms like truncate, which use + * non-integer parameters or multiple parameters. + *

+ * Example of reducing f_source = truncate(x, 5) on f_target = truncate(x, 3): + *

    + *
  • thisParams = ReducibleParameters([5])
  • + *
  • otherFunction = truncate
  • + *
  • otherParams = ReducibleParameters([3])
  • + *
  • reducer truncates to min(5, 3) = 3
  • + *
+ *

+ * Default implementation provides backward compatibility: if both parameter sets + * contain a single integer, delegates to {@link #reducer(int, ReducibleFunction, int)}. + * + * @param thisParams parameters for this function + * @param otherFunction the other reducible function + * @param otherParams parameters for the other function + * @return a reduction function if it is reducible + * @throws UnsupportedOperationException if not reducible + * @since 4.1.0 + */ + default Reducer reducer( + ReducibleParameters thisParams, + ReducibleFunction otherFunction, + ReducibleParameters otherParams) { + // Backward compatibility: single-int params → delegate to old bucket API + if (thisParams.count() == 1 && otherParams.count() == 1) { + try { + return reducer(thisParams.getInt(0), otherFunction, otherParams.getInt(0)); + } catch (ClassCastException | NumberFormatException ignored) { + // Not int parameters, fall through + } + } + throw new UnsupportedOperationException( + "reducer() with ReducibleParameters not implemented"); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java new file mode 100644 index 0000000000000..03716225fac2d --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connector.catalog.functions; + +import java.util.Collections; +import java.util.List; + +import org.apache.spark.annotation.Evolving; + +/** + * Container for parameters of a {@link ReducibleFunction}. + *

+ * Provides type-safe access to function parameters for generic reducer comparisons, + * enabling SPJ support for any parameterized transform (not just bucket). + *

+ * Examples: + *

    + *
  • bucket(4, x) → ReducibleParameters([4])
  • + *
  • truncate(x, 3) → ReducibleParameters([3])
  • + *
  • bucket(16, x) → ReducibleParameters([16])
  • + *
+ * + * @since 4.1.0 + */ +@Evolving +public class ReducibleParameters { + private final List values; + + public ReducibleParameters(List values) { + this.values = Collections.unmodifiableList(values); + } + + /** Number of parameters. */ + public int count() { + return values.size(); + } + + /** Get raw parameter value at index. */ + public Object get(int index) { + return values.get(index); + } + + /** Get parameter as int. Throws ClassCastException if not numeric. */ + public int getInt(int index) { + return ((Number) values.get(index)).intValue(); + } + + /** Get parameter as long. Throws ClassCastException if not numeric. */ + public long getLong(int index) { + return ((Number) values.get(index)).longValue(); + } + + /** Get parameter as String. */ + public String getString(int index) { + return (String) values.get(index); + } + + /** Get parameter as double. Throws ClassCastException if not numeric. */ + public double getDouble(int index) { + return ((Number) values.get(index)).doubleValue(); + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (!(other instanceof ReducibleParameters)) return false; + return values.equals(((ReducibleParameters) other).values); + } + + @Override + public int hashCode() { + return values.hashCode(); + } + + @Override + public String toString() { + return "ReducibleParameters(" + values + ")"; + } +} From c225cc1f1917ae86c488858138cb2d73f011ae46 Mon Sep 17 00:00:00 2001 From: akhadka Date: Wed, 6 May 2026 18:55:00 -0700 Subject: [PATCH 2/9] [SPARK-50593][SQL] Support truncate transform for Storage Partitioned Joins by generalizing parameter handling --- .../catalog/functions/ReducibleFunction.java | 89 ++++++------ .../functions/ReducibleParameters.java | 102 +++++++++---- .../expressions/TransformExpression.scala | 132 +++++++++++++---- .../expressions/V2ExpressionUtils.scala | 13 +- .../plans/physical/partitioning.scala | 12 +- .../v2/DistributionAndOrderingUtils.scala | 6 +- .../KeyGroupedPartitioningSuite.scala | 135 +++++++++++++++++- .../functions/transformFunctions.scala | 83 ++++++++++- ...rojectedOrderingAndPartitioningSuite.scala | 8 +- .../exchange/EnsureRequirementsSuite.scala | 4 +- 10 files changed, 450 insertions(+), 134 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java index bfbf186516b10..13b033a98bec8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -60,6 +60,53 @@ @Evolving public interface ReducibleFunction { + /** + * Generic reducer for parameterized functions (bucket, truncate, etc.). + * + * If this function is 'reducible' on another function, return the {@link Reducer}. + *

+ * This method supports functions with any number of parameters of any type. + *

+ * Examples: + *

    + *
  • bucket(4, x) and bucket(2, x): + *
    thisParams = [4], otherParams = [2] + *
    Extract with: thisParams.getInt(0), otherParams.getInt(0) + *
  • + *
  • truncate(x, 3) and truncate(x, 5): + *
    thisParams = [3], otherParams = [5] + *
    Extract with: thisParams.getInt(0), otherParams.getInt(0) + *
  • + *
  • hypothetical range_bucket(x, 0L, 100L, 4): + *
    thisParams = [0L, 100L, 4] + *
    Extract with: thisParams.getLong(0), thisParams.getLong(1), thisParams.getInt(2) + *
  • + *
+ * + * @param thisParams parameters for this function + * @param otherFunction the other parameterized function + * @param otherParams parameters for the other function + * @return a reduction function if reducible, null otherwise + * @since 4.0.0 + */ + default Reducer reducer( + ReducibleParameters thisParams, + ReducibleFunction otherFunction, + ReducibleParameters otherParams) { + // Default: try old Int-based API for backward compatibility + if (thisParams.count() == 1 && otherParams.count() == 1) { + try { + return reducer( + thisParams.getInt(0), + otherFunction, + otherParams.getInt(0)); + } catch (ClassCastException ignored) { + // Not Int parameters, fall through + } + } + throw new UnsupportedOperationException(); + } + /** * This method is for the bucket function. * @@ -78,8 +125,6 @@ public interface ReducibleFunction { * @param otherBucketFunction the other parameterized function * @param otherNumBuckets parameter for the other function * @return a reduction function if it is reducible, null if not - * @deprecated Use {@link #reducer(ReducibleParameters, ReducibleFunction, ReducibleParameters)} - * for generic parameterized transforms. */ @Deprecated default Reducer reducer( @@ -106,44 +151,4 @@ default Reducer reducer( default Reducer reducer(ReducibleFunction otherFunction) { throw new UnsupportedOperationException(); } - - /** - * Generic reducer for any parameterized transform function. - *

- * This extends SPJ support beyond bucket to transforms like truncate, which use - * non-integer parameters or multiple parameters. - *

- * Example of reducing f_source = truncate(x, 5) on f_target = truncate(x, 3): - *

    - *
  • thisParams = ReducibleParameters([5])
  • - *
  • otherFunction = truncate
  • - *
  • otherParams = ReducibleParameters([3])
  • - *
  • reducer truncates to min(5, 3) = 3
  • - *
- *

- * Default implementation provides backward compatibility: if both parameter sets - * contain a single integer, delegates to {@link #reducer(int, ReducibleFunction, int)}. - * - * @param thisParams parameters for this function - * @param otherFunction the other reducible function - * @param otherParams parameters for the other function - * @return a reduction function if it is reducible - * @throws UnsupportedOperationException if not reducible - * @since 4.1.0 - */ - default Reducer reducer( - ReducibleParameters thisParams, - ReducibleFunction otherFunction, - ReducibleParameters otherParams) { - // Backward compatibility: single-int params → delegate to old bucket API - if (thisParams.count() == 1 && otherParams.count() == 1) { - try { - return reducer(thisParams.getInt(0), otherFunction, otherParams.getInt(0)); - } catch (ClassCastException | NumberFormatException ignored) { - // Not int parameters, fall through - } - } - throw new UnsupportedOperationException( - "reducer() with ReducibleParameters not implemented"); - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java index 03716225fac2d..4e0e3232a1c12 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java @@ -16,69 +16,115 @@ */ package org.apache.spark.sql.connector.catalog.functions; -import java.util.Collections; -import java.util.List; - import org.apache.spark.annotation.Evolving; +import java.util.Arrays; +import java.util.List; /** - * Container for parameters of a {@link ReducibleFunction}. - *

- * Provides type-safe access to function parameters for generic reducer comparisons, - * enabling SPJ support for any parameterized transform (not just bucket). - *

+ * Container for reducible function literal parameters. + * Provides type-safe access to parameters of various types. + * * Examples: *

    - *
  • bucket(4, x) → ReducibleParameters([4])
  • - *
  • truncate(x, 3) → ReducibleParameters([3])
  • - *
  • bucket(16, x) → ReducibleParameters([16])
  • + *
  • bucket(4, col) → ReducibleParameters([4])
  • + *
  • truncate(col, 3) → ReducibleParameters([3])
  • + *
  • range_bucket(col, 0L, 100L, 10) → ReducibleParameters([0L, 100L, 10])
  • + *
  • custom_transform(col, "param") → ReducibleParameters(["param"])
  • *
* - * @since 4.1.0 + * @since 4.0.0 */ @Evolving public class ReducibleParameters { private final List values; public ReducibleParameters(List values) { - this.values = Collections.unmodifiableList(values); + this.values = values; } - /** Number of parameters. */ + public ReducibleParameters(Object... values) { + this.values = Arrays.asList(values); + } + + /** + * Get the number of parameters. + */ public int count() { return values.size(); } - /** Get raw parameter value at index. */ - public Object get(int index) { - return values.get(index); + /** + * Check if this container has parameters. + */ + public boolean isEmpty() { + return values.isEmpty(); } - /** Get parameter as int. Throws ClassCastException if not numeric. */ + /** + * Get parameter at index as Integer. + * @throws ClassCastException if parameter is not an Integer + * @throws IndexOutOfBoundsException if index is invalid + */ public int getInt(int index) { - return ((Number) values.get(index)).intValue(); + return (Integer) values.get(index); } - /** Get parameter as long. Throws ClassCastException if not numeric. */ + /** + * Get parameter at index as Long. + * @throws ClassCastException if parameter is not a Long + * @throws IndexOutOfBoundsException if index is invalid + */ public long getLong(int index) { - return ((Number) values.get(index)).longValue(); + return (Long) values.get(index); } - /** Get parameter as String. */ + /** + * Get parameter at index as String. + * @throws ClassCastException if parameter is not a String + * @throws IndexOutOfBoundsException if index is invalid + */ public String getString(int index) { return (String) values.get(index); } - /** Get parameter as double. Throws ClassCastException if not numeric. */ + /** + * Get parameter at index as Double. + * @throws ClassCastException if parameter is not a Double + * @throws IndexOutOfBoundsException if index is invalid + */ public double getDouble(int index) { - return ((Number) values.get(index)).doubleValue(); + return (Double) values.get(index); + } + + /** + * Get parameter at index as Float. + * @throws ClassCastException if parameter is not a Float + * @throws IndexOutOfBoundsException if index is invalid + */ + public float getFloat(int index) { + return (Float) values.get(index); + } + + /** + * Get raw parameter value at index. + */ + public Object get(int index) { + return values.get(index); + } + + /** + * Get all parameter values as a list. + */ + public List getAll() { + return values; } @Override - public boolean equals(Object other) { - if (this == other) return true; - if (!(other instanceof ReducibleParameters)) return false; - return values.equals(((ReducibleParameters) other).values); + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ReducibleParameters that = (ReducibleParameters) o; + return values.equals(that.values); } @Override diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index 9041ed15fc501..eaf79bb6475c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction, ScalarFunction} +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction, ReducibleParameters, ScalarFunction} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.DataType @@ -28,35 +28,87 @@ import org.apache.spark.sql.types.DataType * * @param function the transform function itself. Spark will use it to decide whether two * partition transform expressions are compatible. - * @param numBucketsOpt the number of buckets if the transform is `bucket`. Unset otherwise. */ -case class TransformExpression( - function: BoundFunction, - children: Seq[Expression], - numBucketsOpt: Option[Int] = None) extends Expression { +case class TransformExpression(function: BoundFunction, children: Seq[Expression]) + extends Expression { override def nullable: Boolean = true /** - * Whether this [[TransformExpression]] has the same semantics as `other`. - * For instance, `bucket(32, c)` is equal to `bucket(32, d)`, but not to `bucket(16, d)` or - * `year(c)`. + * Extract literal children (constant parameters) from this transform. These are constant + * arguments like width in truncate(col, width). Literals are compared when checking if two + * transforms are the same. + */ + private lazy val literalChildren: Seq[Literal] = + children.collect { case l: Literal => l } + + /** + * Whether this [[TransformExpression]] has the same semantics as `other`. For instance, + * `bucket(32, c)` is equal to `bucket(32, d)`, but not to `bucket(16, d)` or `year(c)`. + * Similarly, `truncate(c, 2)` is equal to `truncate(d, 2)`, but may not to `truncate(c, 4)`. * * This will be used, for instance, by Spark to determine whether storage-partitioned join can * be triggered, by comparing partition transforms from both sides of the join and checking * whether they are compatible. * - * @param other the transform expression to compare to - * @return true if this and `other` has the same semantics w.r.t to transform, false otherwise. + * Two transforms are considered the same if: + * 1. They have the same function name + * 2. They have the same literal arguments (e.g., numBuckets for bucket, width for truncate) + * + * @param other + * the transform expression to compare to + * @return + * true if this and `other` has the same semantics w.r.t to transform, false otherwise. */ def isSameFunction(other: TransformExpression): Boolean = other match { - case TransformExpression(otherFunction, _, otherNumBucketsOpt) => - function.canonicalName() == otherFunction.canonicalName() && - numBucketsOpt == otherNumBucketsOpt + case TransformExpression(otherFunction, _) => + val sameFunctionName = function.canonicalName() == otherFunction.canonicalName() + + // Compare literal arguments to ensure transforms with different parameters + // (e.g., bucket(32, col) vs bucket(16, col), truncate(col, 2) vs truncate(col, 4)) + // are not considered the same + val otherLiterals = other.literalChildren + val sameLiterals = literalChildren.length == otherLiterals.length && + literalChildren.zip(otherLiterals).forall { case (l1, l2) => + l1.equals(l2) + } + + sameFunctionName && sameLiterals case _ => false } + /** + * Override canonicalized to ensure transforms with the same function and literals are + * considered semantically equal, regardless of which specific column references they use. + * + * This is crucial for Storage Partitioned Joins - we need bucket(4, tableA.id) and bucket(4, + * tableB.id) to be semantically equal so SPJ can be triggered. + */ + override lazy val canonicalized: Expression = { + // Canonicalize only the non-literal children (i.e., column references) + val canonicalizedReferenceChildren = children.map { + case l: Literal => l + case other => other.canonicalized + } + TransformExpression(function, canonicalizedReferenceChildren) + } + + /** + * Override collectLeaves to only return reference children (columns), not literal parameters. + * + * For TransformExpression, literal children are metadata about the transform function (e.g., + * numBuckets=4 in bucket(4, col), width=2 in truncate(col, 2)). All consumers of + * collectLeaves() expect only column references, not these metadata literals. + * + */ + override def collectLeaves(): Seq[Expression] = { + children.flatMap { + case _: Literal => Seq.empty // Skip literal parameters (metadata) + case other => other.collectLeaves() // Include column references + } + } + /** * Whether this [[TransformExpression]]'s function is compatible with the `other` * [[TransformExpression]]'s function. @@ -73,8 +125,8 @@ case class TransformExpression( } else { (function, other.function) match { case (f: ReducibleFunction[_, _], o: ReducibleFunction[_, _]) => - val thisReducer = reducer(f, numBucketsOpt, o, other.numBucketsOpt) - val otherReducer = reducer(o, other.numBucketsOpt, f, numBucketsOpt) + val thisReducer = reducer(f, this, o, other) + val otherReducer = reducer(o, other, f, this) thisReducer.isDefined || otherReducer.isDefined case _ => false } @@ -92,22 +144,47 @@ case class TransformExpression( */ def reducers(other: TransformExpression): Option[Reducer[_, _]] = { (function, other.function) match { - case(e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) => - reducer(e1, numBucketsOpt, e2, other.numBucketsOpt) + case (e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) => + reducer(e1, this, e2, other) case _ => None } } - // Return a Reducer for a reducible function on another reducible function + /** + * Extract all literal parameters from a transform expression. + * Returns ReducibleParameters containing the literal values in order. + * + * Examples: + * bucket(4, col) => ReducibleParameters([4]) + * truncate(col, 3) => ReducibleParameters([3]) + * days(col) => ReducibleParameters([]) (no literals) + */ + private def extractParameters(expr: TransformExpression): ReducibleParameters = { + import scala.jdk.CollectionConverters._ + val values = expr.literalChildren.map { + case Literal(value, _) => value.asInstanceOf[AnyRef] + } + new ReducibleParameters(values.asJava) + } + + /** + * Return a Reducer for a reducible function on another reducible function + * Handles both parameterized (bucket, truncate) and non-parameterized (days, hours) functions. + */ private def reducer( thisFunction: ReducibleFunction[_, _], - thisNumBucketsOpt: Option[Int], + thisExpr: TransformExpression, otherFunction: ReducibleFunction[_, _], - otherNumBucketsOpt: Option[Int]): Option[Reducer[_, _]] = { - val res = (thisNumBucketsOpt, otherNumBucketsOpt) match { - case (Some(numBuckets), Some(otherNumBuckets)) => - thisFunction.reducer(numBuckets, otherFunction, otherNumBuckets) - case _ => thisFunction.reducer(otherFunction) + otherExpr: TransformExpression): Option[Reducer[_, _]] = { + val thisParams = extractParameters(thisExpr) + val otherParams = extractParameters(otherExpr) + + val res = if (!thisParams.isEmpty && !otherParams.isEmpty) { + // Parameterized functions (bucket, truncate, etc.) + thisFunction.reducer(thisParams, otherFunction, otherParams) + } else { + // Non-parameterized functions (days, hours, etc.) + thisFunction.reducer(otherFunction) } Option(res) } @@ -118,10 +195,7 @@ case class TransformExpression( copy(children = newChildren) private lazy val resolvedFunction: Option[Expression] = this match { - case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) => - Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc, - Seq(Literal(numBuckets)) ++ arguments)) - case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) => + case TransformExpression(scalarFunc: ScalarFunction[_], arguments) => Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments)) case _ => None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala index c561677ed5ad5..5967491cf8e07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier} import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME -import org.apache.spark.sql.connector.expressions.{BucketTransform, Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform} +import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue} import org.apache.spark.sql.connector.read.{SampleMethod => V2SampleMethod} import org.apache.spark.sql.errors.DataTypeErrors.toSQLId @@ -116,17 +116,6 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { funCatalogOpt: Option[FunctionCatalog] = None): Option[Expression] = trans match { case IdentityTransform(ref) => Some(resolveRef[NamedExpression](ref, query)) - case BucketTransform(numBuckets, refs, sorted) - if sorted.isEmpty && refs.length == 1 && refs.forall(_.isInstanceOf[NamedReference]) => - val resolvedRefs = refs.map(r => resolveRef[NamedExpression](r, query)) - // Create a dummy reference for `numBuckets` here and use that, together with `refs`, to - // look up the V2 function. - val numBucketsRef = AttributeReference("numBuckets", IntegerType, nullable = false)() - funCatalogOpt.flatMap { catalog => - loadV2FunctionOpt(catalog, "bucket", Seq(numBucketsRef) ++ resolvedRefs).map { bound => - TransformExpression(bound, resolvedRefs, Some(numBuckets)) - } - } case NamedTransform(name, args) => val catalystArgs = args.map(toCatalyst(_, query, funCatalogOpt)) funCatalogOpt.flatMap { catalog => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index aeacdaec7a8de..d80e66d4cb56c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -641,7 +641,9 @@ object KeyedPartitioning { def supportsExpressions(expressions: Seq[Expression]): Boolean = { def isSupportedTransform(transform: TransformExpression): Boolean = { - transform.children.size == 1 && isReference(transform.children.head) + // TransformExpression.collectLeaves() only returns column references, not literals. + // We need exactly one column reference per transform. + transform.collectLeaves().size == 1 } @tailrec @@ -1335,7 +1337,13 @@ case class KeyedShuffleSpec( val newExpressions = partitioning.expressions.zip(keyPositions).map { case (te: TransformExpression, positionSet) => - te.copy(children = te.children.map(_ => clustering(positionSet.head))) + // Preserve literal parameters (e.g., numBuckets, truncate width) + // while replacing only column references with the new clustering expression + val newChildren = te.children.map { + case l: Literal => l // Keep literals as-is + case _ => clustering(positionSet.head) // Replace column references + } + te.copy(children = newChildren) case (_, positionSet) => clustering(positionSet.head) } KeyedPartitioning(newExpressions, partitioning.partitionKeys, partitioning.isGrouped) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala index 02e19dd053f29..9f3c4daa7a6f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, ResolveTimeZone, TypeCoercion} -import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder, TransformExpression, V2ExpressionUtils} +import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder, TransformExpression, V2ExpressionUtils} import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RebalancePartitions, RepartitionByExpression, Sort} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} @@ -96,9 +96,7 @@ object DistributionAndOrderingUtils { } private def resolveTransformExpression(expr: Expression): Expression = expr.transform { - case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) => - V2ExpressionUtils.resolveScalarFunction(scalarFunc, Seq(Literal(numBuckets)) ++ arguments) - case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) => + case TransformExpression(scalarFunc: ScalarFunction[_], arguments) => V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 711f6dbdcdb11..bf12815acab5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.functions.{col, max} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with ExplainSuiteHelper { private val functions = Seq( @@ -126,7 +127,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with val df = sql(s"SELECT * FROM testcat.ns.$table") val distribution = physical.ClusteredDistribution( - Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32)))) + Seq(TransformExpression(BucketFunction, Seq(Literal(32), attr("ts"))))) checkQueryPlan(df, distribution, physical.UnknownPartitioning(0)) } @@ -138,7 +139,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with val df = sql(s"SELECT * FROM testcat.ns.$table") val distribution = physical.ClusteredDistribution( - Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32)))) + Seq(TransformExpression(BucketFunction, Seq(Literal(32), attr("ts"))))) // Has exactly one partition. val partitionKeys = Seq(0).map(v => InternalRow.fromSeq(Seq(v))) @@ -194,13 +195,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with val df = sql(s"SELECT * FROM testcat.ns.$table") val distribution = physical.ClusteredDistribution( - Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32)))) + Seq(TransformExpression(BucketFunction, Seq(Literal(32), attr("ts"))))) checkQueryPlan(df, distribution, physical.UnknownPartitioning(0)) } } - test("non-clustered distribution: V2 function with multiple args") { + test("clustered distribution: V2 function with multiple args") { val partitions: Array[Transform] = Array( Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(2)) ) @@ -216,7 +217,11 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with val distribution = physical.ClusteredDistribution( Seq(TransformExpression(TruncateFunction, Seq(attr("data"), Literal(2))))) - checkQueryPlan(df, distribution, physical.UnknownPartitioning(0)) + // With truncate transform support, KeyedPartitioning should now work + val partitionKeys = Seq("aa", "bb", "cc").map(v => + InternalRow(UTF8String.fromString(v))) + checkQueryPlan(df, distribution, + physical.KeyedPartitioning(distribution.clustering, partitionKeys)) } /** @@ -4183,4 +4188,124 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } } } + + // === SPARK-50593: Truncate SPJ support tests === + + test("SPARK-50593: cross-function truncate vs bucket should NOT trigger SPJ") { + val partitions1 = Array( + Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(3)) + ) + val partitions2 = Array( + Expressions.bucket(4, "data") + ) + + createTable("trunc_cross1", columns, partitions1) + sql("INSERT INTO testcat.ns.trunc_cross1 VALUES " + + "(0, 'aaa', CAST('2022-01-01' AS timestamp)), " + + "(1, 'bbb', CAST('2021-01-01' AS timestamp))") + + createTable("trunc_cross2", columns2, partitions2) + sql("INSERT INTO testcat.ns.trunc_cross2 VALUES " + + "(1, 5, 'aaa'), " + + "(5, 10, 'bbb')") + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + + val df = sql( + selectWithMergeJoinHint("trunc_cross1", "trunc_cross2") + + "trunc_cross1.id, trunc_cross2.store_id " + + "FROM testcat.ns.trunc_cross1 JOIN testcat.ns.trunc_cross2 " + + "ON trunc_cross1.data = trunc_cross2.data " + + "ORDER BY trunc_cross1.id") + + // Different functions (truncate vs bucket) should NEVER enable SPJ + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.nonEmpty, + "truncate vs bucket should not trigger SPJ, but no shuffles found") + } + } + + test("SPARK-50593: TransformExpression.collectLeaves filters out literals") { + // bucket(4, col) has children = [Literal(4), col] but collectLeaves should return [col] + val col = attr("data") + val bucketExpr = TransformExpression(BucketFunction, Seq(Literal(4), col)) + val leaves = bucketExpr.collectLeaves() + assert(leaves.size == 1, s"Expected 1 leaf (column ref), got ${leaves.size}: $leaves") + assert(leaves.head.semanticEquals(col), + s"Expected leaf to be the column reference, got ${leaves.head}") + + // truncate(col, 3) has children = [col, Literal(3)] but collectLeaves should return [col] + val truncExpr = TransformExpression(TruncateFunction, Seq(col, Literal(3))) + val truncLeaves = truncExpr.collectLeaves() + assert(truncLeaves.size == 1, + s"Expected 1 leaf (column ref), got ${truncLeaves.size}: $truncLeaves") + assert(truncLeaves.head.semanticEquals(col), + s"Expected leaf to be the column reference, got ${truncLeaves.head}") + + // years(col) has children = [col] with no literals + val yearsExpr = TransformExpression(YearsFunction, Seq(col)) + val yearsLeaves = yearsExpr.collectLeaves() + assert(yearsLeaves.size == 1, + s"Expected 1 leaf for years(), got ${yearsLeaves.size}: $yearsLeaves") + } + + test("SPARK-50593: existing bucket SPJ still works with ReducibleParameters API") { + // This test verifies that the migration from reducer(int, func, int) + // to reducer(ReducibleParameters, func, ReducibleParameters) is backward compatible. + // BucketFunction now implements the new API but bucket SPJ should still work. + val table1 = "bucket_compat1" + val table2 = "bucket_compat2" + + val partitions1 = Array(Expressions.bucket(4, "id")) + val partitions2 = Array(Expressions.bucket(4, "store_id")) + + createTable(table1, columns, partitions1) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(0, 'aaa', CAST('2022-01-01' AS timestamp)), " + + "(1, 'bbb', CAST('2021-01-01' AS timestamp)), " + + "(2, 'ccc', CAST('2020-01-01' AS timestamp)), " + + "(3, 'ddd', CAST('2019-01-01' AS timestamp))") + + createTable(table2, columns2, partitions2) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(0, 5, 'aaa'), " + + "(1, 10, 'bbb'), " + + "(2, 15, 'ccc'), " + + "(3, 20, 'ddd')") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true") { + val df = sql( + selectWithMergeJoinHint(table1, table2) + + s"$table1.id, $table2.store_id " + + s"FROM testcat.ns.$table1 JOIN testcat.ns.$table2 " + + s"ON $table1.id = $table2.store_id " + + s"ORDER BY $table1.id") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, + "Bucket SPJ should still work after ReducibleParameters migration") + checkAnswer(df, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3))) + } + } + + test("SPARK-50593: ReducibleParameters backward compat - old int API still works via default") { + // The new reducer(ReducibleParameters, func, ReducibleParameters) default implementation + // delegates to the old reducer(int, func, int) for single-int params. + // This verifies bucket(4) vs bucket(2) still produces a reducer via the fallback path. + val bucketExpr4 = TransformExpression(BucketFunction, Seq(Literal(4), attr("id"))) + val bucketExpr2 = TransformExpression(BucketFunction, Seq(Literal(2), attr("id"))) + + // isCompatible should return true (4 and 2 share GCD > 1) + assert(bucketExpr4.isCompatible(bucketExpr2), + "bucket(4) and bucket(2) should be compatible via reducer") + + // reducers() should return a Reducer + val reducer = bucketExpr4.reducers(bucketExpr2) + assert(reducer.isDefined, "Expected a reducer for bucket(4) on bucket(2)") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 35102c6893d3b..8301ad70b0248 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -213,11 +213,14 @@ object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, In } override def reducer( - thisNumBuckets: Int, + thisParams: ReducibleParameters, otherFunc: ReducibleFunction[_, _], - otherNumBuckets: Int): Reducer[Int, Int] = { + otherParams: ReducibleParameters): Reducer[Int, Int] = { if (otherFunc == BucketFunction) { + val thisNumBuckets = thisParams.getInt(0) + val otherNumBuckets = otherParams.getInt(0) + val gcd = this.gcd(thisNumBuckets, otherNumBuckets) if (gcd > 1 && gcd != thisNumBuckets) { return BucketReducer(gcd) @@ -253,12 +256,35 @@ object StringSelfFunction extends ScalarFunction[UTF8String] { } object UnboundTruncateFunction extends UnboundFunction { - override def bind(inputType: StructType): BoundFunction = TruncateFunction + override def bind(inputType: StructType): BoundFunction = { + if (inputType.size == 2) { + inputType.head.dataType match { + case StringType | BinaryType => TruncateFunction + case IntegerType | LongType => IntegerTruncateFunction + case _ => + throw new UnsupportedOperationException( + s"'truncate' does not support data type: ${inputType.head.dataType}") + } + } else { + throw new UnsupportedOperationException( + "'truncate' requires exactly 2 arguments: (column, width)") + } + } + override def description(): String = name() override def name(): String = "truncate" } -object TruncateFunction extends ScalarFunction[UTF8String] { +/** + * Truncate transform for String/Binary types. + * Follows Iceberg spec: truncate(str, L) = str[0:L] + * + * Implements ReducibleFunction: ANY two different widths are compatible. + * The reducer uses the smaller width. + */ +object TruncateFunction + extends ScalarFunction[UTF8String] + with ReducibleFunction[UTF8String, UTF8String] { override def inputTypes(): Array[DataType] = Array(StringType, IntegerType) override def resultType(): DataType = StringType override def name(): String = "truncate" @@ -266,7 +292,52 @@ object TruncateFunction extends ScalarFunction[UTF8String] { override def toString: String = name() override def produceResult(input: InternalRow): UTF8String = { val str = input.getUTF8String(0) - val length = input.getInt(1) - str.substring(0, length) + val width = input.getInt(1) + str.substring(0, width) + } + + override def reducer( + thisParams: ReducibleParameters, + otherFunc: ReducibleFunction[_, _], + otherParams: ReducibleParameters): Reducer[UTF8String, UTF8String] = { + + if (otherFunc == TruncateFunction) { + val thisWidth = thisParams.getInt(0) + val otherWidth = otherParams.getInt(0) + val smallerWidth = math.min(thisWidth, otherWidth) + + if (smallerWidth != thisWidth) { + return TruncateReducer(smallerWidth) + } + } + null + } +} + +case class TruncateReducer(width: Int) extends Reducer[UTF8String, UTF8String] { + override def reduce(value: UTF8String): UTF8String = { + value.substring(0, width) + } + override def resultType(): DataType = StringType + override def displayName(): String = s"truncate($width)" +} + +/** + * Truncate transform for Integer/Long types. + * Follows Iceberg spec: truncate(value, W) = value - (((value % W) + W) % W) + * + * Does NOT implement ReducibleFunction because different integer truncate widths + * produce incompatible partition structures. + */ +object IntegerTruncateFunction extends ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(IntegerType, IntegerType) + override def resultType(): DataType = IntegerType + override def name(): String = "truncate" + override def canonicalName(): String = name() + override def toString: String = name() + override def produceResult(input: InternalRow): Int = { + val value = input.getInt(0) + val width = input.getInt(1) + value - (((value % width) + width) % width) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala index a70baece77844..d451d44059765 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, TransformExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal, TransformExpression} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning, KeyedPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning} import org.apache.spark.sql.connector.catalog.functions.{BucketFunction, YearsFunction} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -492,7 +492,7 @@ class ProjectedOrderingAndPartitioningSuite // KP([bucket(32, id)], keys1d) through Project(id as pk) should produce // KP([bucket(32, pk)], keys1d): the alias is pushed into the bucket's column argument. val id = AttributeReference("id", IntegerType)() - val bucketExpr = TransformExpression(BucketFunction, Seq(id), Some(32)) + val bucketExpr = TransformExpression(BucketFunction, Seq(Literal(32), id)) val keys1d = Seq(InternalRow(0), InternalRow(1), InternalRow(2)) val child = DummyLeafExecWithPartitioning( output = Seq(id), @@ -524,7 +524,7 @@ class ProjectedOrderingAndPartitioningSuite // Result: KP([bucket(32, id)], keys1d, isNarrowed=true, isGrouped=false). val id = AttributeReference("id", IntegerType)() val ts = AttributeReference("ts", IntegerType)() - val bucketExpr = TransformExpression(BucketFunction, Seq(id), Some(32)) + val bucketExpr = TransformExpression(BucketFunction, Seq(Literal(32), id)) val yearsExpr = TransformExpression(YearsFunction, Seq(ts)) // Projected to position [0] (bucket): (0),(1),(0) -- bucket value 0 appears twice. val keys2d = Seq(InternalRow(0, 2020), InternalRow(1, 2020), InternalRow(0, 2021)) @@ -554,7 +554,7 @@ class ProjectedOrderingAndPartitioningSuite // Result: KP([bucket(32, id), years(ts_alias)], keys2d) -- not narrowed. val id = AttributeReference("id", IntegerType)() val ts = AttributeReference("ts", IntegerType)() - val bucketExpr = TransformExpression(BucketFunction, Seq(id), Some(32)) + val bucketExpr = TransformExpression(BucketFunction, Seq(Literal(32), id)) val yearsExpr = TransformExpression(YearsFunction, Seq(ts)) val keys2d = Seq(InternalRow(0, 2020), InternalRow(1, 2020), InternalRow(0, 2021)) val child = DummyLeafExecWithPartitioning( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 17d00ec055e07..84eee883aeeda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -1191,11 +1191,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { } def bucket(numBuckets: Int, expr: Expression): TransformExpression = { - TransformExpression(BucketFunction, Seq(expr), Some(numBuckets)) + TransformExpression(BucketFunction, Seq(Literal(numBuckets), expr)) } def buckets(numBuckets: Int, expr: Seq[Expression]): TransformExpression = { - TransformExpression(BucketFunction, expr, Some(numBuckets)) + TransformExpression(BucketFunction, Seq(Literal(numBuckets)) ++ expr) } test("ShufflePartitionIdPassThrough - avoid unnecessary shuffle when children are compatible") { From f0d8088f8371bbbb67794da436abbe1e031faea9 Mon Sep 17 00:00:00 2001 From: akhadka Date: Wed, 13 May 2026 17:17:54 -0700 Subject: [PATCH 3/9] [SPARK-50593][SQL] Strengthen test coverage for truncate SPJ and ReducibleParameters backward compatibility --- .../KeyGroupedPartitioningSuite.scala | 114 +++++++++++++----- .../functions/transformFunctions.scala | 31 +++++ ...rojectedOrderingAndPartitioningSuite.scala | 8 +- 3 files changed, 118 insertions(+), 35 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index bf12815acab5a..7250b31560caf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -4210,25 +4210,71 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with "(5, 10, 'bbb')") withSQLConf( - SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", - SQLConf.V2_BUCKETING_ALLOW_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true", SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { val df = sql( - selectWithMergeJoinHint("trunc_cross1", "trunc_cross2") + - "trunc_cross1.id, trunc_cross2.store_id " + - "FROM testcat.ns.trunc_cross1 JOIN testcat.ns.trunc_cross2 " + - "ON trunc_cross1.data = trunc_cross2.data " + - "ORDER BY trunc_cross1.id") + s""" + |${selectWithMergeJoinHint("trunc_cross1", "trunc_cross2")} + |trunc_cross1.id, trunc_cross2.store_id + |FROM testcat.ns.trunc_cross1 JOIN testcat.ns.trunc_cross2 + |ON trunc_cross1.data = trunc_cross2.data + |ORDER BY trunc_cross1.id + |""".stripMargin) - // Different functions (truncate vs bucket) should NEVER enable SPJ + // Different functions (truncate vs bucket) are not mutually reducible, so a shuffle + // must still be planned. val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.nonEmpty, - "truncate vs bucket should not trigger SPJ, but no shuffles found") + "truncate vs bucket are not compatible - a shuffle should be present, " + + "but none was planned") + checkAnswer(df, Seq(Row(0, 1), Row(1, 5))) } } + test("SPARK-50593: truncate(3) vs truncate(5) triggers SPJ via width reducer") { + // Exercises the ReducibleParameters-based reducer path end-to-end: truncate widths 3 and 5 + // are mutually reducible (reduce the larger to the smaller), so SPJ must avoid the shuffle. + val table1 = "trunc_three" + val table2 = "trunc_five" + + val partitions1 = Array( + Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(3))) + val partitions2 = Array( + Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(5))) + + createTable(table1, columns, partitions1) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(0, 'apple', CAST('2022-01-01' AS timestamp)), " + + "(1, 'grape', CAST('2021-01-01' AS timestamp)), " + + "(2, 'orange', CAST('2020-01-01' AS timestamp))") + + createTable(table2, columns, partitions2) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(10, 'apple', CAST('2022-01-01' AS timestamp)), " + + "(20, 'grape', CAST('2021-01-01' AS timestamp)), " + + "(30, 'orange', CAST('2020-01-01' AS timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + + val df = sql( + s""" + |${selectWithMergeJoinHint(table1, table2)} + |$table1.id AS left_id, $table2.id AS right_id + |FROM testcat.ns.$table1 JOIN testcat.ns.$table2 + |ON $table1.data = $table2.data + |ORDER BY $table1.id + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, + "truncate(3) vs truncate(5) should avoid shuffle via the width reducer, " + + "but a shuffle was planned") + checkAnswer(df, Seq(Row(0, 10), Row(1, 20), Row(2, 30))) + } + } test("SPARK-50593: TransformExpression.collectLeaves filters out literals") { // bucket(4, col) has children = [Literal(4), col] but collectLeaves should return [col] val col = attr("data") @@ -4254,14 +4300,14 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } test("SPARK-50593: existing bucket SPJ still works with ReducibleParameters API") { - // This test verifies that the migration from reducer(int, func, int) - // to reducer(ReducibleParameters, func, ReducibleParameters) is backward compatible. - // BucketFunction now implements the new API but bucket SPJ should still work. + // Exercises the new ReducibleParameters-based reducer path end-to-end: bucket(4) and + // bucket(2) differ, so SPJ can only avoid the shuffle if BucketFunction's reducer + // (now implemented via ReducibleParameters) correctly returns a GCD-based Reducer. val table1 = "bucket_compat1" val table2 = "bucket_compat2" val partitions1 = Array(Expressions.bucket(4, "id")) - val partitions2 = Array(Expressions.bucket(4, "store_id")) + val partitions2 = Array(Expressions.bucket(2, "store_id")) createTable(table1, columns, partitions1) sql(s"INSERT INTO testcat.ns.$table1 VALUES " + @@ -4278,34 +4324,40 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with "(3, 20, 'ddd')") withSQLConf( - SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true") { + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { val df = sql( - selectWithMergeJoinHint(table1, table2) + - s"$table1.id, $table2.store_id " + - s"FROM testcat.ns.$table1 JOIN testcat.ns.$table2 " + - s"ON $table1.id = $table2.store_id " + - s"ORDER BY $table1.id") + s""" + |${selectWithMergeJoinHint(table1, table2)} + |$table1.id, $table2.store_id + |FROM testcat.ns.$table1 JOIN testcat.ns.$table2 + |ON $table1.id = $table2.store_id + |ORDER BY $table1.id + |""".stripMargin) val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, - "Bucket SPJ should still work after ReducibleParameters migration") + "bucket(4) vs bucket(2) should avoid shuffle via the GCD reducer, " + + "but a shuffle was planned") checkAnswer(df, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3))) } } test("SPARK-50593: ReducibleParameters backward compat - old int API still works via default") { - // The new reducer(ReducibleParameters, func, ReducibleParameters) default implementation - // delegates to the old reducer(int, func, int) for single-int params. - // This verifies bucket(4) vs bucket(2) still produces a reducer via the fallback path. - val bucketExpr4 = TransformExpression(BucketFunction, Seq(Literal(4), attr("id"))) - val bucketExpr2 = TransformExpression(BucketFunction, Seq(Literal(2), attr("id"))) - - // isCompatible should return true (4 and 2 share GCD > 1) - assert(bucketExpr4.isCompatible(bucketExpr2), - "bucket(4) and bucket(2) should be compatible via reducer") + // Verifies the default reducer(ReducibleParameters, ...) implementation correctly + // delegates to the deprecated reducer(int, func, int) when a ReducibleFunction only + // overrides the old API. This mirrors how Iceberg 1.10.0 (and earlier) ship without + // knowledge of ReducibleParameters. + val bucketExpr4 = TransformExpression(LegacyBucketFunction, Seq(Literal(4), attr("id"))) + val bucketExpr2 = TransformExpression(LegacyBucketFunction, Seq(Literal(2), attr("id"))) - // reducers() should return a Reducer val reducer = bucketExpr4.reducers(bucketExpr2) - assert(reducer.isDefined, "Expected a reducer for bucket(4) on bucket(2)") + assert(reducer.isDefined, "Expected a reducer for legacy_bucket(4) on legacy_bucket(2)") + + // Verify the returned Reducer actually reduces bucket 4 -> bucket 2 (GCD = 2). + // bucket(4, x) produces values in [0, 4); reducing by GCD=2 gives v % 2. + val r = reducer.get.asInstanceOf[Reducer[Integer, Integer]] + assert(r.reduce(3) == 1, s"Expected reduce(3) == 1, got ${r.reduce(3)}") + assert(r.reduce(2) == 0, s"Expected reduce(2) == 0, got ${r.reduce(2)}") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 8301ad70b0248..fc034729330e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -238,6 +238,37 @@ case class BucketReducer(divisor: Int) extends Reducer[Int, Int] { override def displayName(): String = toString } +/** + * A bucket function that only overrides the deprecated `reducer(int, func, int)` method, + * not the new `reducer(ReducibleParameters, func, ReducibleParameters)` method. + * + * Used to verify that the default implementation of the new method correctly falls back + * to the deprecated int-based API, so legacy implementations continue to work. + */ +object LegacyBucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] { + override def inputTypes(): Array[DataType] = Array(IntegerType, LongType) + override def resultType(): DataType = IntegerType + override def name(): String = "legacy_bucket" + override def canonicalName(): String = name() + override def toString: String = name() + override def produceResult(input: InternalRow): Int = { + Math.floorMod(input.getLong(1), input.getInt(0)) + } + + override def reducer( + thisNumBuckets: Int, + otherFunc: ReducibleFunction[_, _], + otherNumBuckets: Int): Reducer[Int, Int] = { + if (otherFunc == LegacyBucketFunction) { + val gcd = BigInt(thisNumBuckets).gcd(BigInt(otherNumBuckets)).toInt + if (gcd > 1 && gcd != thisNumBuckets) { + return BucketReducer(gcd) + } + } + null + } +} + object UnboundStringSelfFunction extends UnboundFunction { override def bind(inputType: StructType): BoundFunction = StringSelfFunction override def description(): String = name() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala index d451d44059765..629d65bb20c0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala @@ -507,7 +507,7 @@ class ProjectedOrderingAndPartitioningSuite case te: TransformExpression => assert(te.isSameFunction(bucketExpr), "bucket function and numBuckets must be preserved after alias substitution") - assert(te.children.head.asInstanceOf[Attribute].name === "pk", + assert(te.children.collectFirst { case a: Attribute => a }.get.name === "pk", "bucket's column argument must be rewritten to the aliased attribute") case other => fail(s"Expected TransformExpression, got $other") } @@ -539,7 +539,7 @@ class ProjectedOrderingAndPartitioningSuite kp.expressions.head match { case te: TransformExpression => assert(te.isSameFunction(bucketExpr), "bucket must be the surviving expression") - assert(te.children.head.asInstanceOf[Attribute].name === "id") + assert(te.children.collectFirst { case a: Attribute => a }.get.name === "id") case other => fail(s"Expected TransformExpression, got $other") } assert(kp.isNarrowed, "dropping years(ts) position must mark the KP as narrowed") @@ -569,14 +569,14 @@ class ProjectedOrderingAndPartitioningSuite kp.expressions(0) match { case te: TransformExpression => assert(te.isSameFunction(bucketExpr)) - assert(te.children.head.asInstanceOf[Attribute].name === "id", + assert(te.children.collectFirst { case a: Attribute => a }.get.name === "id", "bucket's argument must remain id (no alias for id in this projection)") case other => fail(s"Expected TransformExpression at pos 0, got $other") } kp.expressions(1) match { case te: TransformExpression => assert(te.isSameFunction(yearsExpr)) - assert(te.children.head.asInstanceOf[Attribute].name === "ts_alias", + assert(te.children.collectFirst { case a: Attribute => a }.get.name === "ts_alias", "years() argument must be rewritten to ts_alias") case other => fail(s"Expected TransformExpression at pos 1, got $other") } From aeab76382803ee06f7f6fbe197313f520c6ff54d Mon Sep 17 00:00:00 2001 From: akhadka Date: Thu, 21 May 2026 17:11:09 -0700 Subject: [PATCH 4/9] [SPARK-50593][SQL] Harden reducer fallback and tighten SPJ supportsExpressions --- .../catalog/functions/ReducibleFunction.java | 21 +++----- .../functions/ReducibleParameters.java | 10 +++- .../expressions/TransformExpression.scala | 52 +++++++++++++++---- .../plans/physical/partitioning.scala | 5 +- .../KeyGroupedPartitioningSuite.scala | 43 +++++++++++++++ 5 files changed, 105 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java index 13b033a98bec8..c496196f3ca7c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -87,23 +87,12 @@ public interface ReducibleFunction { * @param otherFunction the other parameterized function * @param otherParams parameters for the other function * @return a reduction function if reducible, null otherwise - * @since 4.0.0 + * @since 5.0.0 */ default Reducer reducer( ReducibleParameters thisParams, ReducibleFunction otherFunction, ReducibleParameters otherParams) { - // Default: try old Int-based API for backward compatibility - if (thisParams.count() == 1 && otherParams.count() == 1) { - try { - return reducer( - thisParams.getInt(0), - otherFunction, - otherParams.getInt(0)); - } catch (ClassCastException ignored) { - // Not Int parameters, fall through - } - } throw new UnsupportedOperationException(); } @@ -125,8 +114,12 @@ default Reducer reducer( * @param otherBucketFunction the other parameterized function * @param otherNumBuckets parameter for the other function * @return a reduction function if it is reducible, null if not + * @deprecated as of 5.0.0. Please override + * {@link #reducer(ReducibleParameters, ReducibleFunction, ReducibleParameters)} instead. + * The new overload supports transforms with any number of parameters of any type + * (e.g. truncate width, multi-arg range buckets), not just a single int. */ - @Deprecated + @Deprecated(since = "5.0.0") default Reducer reducer( int thisNumBuckets, ReducibleFunction otherBucketFunction, @@ -149,6 +142,6 @@ default Reducer reducer( * @return a reduction function if it is reducible, null if not. */ default Reducer reducer(ReducibleFunction otherFunction) { - throw new UnsupportedOperationException(); + return reducer(ReducibleParameters.EMPTY, otherFunction, ReducibleParameters.EMPTY); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java index 4e0e3232a1c12..836cd5d890e5d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.connector.catalog.functions; import org.apache.spark.annotation.Evolving; + +import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -32,12 +34,18 @@ *
  • custom_transform(col, "param") → ReducibleParameters(["param"])
  • * * - * @since 4.0.0 + * @since 5.0.0 */ @Evolving public class ReducibleParameters { + public static final ReducibleParameters EMPTY = new ReducibleParameters(); + private final List values; + private ReducibleParameters() { + this.values = new ArrayList<>(); + } + public ReducibleParameters(List values) { this.values = values; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index eaf79bb6475c3..39773c990013c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.catalyst.expressions +import scala.util.Try + +import org.apache.spark.internal.Logging +import org.apache.spark.internal.LogKeys.FUNCTION_NAME import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction, ReducibleParameters, ScalarFunction} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, Decimal, DecimalType, StringType} +import org.apache.spark.unsafe.types.UTF8String /** * Represents a partition transform expression, for instance, `bucket`, `days`, `years`, etc. @@ -30,7 +35,7 @@ import org.apache.spark.sql.types.DataType * partition transform expressions are compatible. */ case class TransformExpression(function: BoundFunction, children: Seq[Expression]) - extends Expression { + extends Expression with Logging { override def nullable: Boolean = true @@ -162,6 +167,8 @@ case class TransformExpression(function: BoundFunction, children: Seq[Expression private def extractParameters(expr: TransformExpression): ReducibleParameters = { import scala.jdk.CollectionConverters._ val values = expr.literalChildren.map { + case Literal(value, _: StringType) => value.asInstanceOf[UTF8String].toString + case Literal(value, _: DecimalType) => value.asInstanceOf[Decimal].toJavaBigDecimal case Literal(value, _) => value.asInstanceOf[AnyRef] } new ReducibleParameters(values.asJava) @@ -178,15 +185,42 @@ case class TransformExpression(function: BoundFunction, children: Seq[Expression otherExpr: TransformExpression): Option[Reducer[_, _]] = { val thisParams = extractParameters(thisExpr) val otherParams = extractParameters(otherExpr) + val thisName = thisExpr.function.canonicalName() - val res = if (!thisParams.isEmpty && !otherParams.isEmpty) { - // Parameterized functions (bucket, truncate, etc.) - thisFunction.reducer(thisParams, otherFunction, otherParams) - } else { - // Non-parameterized functions (days, hours, etc.) - thisFunction.reducer(otherFunction) + def isSingleInt(p: ReducibleParameters): Boolean = { + p.count() == 1 && p.get(0).isInstanceOf[Int] + } + + // Both thrown exceptions and `null` returns collapse to None; any failure + // to compute a reducer falls back to a shuffle (no SPJ). + def tryReduce[R](call: => R): Try[Option[R]] = { + val attempt = Try(Option(call)) + attempt.failed.foreach { + case e: UnsupportedOperationException => + logWarning(log"V2 function ${MDC(FUNCTION_NAME, thisName)} threw " + + log"UnsupportedOperationException; treating as not reducible. Override " + + log"reducer(ReducibleParameters, ReducibleFunction, ReducibleParameters) " + + log"to enable SPJ.") + case _ => + } + + attempt } - Option(res) + + val res: Try[Option[Reducer[_, _]]] = + if (thisParams.isEmpty && otherParams.isEmpty) { + tryReduce(thisFunction.reducer(otherFunction)) + } else if (isSingleInt(thisParams) && isSingleInt(otherParams)) { + // Try deprecated int-API first for legacy connectors (e.g. Iceberg 1.10); + // the first attempt is silent because we have a fallback. Only the fallback warns. + Try(Option(thisFunction.reducer( + thisParams.getInt(0), otherFunction, otherParams.getInt(0)))) + .orElse(tryReduce(thisFunction.reducer(thisParams, otherFunction, otherParams))) + } else { + // Parameterized functions (bucket, truncate, etc.) + tryReduce(thisFunction.reducer(thisParams, otherFunction, otherParams)) + } + res.toOption.flatten } override def dataType: DataType = function.resultType() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index d80e66d4cb56c..495e2f8a0c73a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -641,9 +641,10 @@ object KeyedPartitioning { def supportsExpressions(expressions: Seq[Expression]): Boolean = { def isSupportedTransform(transform: TransformExpression): Boolean = { - // TransformExpression.collectLeaves() only returns column references, not literals. + // Should only consider column references, not literals. + val nonLiteralChildren = transform.children.filterNot(_.isInstanceOf[Literal]) // We need exactly one column reference per transform. - transform.collectLeaves().size == 1 + nonLiteralChildren.size == 1 && isReference(nonLiteralChildren.head) } @tailrec diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 7250b31560caf..b1d2d1d1d47cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -4343,6 +4343,49 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } } + test("SPARK-50593: bucket(4) vs bucket(3) - no common divisor, must shuffle") { + // GCD(4, 3) = 1 -- BucketFunction.reducer returns null. Spark must NOT enable SPJ. + // Regression guard: a buggy null-handling in TransformExpression.reducer (e.g., + // Try(...).toOption instead of Try(Option(...))) would treat null as Some(null), + // enable SPJ, and produce wrong join results for incompatible bucket layouts. + val table1 = "bucket_gcd1_a" + val table2 = "bucket_gcd1_b" + + val partitions1 = Array(Expressions.bucket(4, "id")) + val partitions2 = Array(Expressions.bucket(3, "store_id")) + + createTable(table1, columns, partitions1) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(0, 'aaa', CAST('2022-01-01' AS timestamp)), " + + "(1, 'bbb', CAST('2021-01-01' AS timestamp)), " + + "(2, 'ccc', CAST('2020-01-01' AS timestamp))") + + createTable(table2, columns2, partitions2) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(0, 5, 'aaa'), " + + "(1, 10, 'bbb'), " + + "(2, 15, 'ccc')") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + val df = sql( + s""" + |${selectWithMergeJoinHint(table1, table2)} + |$table1.id, $table2.store_id + |FROM testcat.ns.$table1 JOIN testcat.ns.$table2 + |ON $table1.id = $table2.store_id + |ORDER BY $table1.id + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.nonEmpty, + "bucket(4) vs bucket(3) have no common divisor (GCD=1), so the reducer " + + "returns null. SPJ must NOT be enabled; a shuffle is required.") + checkAnswer(df, Seq(Row(0, 0), Row(1, 1), Row(2, 2))) + } + } + test("SPARK-50593: ReducibleParameters backward compat - old int API still works via default") { // Verifies the default reducer(ReducibleParameters, ...) implementation correctly // delegates to the deprecated reducer(int, func, int) when a ReducibleFunction only From 721f5761944d114fecde071bc5b828f819c43eee Mon Sep 17 00:00:00 2001 From: akhadka Date: Wed, 3 Jun 2026 14:50:06 -0700 Subject: [PATCH 5/9] [SPARK-50593][SQL] Make SPJ compatibility and reducer checks nested-transform aware --- .../expressions/TransformExpression.scala | 123 +++++++----- .../plans/physical/partitioning.scala | 19 +- .../KeyGroupedPartitioningSuite.scala | 187 +++++++++++++++--- .../functions/transformFunctions.scala | 48 ++++- 4 files changed, 292 insertions(+), 85 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index 39773c990013c..c60e343e10076 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -21,12 +21,11 @@ import scala.util.Try import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.FUNCTION_NAME -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction, ReducibleParameters, ScalarFunction} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.types.{DataType, Decimal, DecimalType, StringType} -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.types.DataType /** * Represents a partition transform expression, for instance, `bucket`, `days`, `years`, etc. @@ -47,6 +46,23 @@ case class TransformExpression(function: BoundFunction, children: Seq[Expression private lazy val literalChildren: Seq[Literal] = children.collect { case l: Literal => l } + /** + * Whether every argument of this transform is either a literal parameter or a bare column + * reference (an [[Attribute]] or a [[GetStructField]] chain). This is the condition under which + * the transform's column slot can be safely rewritten to a join key (see + * `KeyedShuffleSpec.createPartitioning`): that rewrite replaces each non-literal child wholesale + * with the clustering key, which is only correct when the child IS a plain column reference. + * + * It excludes nested transforms (`bucket(4, years(ts))`), transforms hidden under a wrapper + * (`bucket(4, cast(years(ts)))`), and value-changing slots (`bucket(4, cast(a))`, + * `bucket(4, a + 1)`). A single non-recursive pass suffices: any disqualifying node appears at + * the top of some child, and [[isColumnRef]] rejects it without needing to look inside. + */ + def hasOnlyReferenceArgs: Boolean = children.forall { + case _: Literal => true + case e => isColumnRef(e) + } + /** * Whether this [[TransformExpression]] has the same semantics as `other`. For instance, * `bucket(32, c)` is equal to `bucket(32, d)`, but not to `bucket(16, d)` or `year(c)`. @@ -56,63 +72,69 @@ case class TransformExpression(function: BoundFunction, children: Seq[Expression * be triggered, by comparing partition transforms from both sides of the join and checking * whether they are compatible. * - * Two transforms are considered the same if: - * 1. They have the same function name - * 2. They have the same literal arguments (e.g., numBuckets for bucket, width for truncate) + * Two transforms are considered the same when they have the same function name, the same arity, + * and each pair of corresponding children matches: + * - literal arguments must be equal (e.g. numBuckets for bucket, width for truncate), so that + * `bucket(32, c)` is not the same as `bucket(16, c)`; + * - nested transform arguments must recursively be the same function, so that + * `bucket(4, years(c))` is not the same as `bucket(4, days(c))`; + * - everything else must be a plain column reference on both sides. Column identity is + * intentionally ignored (it is reconciled separately via positional matching), but a + * non-reference slot such as `c + 1` or `cast(c)`, or a literal/transform-vs-reference + * mismatch, is treated as not the same. * * @param other * the transform expression to compare to * @return * true if this and `other` has the same semantics w.r.t to transform, false otherwise. */ - def isSameFunction(other: TransformExpression): Boolean = other match { - case TransformExpression(otherFunction, _) => - val sameFunctionName = function.canonicalName() == otherFunction.canonicalName() - - // Compare literal arguments to ensure transforms with different parameters - // (e.g., bucket(32, col) vs bucket(16, col), truncate(col, 2) vs truncate(col, 4)) - // are not considered the same - val otherLiterals = other.literalChildren - val sameLiterals = literalChildren.length == otherLiterals.length && - literalChildren.zip(otherLiterals).forall { case (l1, l2) => - l1.equals(l2) - } - - sameFunctionName && sameLiterals - case _ => - false + def isSameFunction(other: TransformExpression): Boolean = + function.canonicalName() == other.function.canonicalName() && + childrenMatch(other)(_ == _) + + @scala.annotation.tailrec + private def isColumnRef(e: Expression): Boolean = e match { + case _: Attribute => true + case g: GetStructField => isColumnRef(g.child) + case _ => false } /** - * Override canonicalized to ensure transforms with the same function and literals are - * considered semantically equal, regardless of which specific column references they use. + * Whether every non-literal child of this and `other` is structurally the same: nested transforms + * must recursively be the same function, and any other slot must be a column reference. Literal + * children may differ -- they are exactly the parameters a [[Reducer]] is allowed to reconcile. * - * This is crucial for Storage Partitioned Joins - we need bucket(4, tableA.id) and bucket(4, - * tableB.id) to be semantically equal so SPJ can be triggered. + * This guards the reducer path. A [[Reducer]] is derived from the outer literal parameters alone + * (e.g. bucket numBuckets, truncate width); the nested transform children are not visible to it. + * It is therefore only valid when those nested children are identical. Without this check, + * `bucket(4, years(ts))` and `bucket(2, days(ts))` would be reduced via `gcd(4, 2) = 2`, silently + * joining mismatched partitions even though `years(ts)` and `days(ts)` are different transforms. */ - override lazy val canonicalized: Expression = { - // Canonicalize only the non-literal children (i.e., column references) - val canonicalizedReferenceChildren = children.map { - case l: Literal => l - case other => other.canonicalized - } - TransformExpression(function, canonicalizedReferenceChildren) - } + private def nonLiteralChildrenSame(other: TransformExpression): Boolean = + childrenMatch(other)((_, _) => true) /** - * Override collectLeaves to only return reference children (columns), not literal parameters. - * - * For TransformExpression, literal children are metadata about the transform function (e.g., - * numBuckets=4 in bucket(4, col), width=2 in truncate(col, 2)). All consumers of - * collectLeaves() expect only column references, not these metadata literals. + * Pairwise-match this transform's children against `other`'s. Requires equal arity, recursively + * the same function for nested transform arguments, and a plain column reference (Attribute / + * GetStructField chain) on both sides for any other slot. The `literalsMatch` predicate decides + * how literal parameters are compared: + * - `_ == _` for exact equality ([[isSameFunction]]); + * - `(_, _) => true` to allow them to differ ([[nonLiteralChildrenSame]], the reducer check, + * where differing literal parameters are exactly what a [[Reducer]] reconciles). * + * Note nested transform arguments always require full sameness via [[isSameFunction]] regardless + * of `literalsMatch`: the reducer is blind to nested transforms, so they must be identical. */ - override def collectLeaves(): Seq[Expression] = { - children.flatMap { - case _: Literal => Seq.empty // Skip literal parameters (metadata) - case other => other.collectLeaves() // Include column references - } - } + private def childrenMatch(other: TransformExpression) + (literalsMatch: (Literal, Literal) => Boolean): Boolean = + children.length == other.children.length && + children.zip(other.children).forall { + case (l1: Literal, l2: Literal) => literalsMatch(l1, l2) + case (t1: TransformExpression, t2: TransformExpression) => t1.isSameFunction(t2) + // any other pair: both must be plain column references; column identity is ignored, but a + // non-reference slot (Add, Cast, ...) or a Literal/Transform-vs-ref mismatch is "not same" + case (c1, c2) => isColumnRef(c1) && isColumnRef(c2) + } /** * Whether this [[TransformExpression]]'s function is compatible with the `other` @@ -167,9 +189,7 @@ case class TransformExpression(function: BoundFunction, children: Seq[Expression private def extractParameters(expr: TransformExpression): ReducibleParameters = { import scala.jdk.CollectionConverters._ val values = expr.literalChildren.map { - case Literal(value, _: StringType) => value.asInstanceOf[UTF8String].toString - case Literal(value, _: DecimalType) => value.asInstanceOf[Decimal].toJavaBigDecimal - case Literal(value, _) => value.asInstanceOf[AnyRef] + case Literal(value, dt) => CatalystTypeConverters.convertToScala(value, dt) } new ReducibleParameters(values.asJava) } @@ -183,6 +203,13 @@ case class TransformExpression(function: BoundFunction, children: Seq[Expression thisExpr: TransformExpression, otherFunction: ReducibleFunction[_, _], otherExpr: TransformExpression): Option[Reducer[_, _]] = { + // The reducer is derived from the literal parameters only (extractParameters drops nested + // transform children), so it is valid only when every non-literal child is structurally + // identical. This protects both `isCompatible` and the public `reducers` entry point from + // reducing across unrelated nested transforms, e.g. bucket(4, years(ts)) vs bucket(2, days(ts)) + if (!thisExpr.nonLiteralChildrenSame(otherExpr)) { + return None + } val thisParams = extractParameters(thisExpr) val otherParams = extractParameters(otherExpr) val thisName = thisExpr.function.canonicalName() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 495e2f8a0c73a..0d6154fdd662f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -641,10 +641,10 @@ object KeyedPartitioning { def supportsExpressions(expressions: Seq[Expression]): Boolean = { def isSupportedTransform(transform: TransformExpression): Boolean = { - // Should only consider column references, not literals. - val nonLiteralChildren = transform.children.filterNot(_.isInstanceOf[Literal]) - // We need exactly one column reference per transform. - nonLiteralChildren.size == 1 && isReference(nonLiteralChildren.head) + // Use Expression.references (AttributeSet), which already filters out non-attribute + // leaves (e.g., Literal parameters introduced by parameterized transforms such as + // `bucket(numBuckets, col)` or `truncate(col, width)`). + transform.references.size == 1 } @tailrec @@ -1328,8 +1328,15 @@ case class KeyedShuffleSpec( override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && !SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled && - partitioning.expressions.forall { e => - e.isInstanceOf[AttributeReference] || e.isInstanceOf[TransformExpression] + partitioning.expressions.forall { + case _: AttributeReference => true + // Only repartition-to-match a transform whose every argument is a literal or a bare column + // reference. Otherwise createPartitioning would rewrite a non-reference slot (nested + // transform, cast, arithmetic) wholesale to the join key and flatten its semantics away + // (e.g. bucket(4, years(ts)) -> bucket(4, key)), producing a non-co-locating partitioning. + // Fallback to a regular shuffle for those. + case t: TransformExpression => t.hasOnlyReferenceArgs + case _ => false } override def createPartitioning(clustering: Seq[Expression]): Partitioning = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index b1d2d1d1d47cf..4ab3634029891 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -4275,29 +4275,6 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with checkAnswer(df, Seq(Row(0, 10), Row(1, 20), Row(2, 30))) } } - test("SPARK-50593: TransformExpression.collectLeaves filters out literals") { - // bucket(4, col) has children = [Literal(4), col] but collectLeaves should return [col] - val col = attr("data") - val bucketExpr = TransformExpression(BucketFunction, Seq(Literal(4), col)) - val leaves = bucketExpr.collectLeaves() - assert(leaves.size == 1, s"Expected 1 leaf (column ref), got ${leaves.size}: $leaves") - assert(leaves.head.semanticEquals(col), - s"Expected leaf to be the column reference, got ${leaves.head}") - - // truncate(col, 3) has children = [col, Literal(3)] but collectLeaves should return [col] - val truncExpr = TransformExpression(TruncateFunction, Seq(col, Literal(3))) - val truncLeaves = truncExpr.collectLeaves() - assert(truncLeaves.size == 1, - s"Expected 1 leaf (column ref), got ${truncLeaves.size}: $truncLeaves") - assert(truncLeaves.head.semanticEquals(col), - s"Expected leaf to be the column reference, got ${truncLeaves.head}") - - // years(col) has children = [col] with no literals - val yearsExpr = TransformExpression(YearsFunction, Seq(col)) - val yearsLeaves = yearsExpr.collectLeaves() - assert(yearsLeaves.size == 1, - s"Expected 1 leaf for years(), got ${yearsLeaves.size}: $yearsLeaves") - } test("SPARK-50593: existing bucket SPJ still works with ReducibleParameters API") { // Exercises the new ReducibleParameters-based reducer path end-to-end: bucket(4) and @@ -4386,6 +4363,170 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } } + test("SPARK-50593: nested transforms with differing inner are not compatible " + + "(no false-positive SPJ)") { + import org.apache.spark.sql.catalyst.expressions.Expression + import org.apache.spark.sql.types.StringType + // A reducer is derived from the outer literal parameters only and is blind to nested inner + // transforms (extractParameters keeps only literalChildren). The nonLiteralChildrenSame guard + // in TransformExpression.reducer must refuse to reduce when the inner transforms differ; + // otherwise SPJ would silently co-locate mismatched partitions and drop matching rows. + val ts = attr("ts") + val data = AttributeReference("data", StringType)() + def bucket(n: Int, e: Expression): TransformExpression = + TransformExpression(BucketFunction, Seq(Literal(n), e)) + def years(e: Expression): TransformExpression = TransformExpression(YearsFunction, Seq(e)) + def days(e: Expression): TransformExpression = TransformExpression(DaysFunction, Seq(e)) + def truncate(e: Expression, w: Int): TransformExpression = + TransformExpression(TruncateFunction, Seq(e, Literal(w))) + + // Outer bucket reducible (4 vs 2) but inner differs (years vs days) -> NOT compatible. + val l1 = bucket(4, years(ts)) + val r1 = bucket(2, days(ts)) + assert(!l1.isSameFunction(r1)) + assert(!l1.isCompatible(r1), + "bucket(4, years(ts)) must not be compatible with bucket(2, days(ts))") + assert(l1.reducers(r1).isEmpty && r1.reducers(l1).isEmpty, + "no reducer may be produced across differing inner transforms") + + // No-arg reducer path: outer days/years (empty params) over differing inner buckets. + val l2 = days(bucket(4, ts)) + val r2 = years(bucket(2, ts)) + assert(!l2.isCompatible(r2), + "days(bucket(4, ts)) must not be compatible with years(bucket(2, ts))") + assert(l2.reducers(r2).isEmpty) + + // Truncate manifestation: outer width reducible (10 -> 5) but inner differs -> NOT compatible. + val l3 = truncate(bucket(4, data), 10) + val r3 = truncate(bucket(2, data), 5) + assert(!l3.isCompatible(r3), + "truncate(.., 10) over bucket(4) must not be compatible with truncate(.., 5) over bucket(2)") + assert(l3.reducers(r3).isEmpty) + + // Positive control 1: same inner, outer reducible -> still compatible with a reducer. + val same1 = bucket(4, years(ts)) + val same2 = bucket(2, years(ts)) + assert(same1.isCompatible(same2), + "same inner years(ts) with bucket(4) vs bucket(2) must remain compatible") + assert(same1.reducers(same2).isDefined) + + // Positive control 2: flat cross-function reducible days(ts) vs years(ts) -> still compatible. + assert(days(ts).isCompatible(years(ts)), + "flat days(ts) vs years(ts) must remain compatible (no regression)") + } + + test("SPARK-50593: isSameFunction recurses into nested transforms, respects column-ref slots") { + import org.apache.spark.sql.catalyst.expressions.{Add, Expression, GetStructField} + val a = attr("a") + val b = attr("b") + def bucket(n: Int, e: Expression): TransformExpression = + TransformExpression(BucketFunction, Seq(Literal(n), e)) + def years(e: Expression): TransformExpression = TransformExpression(YearsFunction, Seq(e)) + def days(e: Expression): TransformExpression = TransformExpression(DaysFunction, Seq(e)) + + // Nested identical -> same, with column identity ignored. Without this, the primary nested SPJ + // case (bucket(4, years(ts)) on both sides) would silently fall back to a shuffle. + assert(bucket(4, years(a)).isSameFunction(bucket(4, years(a)))) + assert(bucket(4, years(a)).isSameFunction(bucket(4, years(b))), "column identity is ignored") + // Nested different inner -> not same. + assert(!bucket(4, years(a)).isSameFunction(bucket(4, days(a)))) + // Different outer literal -> not same. + assert(!bucket(4, years(a)).isSameFunction(bucket(2, years(a)))) + // Flat sanity (no nesting). + assert(bucket(4, a).isSameFunction(bucket(4, b))) + assert(!bucket(4, a).isSameFunction(bucket(2, b))) + + // A non-reference column slot (a + 1) carries value-changing semantics, so it is conservatively + // treated as not-same -- even compared to itself. + val add = bucket(4, Add(a, Literal(1))) + assert(!add.isSameFunction(bucket(4, Add(b, Literal(1))))) + assert(!add.isSameFunction(add), "a non-reference slot is treated as not-same by design") + + // Struct-field column references are recognized (reflexivity preserved for genuine refs). + val s = AttributeReference("s", StructType(Seq(StructField("f", IntegerType))))() + val sf = GetStructField(s, 0) + assert(bucket(4, sf).isSameFunction(bucket(4, sf))) + } + + test("SPARK-50593: supportsExpressions admits single-attribute nested, rejects multi-attribute") { + import org.apache.spark.sql.catalyst.expressions.Add + val a = AttributeReference("a", IntegerType)() + val b = AttributeReference("b", IntegerType)() + + // Single-attribute nested transform -> admitted (references = {a}). + val nested = TransformExpression(BucketFunction, + Seq(Literal(4), TransformExpression(YearsFunction, Seq(a)))) + assert(nested.references.size == 1) + assert(physical.KeyedPartitioning.supportsExpressions(Seq(nested))) + + // Multi-attribute transform -> rejected at the gate; the positional model (keyPositions) needs + // exactly one clustering column per partition expression. + val multi = TransformExpression(BucketFunction, Seq(Literal(4), Add(a, b))) + assert(multi.references.size == 2) + assert(!physical.KeyedPartitioning.supportsExpressions(Seq(multi))) + } + + test("SPARK-50593: canCreatePartitioning is false for nested transforms (avoids flattening)") { + import org.apache.spark.sql.catalyst.expressions.{Add, Cast, Expression} + val ts = AttributeReference("ts", TimestampType)() + val n = AttributeReference("n", IntegerType)() + def bucket(n: Int, e: Expression): TransformExpression = + TransformExpression(BucketFunction, Seq(Literal(n), e)) + def years(e: Expression): TransformExpression = TransformExpression(YearsFunction, Seq(e)) + + // hasOnlyReferenceArgs: true only when every arg is a literal or a bare column reference. + assert(bucket(4, ts).hasOnlyReferenceArgs) // bare column + assert(!bucket(4, years(ts)).hasOnlyReferenceArgs) // nested transform + assert(!bucket(4, Cast(years(ts), IntegerType)).hasOnlyReferenceArgs) // transform under a cast + assert(!bucket(4, Cast(n, LongType)).hasOnlyReferenceArgs) // cast slot + assert(!bucket(4, Add(n, Literal(1))).hasOnlyReferenceArgs) // arithmetic slot + + val dist = physical.ClusteredDistribution(Seq(ts)) + val nested = physical.KeyedPartitioning(Seq(bucket(4, years(ts))), Seq.empty) + val flat = physical.KeyedPartitioning(Seq(bucket(4, ts)), Seq.empty) + + withSQLConf( + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false") { + // A nested transform must NOT be repartitioned-to-match: createPartitioning would flatten + // bucket(4, years(ts)) -> bucket(4, key), producing a non-co-locating partitioning. It must + // opt out so EnsureRequirements falls back to a regular shuffle. + assert(!physical.KeyedShuffleSpec(nested, dist).canCreatePartitioning, + "nested transform must opt out of createPartitioning") + // Flat transform remains repartitionable (no regression for v2 bucketing-shuffle). + assert(physical.KeyedShuffleSpec(flat, dist).canCreatePartitioning) + } + } + + test("SPARK-50593: integer truncate is reducible via lcm (generalized reducer, non-bucket)") { + // A second reducible transform exercising the generalized ReducibleParameters API with reducer + // math distinct from bucket (GCD) and string truncate (prefix-min): integer truncate snaps to + // a coarser grid, so truncate(v, W1) and truncate(v, W2) reduce onto multiples of lcm(W1, W2). + import org.apache.spark.sql.catalyst.expressions.Expression + val id = attr("id") + def itrunc(e: Expression, w: Int): TransformExpression = + TransformExpression(IntegerTruncateFunction, Seq(e, Literal(w))) + + // Same width -> same function (no reduction needed). + assert(itrunc(id, 4).isSameFunction(itrunc(id, 4))) + + // W2 is a multiple of W1: the finer side (W1=2) reduces onto the coarser grid (W2=4). + assert(itrunc(id, 2).isCompatible(itrunc(id, 4))) + val r = itrunc(id, 2).reducers(itrunc(id, 4)) + assert(r.isDefined, "truncate(2) must reduce onto truncate(4)") + val red = r.get.asInstanceOf[Reducer[Integer, Integer]] + // truncate(.,2) values snapped to multiples of 4: 6 -> 4, 2 -> 0, 8 -> 8 + assert(red.reduce(6) == 4 && red.reduce(2) == 0 && red.reduce(8) == 8) + // The coarser side (4) is already the common grid -> no reducer. + assert(itrunc(id, 4).reducers(itrunc(id, 2)).isEmpty) + + // Neither divides the other: both sides reduce to the lcm grid. + assert(itrunc(id, 6).isCompatible(itrunc(id, 4))) // lcm(6, 4) = 12 + assert(itrunc(id, 6).reducers(itrunc(id, 4)).isDefined) + assert(itrunc(id, 4).reducers(itrunc(id, 6)).isDefined) + assert(itrunc(id, 3).isCompatible(itrunc(id, 5))) // coprime -> lcm(3, 5) = 15 + } + test("SPARK-50593: ReducibleParameters backward compat - old int API still works via default") { // Verifies the default reducer(ReducibleParameters, ...) implementation correctly // delegates to the deprecated reducer(int, func, int) when a ReducibleFunction only diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index fc034729330e8..094db237a0c8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -290,8 +290,8 @@ object UnboundTruncateFunction extends UnboundFunction { override def bind(inputType: StructType): BoundFunction = { if (inputType.size == 2) { inputType.head.dataType match { - case StringType | BinaryType => TruncateFunction - case IntegerType | LongType => IntegerTruncateFunction + case StringType => TruncateFunction + case IntegerType => IntegerTruncateFunction case _ => throw new UnsupportedOperationException( s"'truncate' does not support data type: ${inputType.head.dataType}") @@ -307,7 +307,7 @@ object UnboundTruncateFunction extends UnboundFunction { } /** - * Truncate transform for String/Binary types. + * Truncate transform for String type. * Follows Iceberg spec: truncate(str, L) = str[0:L] * * Implements ReducibleFunction: ANY two different widths are compatible. @@ -354,13 +354,18 @@ case class TruncateReducer(width: Int) extends Reducer[UTF8String, UTF8String] { } /** - * Truncate transform for Integer/Long types. - * Follows Iceberg spec: truncate(value, W) = value - (((value % W) + W) % W) + * Truncate transform for Integer type. + * Follows Iceberg spec: truncate(value, W) = value - (((value % W) + W) % W), which snaps `value` + * down to a multiple of `W`. * - * Does NOT implement ReducibleFunction because different integer truncate widths - * produce incompatible partition structures. + * Implements ReducibleFunction: truncate(v, W1) and truncate(v, W2) are always reducible onto a + * common coarser grid of multiples of lcm(W1, W2). The finer side (whose width does not already + * equal the lcm) reduces by snapping to that grid; when W2 is a multiple of W1 the lcm is simply + * the coarser width W2. */ -object IntegerTruncateFunction extends ScalarFunction[Int] { +object IntegerTruncateFunction + extends ScalarFunction[Int] + with ReducibleFunction[Int, Int] { override def inputTypes(): Array[DataType] = Array(IntegerType, IntegerType) override def resultType(): DataType = IntegerType override def name(): String = "truncate" @@ -371,4 +376,31 @@ object IntegerTruncateFunction extends ScalarFunction[Int] { val width = input.getInt(1) value - (((value % width) + width) % width) } + + override def reducer( + thisParams: ReducibleParameters, + otherFunc: ReducibleFunction[_, _], + otherParams: ReducibleParameters): Reducer[Int, Int] = { + if (otherFunc == IntegerTruncateFunction) { + val thisWidth = thisParams.getInt(0) + val otherWidth = otherParams.getInt(0) + val common = lcm(thisWidth, otherWidth) + // Only the finer side reduces; if `common == thisWidth` this side is already the common grid. + if (common != thisWidth) { + return IntTruncateReducer(common) + } + } + null + } + + private def lcm(a: Int, b: Int): Int = { + val g = BigInt(a).gcd(BigInt(b)) + (BigInt(a) / g * BigInt(b)).toInt + } +} + +case class IntTruncateReducer(width: Int) extends Reducer[Int, Int] { + override def reduce(value: Int): Int = value - (((value % width) + width) % width) + override def resultType(): DataType = IntegerType + override def displayName(): String = s"truncate($width)" } From 9a706959267fe6e17f5a46ef47f87fda2c42c553 Mon Sep 17 00:00:00 2001 From: akhadka Date: Thu, 4 Jun 2026 16:07:58 -0700 Subject: [PATCH 6/9] [SPARK-50593][SQL] Restore strict SPJ transform gate, keep recursive isSameFunction --- .../expressions/TransformExpression.scala | 88 ++++--------- .../plans/physical/partitioning.scala | 29 ++--- .../KeyGroupedPartitioningSuite.scala | 117 ++++-------------- 3 files changed, 51 insertions(+), 183 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index c60e343e10076..c0eab818f2a7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import scala.annotation.tailrec import scala.util.Try import org.apache.spark.internal.Logging @@ -46,23 +47,6 @@ case class TransformExpression(function: BoundFunction, children: Seq[Expression private lazy val literalChildren: Seq[Literal] = children.collect { case l: Literal => l } - /** - * Whether every argument of this transform is either a literal parameter or a bare column - * reference (an [[Attribute]] or a [[GetStructField]] chain). This is the condition under which - * the transform's column slot can be safely rewritten to a join key (see - * `KeyedShuffleSpec.createPartitioning`): that rewrite replaces each non-literal child wholesale - * with the clustering key, which is only correct when the child IS a plain column reference. - * - * It excludes nested transforms (`bucket(4, years(ts))`), transforms hidden under a wrapper - * (`bucket(4, cast(years(ts)))`), and value-changing slots (`bucket(4, cast(a))`, - * `bucket(4, a + 1)`). A single non-recursive pass suffices: any disqualifying node appears at - * the top of some child, and [[isColumnRef]] rejects it without needing to look inside. - */ - def hasOnlyReferenceArgs: Boolean = children.forall { - case _: Literal => true - case e => isColumnRef(e) - } - /** * Whether this [[TransformExpression]] has the same semantics as `other`. For instance, * `bucket(32, c)` is equal to `bucket(32, d)`, but not to `bucket(16, d)` or `year(c)`. @@ -90,50 +74,14 @@ case class TransformExpression(function: BoundFunction, children: Seq[Expression */ def isSameFunction(other: TransformExpression): Boolean = function.canonicalName() == other.function.canonicalName() && - childrenMatch(other)(_ == _) - - @scala.annotation.tailrec - private def isColumnRef(e: Expression): Boolean = e match { - case _: Attribute => true - case g: GetStructField => isColumnRef(g.child) - case _ => false - } - - /** - * Whether every non-literal child of this and `other` is structurally the same: nested transforms - * must recursively be the same function, and any other slot must be a column reference. Literal - * children may differ -- they are exactly the parameters a [[Reducer]] is allowed to reconcile. - * - * This guards the reducer path. A [[Reducer]] is derived from the outer literal parameters alone - * (e.g. bucket numBuckets, truncate width); the nested transform children are not visible to it. - * It is therefore only valid when those nested children are identical. Without this check, - * `bucket(4, years(ts))` and `bucket(2, days(ts))` would be reduced via `gcd(4, 2) = 2`, silently - * joining mismatched partitions even though `years(ts)` and `days(ts)` are different transforms. - */ - private def nonLiteralChildrenSame(other: TransformExpression): Boolean = - childrenMatch(other)((_, _) => true) - - /** - * Pairwise-match this transform's children against `other`'s. Requires equal arity, recursively - * the same function for nested transform arguments, and a plain column reference (Attribute / - * GetStructField chain) on both sides for any other slot. The `literalsMatch` predicate decides - * how literal parameters are compared: - * - `_ == _` for exact equality ([[isSameFunction]]); - * - `(_, _) => true` to allow them to differ ([[nonLiteralChildrenSame]], the reducer check, - * where differing literal parameters are exactly what a [[Reducer]] reconciles). - * - * Note nested transform arguments always require full sameness via [[isSameFunction]] regardless - * of `literalsMatch`: the reducer is blind to nested transforms, so they must be identical. - */ - private def childrenMatch(other: TransformExpression) - (literalsMatch: (Literal, Literal) => Boolean): Boolean = - children.length == other.children.length && + children.length == other.children.length && children.zip(other.children).forall { - case (l1: Literal, l2: Literal) => literalsMatch(l1, l2) + case (l1: Literal, l2: Literal) => l1 == l2 case (t1: TransformExpression, t2: TransformExpression) => t1.isSameFunction(t2) - // any other pair: both must be plain column references; column identity is ignored, but a - // non-reference slot (Add, Cast, ...) or a Literal/Transform-vs-ref mismatch is "not same" - case (c1, c2) => isColumnRef(c1) && isColumnRef(c2) + // Any other pair must be a plain column reference on both sides. Column identity is + // ignored (reconciled separately via positional matching); a non-reference slot + // (Add, Cast, ...) or a literal/transform-vs-reference mismatch is "not the same". + case (c1, c2) => TransformExpression.isColumnRef(c1) && TransformExpression.isColumnRef(c2) } /** @@ -203,13 +151,6 @@ case class TransformExpression(function: BoundFunction, children: Seq[Expression thisExpr: TransformExpression, otherFunction: ReducibleFunction[_, _], otherExpr: TransformExpression): Option[Reducer[_, _]] = { - // The reducer is derived from the literal parameters only (extractParameters drops nested - // transform children), so it is valid only when every non-literal child is structurally - // identical. This protects both `isCompatible` and the public `reducers` entry point from - // reducing across unrelated nested transforms, e.g. bucket(4, years(ts)) vs bucket(2, days(ts)) - if (!thisExpr.nonLiteralChildrenSame(otherExpr)) { - return None - } val thisParams = extractParameters(thisExpr) val otherParams = extractParameters(otherExpr) val thisName = thisExpr.function.canonicalName() @@ -271,3 +212,18 @@ case class TransformExpression(function: BoundFunction, children: Seq[Expression override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this) } + +object TransformExpression { + /** + * Whether `e` is a bare column reference: an [[Attribute]] or a [[GetStructField]] chain + * (struct-field access on a column). Shared by [[TransformExpression.isSameFunction]] and by + * `KeyedPartitioning.supportsExpressions`, which both decide whether a transform's single + * non-literal argument is a plain column. + */ + @tailrec + private[sql] def isColumnRef(e: Expression): Boolean = e match { + case _: Attribute => true + case g: GetStructField => isColumnRef(g.child) + case _ => false + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 0d6154fdd662f..fc4d5bc5da73f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.plans.physical -import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.{SparkException, SparkUnsupportedOperationException} @@ -641,22 +640,15 @@ object KeyedPartitioning { def supportsExpressions(expressions: Seq[Expression]): Boolean = { def isSupportedTransform(transform: TransformExpression): Boolean = { - // Use Expression.references (AttributeSet), which already filters out non-attribute - // leaves (e.g., Literal parameters introduced by parameterized transforms such as - // `bucket(numBuckets, col)` or `truncate(col, width)`). - transform.references.size == 1 - } - - @tailrec - def isReference(e: Expression): Boolean = e match { - case _: Attribute => true - case g: GetStructField => isReference(g.child) - case _ => false + // Should only consider column references, not literals. + val nonLiteralChildren = transform.children.filterNot(_.isInstanceOf[Literal]) + // We need exactly one column reference per transform. + nonLiteralChildren.size == 1 && TransformExpression.isColumnRef(nonLiteralChildren.head) } expressions.forall { case t: TransformExpression if isSupportedTransform(t) => true - case e: Expression if isReference(e) => true + case e: Expression if TransformExpression.isColumnRef(e) => true case _ => false } } @@ -1328,15 +1320,8 @@ case class KeyedShuffleSpec( override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && !SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled && - partitioning.expressions.forall { - case _: AttributeReference => true - // Only repartition-to-match a transform whose every argument is a literal or a bare column - // reference. Otherwise createPartitioning would rewrite a non-reference slot (nested - // transform, cast, arithmetic) wholesale to the join key and flatten its semantics away - // (e.g. bucket(4, years(ts)) -> bucket(4, key)), producing a non-co-locating partitioning. - // Fallback to a regular shuffle for those. - case t: TransformExpression => t.hasOnlyReferenceArgs - case _ => false + partitioning.expressions.forall { e => + e.isInstanceOf[AttributeReference] || e.isInstanceOf[TransformExpression] } override def createPartitioning(clustering: Seq[Expression]): Partitioning = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 4ab3634029891..913e2628e6097 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -4363,58 +4363,6 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } } - test("SPARK-50593: nested transforms with differing inner are not compatible " + - "(no false-positive SPJ)") { - import org.apache.spark.sql.catalyst.expressions.Expression - import org.apache.spark.sql.types.StringType - // A reducer is derived from the outer literal parameters only and is blind to nested inner - // transforms (extractParameters keeps only literalChildren). The nonLiteralChildrenSame guard - // in TransformExpression.reducer must refuse to reduce when the inner transforms differ; - // otherwise SPJ would silently co-locate mismatched partitions and drop matching rows. - val ts = attr("ts") - val data = AttributeReference("data", StringType)() - def bucket(n: Int, e: Expression): TransformExpression = - TransformExpression(BucketFunction, Seq(Literal(n), e)) - def years(e: Expression): TransformExpression = TransformExpression(YearsFunction, Seq(e)) - def days(e: Expression): TransformExpression = TransformExpression(DaysFunction, Seq(e)) - def truncate(e: Expression, w: Int): TransformExpression = - TransformExpression(TruncateFunction, Seq(e, Literal(w))) - - // Outer bucket reducible (4 vs 2) but inner differs (years vs days) -> NOT compatible. - val l1 = bucket(4, years(ts)) - val r1 = bucket(2, days(ts)) - assert(!l1.isSameFunction(r1)) - assert(!l1.isCompatible(r1), - "bucket(4, years(ts)) must not be compatible with bucket(2, days(ts))") - assert(l1.reducers(r1).isEmpty && r1.reducers(l1).isEmpty, - "no reducer may be produced across differing inner transforms") - - // No-arg reducer path: outer days/years (empty params) over differing inner buckets. - val l2 = days(bucket(4, ts)) - val r2 = years(bucket(2, ts)) - assert(!l2.isCompatible(r2), - "days(bucket(4, ts)) must not be compatible with years(bucket(2, ts))") - assert(l2.reducers(r2).isEmpty) - - // Truncate manifestation: outer width reducible (10 -> 5) but inner differs -> NOT compatible. - val l3 = truncate(bucket(4, data), 10) - val r3 = truncate(bucket(2, data), 5) - assert(!l3.isCompatible(r3), - "truncate(.., 10) over bucket(4) must not be compatible with truncate(.., 5) over bucket(2)") - assert(l3.reducers(r3).isEmpty) - - // Positive control 1: same inner, outer reducible -> still compatible with a reducer. - val same1 = bucket(4, years(ts)) - val same2 = bucket(2, years(ts)) - assert(same1.isCompatible(same2), - "same inner years(ts) with bucket(4) vs bucket(2) must remain compatible") - assert(same1.reducers(same2).isDefined) - - // Positive control 2: flat cross-function reducible days(ts) vs years(ts) -> still compatible. - assert(days(ts).isCompatible(years(ts)), - "flat days(ts) vs years(ts) must remain compatible (no regression)") - } - test("SPARK-50593: isSameFunction recurses into nested transforms, respects column-ref slots") { import org.apache.spark.sql.catalyst.expressions.{Add, Expression, GetStructField} val a = attr("a") @@ -4424,8 +4372,9 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with def years(e: Expression): TransformExpression = TransformExpression(YearsFunction, Seq(e)) def days(e: Expression): TransformExpression = TransformExpression(DaysFunction, Seq(e)) - // Nested identical -> same, with column identity ignored. Without this, the primary nested SPJ - // case (bucket(4, years(ts)) on both sides) would silently fall back to a shuffle. + // Nested identical -> same (recursing into the inner transform), with column identity ignored. + // isSameFunction stays correct for nested shapes even though the SPJ gate currently rejects + // them; keeping this behavior is the right shape for any future nested support. assert(bucket(4, years(a)).isSameFunction(bucket(4, years(a)))) assert(bucket(4, years(a)).isSameFunction(bucket(4, years(b))), "column identity is ignored") // Nested different inner -> not same. @@ -4448,54 +4397,32 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with assert(bucket(4, sf).isSameFunction(bucket(4, sf))) } - test("SPARK-50593: supportsExpressions admits single-attribute nested, rejects multi-attribute") { - import org.apache.spark.sql.catalyst.expressions.Add + test("SPARK-50593: supportsExpressions admits flat parameterized transforms, " + + "rejects nested and non-reference slots") { + import org.apache.spark.sql.catalyst.expressions.{Add, Expression} val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() - - // Single-attribute nested transform -> admitted (references = {a}). - val nested = TransformExpression(BucketFunction, - Seq(Literal(4), TransformExpression(YearsFunction, Seq(a)))) - assert(nested.references.size == 1) - assert(physical.KeyedPartitioning.supportsExpressions(Seq(nested))) - - // Multi-attribute transform -> rejected at the gate; the positional model (keyPositions) needs - // exactly one clustering column per partition expression. - val multi = TransformExpression(BucketFunction, Seq(Literal(4), Add(a, b))) - assert(multi.references.size == 2) - assert(!physical.KeyedPartitioning.supportsExpressions(Seq(multi))) - } - - test("SPARK-50593: canCreatePartitioning is false for nested transforms (avoids flattening)") { - import org.apache.spark.sql.catalyst.expressions.{Add, Cast, Expression} - val ts = AttributeReference("ts", TimestampType)() - val n = AttributeReference("n", IntegerType)() def bucket(n: Int, e: Expression): TransformExpression = TransformExpression(BucketFunction, Seq(Literal(n), e)) - def years(e: Expression): TransformExpression = TransformExpression(YearsFunction, Seq(e)) - // hasOnlyReferenceArgs: true only when every arg is a literal or a bare column reference. - assert(bucket(4, ts).hasOnlyReferenceArgs) // bare column - assert(!bucket(4, years(ts)).hasOnlyReferenceArgs) // nested transform - assert(!bucket(4, Cast(years(ts), IntegerType)).hasOnlyReferenceArgs) // transform under a cast - assert(!bucket(4, Cast(n, LongType)).hasOnlyReferenceArgs) // cast slot - assert(!bucket(4, Add(n, Literal(1))).hasOnlyReferenceArgs) // arithmetic slot + // Flat parameterized transform over a bare column -> admitted (one non-literal child = column). + assert(physical.KeyedPartitioning.supportsExpressions(Seq(bucket(4, a)))) + // Bare identity column -> admitted. + assert(physical.KeyedPartitioning.supportsExpressions(Seq(a))) - val dist = physical.ClusteredDistribution(Seq(ts)) - val nested = physical.KeyedPartitioning(Seq(bucket(4, years(ts))), Seq.empty) - val flat = physical.KeyedPartitioning(Seq(bucket(4, ts)), Seq.empty) + // Nested transform -> rejected: the non-literal child is a transform, not a column reference. + // SPJ reasons about a transform via its function and literal params alone, which is unsound + // when the remaining argument is itself a transform. + val nested = bucket(4, TransformExpression(YearsFunction, Seq(a))) + assert(!physical.KeyedPartitioning.supportsExpressions(Seq(nested))) - withSQLConf( - SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", - SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false") { - // A nested transform must NOT be repartitioned-to-match: createPartitioning would flatten - // bucket(4, years(ts)) -> bucket(4, key), producing a non-co-locating partitioning. It must - // opt out so EnsureRequirements falls back to a regular shuffle. - assert(!physical.KeyedShuffleSpec(nested, dist).canCreatePartitioning, - "nested transform must opt out of createPartitioning") - // Flat transform remains repartitionable (no regression for v2 bucketing-shuffle). - assert(physical.KeyedShuffleSpec(flat, dist).canCreatePartitioning) - } + // Value-changing slot (a + 1) -> rejected: not a plain column reference. + assert(!physical.KeyedPartitioning.supportsExpressions(Seq(bucket(4, Add(a, Literal(1)))))) + + // Two non-literal column references -> rejected: a partition expression must map to exactly one + // clustering column (the positional keyPositions model needs a single column per transform). + assert(!physical.KeyedPartitioning.supportsExpressions( + Seq(TransformExpression(BucketFunction, Seq(Literal(4), a, b))))) } test("SPARK-50593: integer truncate is reducible via lcm (generalized reducer, non-bucket)") { From 28368ca7d7e70d56c9c3c98accf8b2186eb3fa70 Mon Sep 17 00:00:00 2001 From: metanil Date: Thu, 18 Jun 2026 15:15:37 -0700 Subject: [PATCH 7/9] [SPARK-50593][SQL] Replace ReducibleParameters with V2 Literal[] in the reducer API Carry reducible-function parameters as V2 Literal[] (value + DataType) instead of the ReducibleParameters wrapper, and delete that class. - ReducibleFunction.reducer(Literal[], ReducibleFunction, Literal[]) replaces the ReducibleParameters overload; the deprecated reducer(int, ..., int) is kept for legacy connectors (e.g. Iceberg 1.10). - TransformExpression.extractParameters wraps each literal child as LiteralValue(value, dataType) -- internal value + DataType, no CatalystTypeConverters conversion -- so connectors interpret values via dataType() (resolves the config-dependent Date/Timestamp concern). - Reducer dispatch hardening: - sameArgumentLayout: reduce only when literal/column positions align on both sides (f(id, 2) vs f(4, store_id) is not reducible). - Gate the deprecated int path on dataType == IntegerType, not the boxed runtime class. - Fall back to the generalized overload when the deprecated one is absent or returns null. - Tests: Literal[] reducer end-to-end, deprecated backward-compat, DateType routing, mismatched argument layout, deprecated-null fallback. --- .../catalog/functions/ReducibleFunction.java | 33 ++-- .../functions/ReducibleParameters.java | 147 ------------------ .../expressions/TransformExpression.scala | 103 ++++++------ .../KeyGroupedPartitioningSuite.scala | 68 ++++++-- .../functions/transformFunctions.scala | 66 ++++++-- 5 files changed, 182 insertions(+), 235 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java index c496196f3ca7c..90ab50344c7b3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector.catalog.functions; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Literal; /** * Base class for user-defined functions that can be 'reduced' on another function. @@ -65,34 +66,28 @@ public interface ReducibleFunction { * * If this function is 'reducible' on another function, return the {@link Reducer}. *

    - * This method supports functions with any number of parameters of any type. + * Each parameter is a {@link Literal} carrying both its value and data type, so this method + * supports functions with any number of parameters of any type. {@link Literal#value()} is + * Spark's internal representation (e.g. {@code UTF8String} for strings, {@code Decimal} for + * decimals); use {@link Literal#dataType()} to interpret it. *

    * Examples: *

      - *
    • bucket(4, x) and bucket(2, x): - *
      thisParams = [4], otherParams = [2] - *
      Extract with: thisParams.getInt(0), otherParams.getInt(0) - *
    • - *
    • truncate(x, 3) and truncate(x, 5): - *
      thisParams = [3], otherParams = [5] - *
      Extract with: thisParams.getInt(0), otherParams.getInt(0) - *
    • - *
    • hypothetical range_bucket(x, 0L, 100L, 4): - *
      thisParams = [0L, 100L, 4] - *
      Extract with: thisParams.getLong(0), thisParams.getLong(1), thisParams.getInt(2) - *
    • + *
    • bucket(4, x) and bucket(2, x): thisParams = [4], otherParams = [2]
    • + *
    • truncate(x, 3) and truncate(x, 5): thisParams = [3], otherParams = [5]
    • + *
    • hypothetical range_bucket(x, 0L, 100L, 4): thisParams = [0L, 100L, 4]
    • *
    * - * @param thisParams parameters for this function + * @param thisParams literal parameters for this function * @param otherFunction the other parameterized function - * @param otherParams parameters for the other function + * @param otherParams literal parameters for the other function * @return a reduction function if reducible, null otherwise * @since 5.0.0 */ default Reducer reducer( - ReducibleParameters thisParams, + Literal[] thisParams, ReducibleFunction otherFunction, - ReducibleParameters otherParams) { + Literal[] otherParams) { throw new UnsupportedOperationException(); } @@ -115,7 +110,7 @@ default Reducer reducer( * @param otherNumBuckets parameter for the other function * @return a reduction function if it is reducible, null if not * @deprecated as of 5.0.0. Please override - * {@link #reducer(ReducibleParameters, ReducibleFunction, ReducibleParameters)} instead. + * {@link #reducer(Literal[], ReducibleFunction, Literal[])} instead. * The new overload supports transforms with any number of parameters of any type * (e.g. truncate width, multi-arg range buckets), not just a single int. */ @@ -142,6 +137,6 @@ default Reducer reducer( * @return a reduction function if it is reducible, null if not. */ default Reducer reducer(ReducibleFunction otherFunction) { - return reducer(ReducibleParameters.EMPTY, otherFunction, ReducibleParameters.EMPTY); + return reducer(new Literal[0], otherFunction, new Literal[0]); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java deleted file mode 100644 index 836cd5d890e5d..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleParameters.java +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.connector.catalog.functions; - -import org.apache.spark.annotation.Evolving; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -/** - * Container for reducible function literal parameters. - * Provides type-safe access to parameters of various types. - * - * Examples: - *
      - *
    • bucket(4, col) → ReducibleParameters([4])
    • - *
    • truncate(col, 3) → ReducibleParameters([3])
    • - *
    • range_bucket(col, 0L, 100L, 10) → ReducibleParameters([0L, 100L, 10])
    • - *
    • custom_transform(col, "param") → ReducibleParameters(["param"])
    • - *
    - * - * @since 5.0.0 - */ -@Evolving -public class ReducibleParameters { - public static final ReducibleParameters EMPTY = new ReducibleParameters(); - - private final List values; - - private ReducibleParameters() { - this.values = new ArrayList<>(); - } - - public ReducibleParameters(List values) { - this.values = values; - } - - public ReducibleParameters(Object... values) { - this.values = Arrays.asList(values); - } - - /** - * Get the number of parameters. - */ - public int count() { - return values.size(); - } - - /** - * Check if this container has parameters. - */ - public boolean isEmpty() { - return values.isEmpty(); - } - - /** - * Get parameter at index as Integer. - * @throws ClassCastException if parameter is not an Integer - * @throws IndexOutOfBoundsException if index is invalid - */ - public int getInt(int index) { - return (Integer) values.get(index); - } - - /** - * Get parameter at index as Long. - * @throws ClassCastException if parameter is not a Long - * @throws IndexOutOfBoundsException if index is invalid - */ - public long getLong(int index) { - return (Long) values.get(index); - } - - /** - * Get parameter at index as String. - * @throws ClassCastException if parameter is not a String - * @throws IndexOutOfBoundsException if index is invalid - */ - public String getString(int index) { - return (String) values.get(index); - } - - /** - * Get parameter at index as Double. - * @throws ClassCastException if parameter is not a Double - * @throws IndexOutOfBoundsException if index is invalid - */ - public double getDouble(int index) { - return (Double) values.get(index); - } - - /** - * Get parameter at index as Float. - * @throws ClassCastException if parameter is not a Float - * @throws IndexOutOfBoundsException if index is invalid - */ - public float getFloat(int index) { - return (Float) values.get(index); - } - - /** - * Get raw parameter value at index. - */ - public Object get(int index) { - return values.get(index); - } - - /** - * Get all parameter values as a list. - */ - public List getAll() { - return values; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - ReducibleParameters that = (ReducibleParameters) o; - return values.equals(that.values); - } - - @Override - public int hashCode() { - return values.hashCode(); - } - - @Override - public String toString() { - return "ReducibleParameters(" + values + ")"; - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index c0eab818f2a7f..74c87fbaaf32e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -22,11 +22,12 @@ import scala.util.Try import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.FUNCTION_NAME -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction, ReducibleParameters, ScalarFunction} +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction, ScalarFunction} +import org.apache.spark.sql.connector.expressions.{Literal => V2Literal, LiteralValue} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, IntegerType} /** * Represents a partition transform expression, for instance, `bucket`, `days`, `years`, etc. @@ -126,21 +127,30 @@ case class TransformExpression(function: BoundFunction, children: Seq[Expression } /** - * Extract all literal parameters from a transform expression. - * Returns ReducibleParameters containing the literal values in order. + * Extract all literal parameters from a transform expression as V2 [[V2Literal]]s, preserving + * each value's internal representation and its `DataType`. Connectors interpret the value via + * the accompanying `DataType` rather than relying on a pre-converted JVM type. * * Examples: - * bucket(4, col) => ReducibleParameters([4]) - * truncate(col, 3) => ReducibleParameters([3]) - * days(col) => ReducibleParameters([]) (no literals) + * bucket(4, col) => [Literal(4, IntegerType)] + * truncate(col, 3) => [Literal(3, IntegerType)] + * days(col) => [] (no literals) */ - private def extractParameters(expr: TransformExpression): ReducibleParameters = { - import scala.jdk.CollectionConverters._ - val values = expr.literalChildren.map { - case Literal(value, dt) => CatalystTypeConverters.convertToScala(value, dt) - } - new ReducibleParameters(values.asJava) - } + private def extractParameters(expr: TransformExpression): Array[V2Literal[_]] = + expr.literalChildren.map(l => LiteralValue(l.value, l.dataType): V2Literal[_]).toArray + + /** + * Whether this transform and `other` share the same argument layout: equal arity, and at each + * position a literal slot aligns with a literal slot (and a non-literal with a non-literal). + * Literal *values* may differ -- that is what a [[Reducer]] reconciles. + */ + private def sameArgumentLayout(other: TransformExpression): Boolean = + children.length == other.children.length && + children.zip(other.children).forall { + case (_: Literal, _: Literal) => true + case (_: Literal, _) | (_, _: Literal) => false + case _ => true + } /** * Return a Reducer for a reducible function on another reducible function @@ -151,44 +161,49 @@ case class TransformExpression(function: BoundFunction, children: Seq[Expression thisExpr: TransformExpression, otherFunction: ReducibleFunction[_, _], otherExpr: TransformExpression): Option[Reducer[_, _]] = { + if (!thisExpr.sameArgumentLayout(otherExpr)) { + return None + } + val thisParams = extractParameters(thisExpr) val otherParams = extractParameters(otherExpr) val thisName = thisExpr.function.canonicalName() - def isSingleInt(p: ReducibleParameters): Boolean = { - p.count() == 1 && p.get(0).isInstanceOf[Int] + // Gate on DataType, not the boxed runtime class (DateType/YearMonthInterval box to Int). + def isSingleInt(p: Array[V2Literal[_]]): Boolean = { + p.length == 1 && p(0).dataType == IntegerType } - // Both thrown exceptions and `null` returns collapse to None; any failure - // to compute a reducer falls back to a shuffle (no SPJ). - def tryReduce[R](call: => R): Try[Option[R]] = { - val attempt = Try(Option(call)) - attempt.failed.foreach { - case e: UnsupportedOperationException => - logWarning(log"V2 function ${MDC(FUNCTION_NAME, thisName)} threw " + - log"UnsupportedOperationException; treating as not reducible. Override " + - log"reducer(ReducibleParameters, ReducibleFunction, ReducibleParameters) " + - log"to enable SPJ.") - case _ => + // Run a reducer overload; a thrown exception or a null both become None. warnOnUoe logs a hint + // when the function implements no usable reducer overload. + def attempt[R](call: => R, warnOnUoe: Boolean): Option[R] = { + val t = Try(Option(call)) + if (warnOnUoe) { + t.failed.foreach { + case _: UnsupportedOperationException => + logWarning(log"V2 function ${MDC(FUNCTION_NAME, thisName)} threw " + + log"UnsupportedOperationException; treating as not reducible. Override " + + log"reducer(Literal[], ReducibleFunction, Literal[]) to enable SPJ.") + case _ => + } } - - attempt + t.toOption.flatten } - val res: Try[Option[Reducer[_, _]]] = - if (thisParams.isEmpty && otherParams.isEmpty) { - tryReduce(thisFunction.reducer(otherFunction)) - } else if (isSingleInt(thisParams) && isSingleInt(otherParams)) { - // Try deprecated int-API first for legacy connectors (e.g. Iceberg 1.10); - // the first attempt is silent because we have a fallback. Only the fallback warns. - Try(Option(thisFunction.reducer( - thisParams.getInt(0), otherFunction, otherParams.getInt(0)))) - .orElse(tryReduce(thisFunction.reducer(thisParams, otherFunction, otherParams))) - } else { - // Parameterized functions (bucket, truncate, etc.) - tryReduce(thisFunction.reducer(thisParams, otherFunction, otherParams)) - } - res.toOption.flatten + if (thisParams.isEmpty && otherParams.isEmpty) { + attempt(thisFunction.reducer(otherFunction), warnOnUoe = true) + } else if (isSingleInt(thisParams) && isSingleInt(otherParams)) { + // Try the deprecated int API first (legacy connectors); fall back to the generalized overload + // when it is absent or returns null. Option.orElse fires on None, covering both. + attempt(thisFunction.reducer( + thisParams(0).value().asInstanceOf[Int], otherFunction, + otherParams(0).value().asInstanceOf[Int]), warnOnUoe = false) + .orElse( + attempt(thisFunction.reducer(thisParams, otherFunction, otherParams), warnOnUoe = true)) + } else { + // Parameterized functions (bucket, truncate, etc.) + attempt(thisFunction.reducer(thisParams, otherFunction, otherParams), warnOnUoe = true) + } } override def dataType: DataType = function.resultType() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 913e2628e6097..4a29d272d899f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -4189,8 +4189,6 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } } - // === SPARK-50593: Truncate SPJ support tests === - test("SPARK-50593: cross-function truncate vs bucket should NOT trigger SPJ") { val partitions1 = Array( Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(3)) @@ -4233,7 +4231,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } test("SPARK-50593: truncate(3) vs truncate(5) triggers SPJ via width reducer") { - // Exercises the ReducibleParameters-based reducer path end-to-end: truncate widths 3 and 5 + // Exercises the Literal[]-based reducer path end-to-end: truncate widths 3 and 5 // are mutually reducible (reduce the larger to the smaller), so SPJ must avoid the shuffle. val table1 = "trunc_three" val table2 = "trunc_five" @@ -4276,10 +4274,12 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } } - test("SPARK-50593: existing bucket SPJ still works with ReducibleParameters API") { - // Exercises the new ReducibleParameters-based reducer path end-to-end: bucket(4) and + test("SPARK-50593: existing bucket SPJ still works with Literal[] API") { + // Exercises the new Literal[]-based reducer path end-to-end: bucket(4) and // bucket(2) differ, so SPJ can only avoid the shuffle if BucketFunction's reducer - // (now implemented via ReducibleParameters) correctly returns a GCD-based Reducer. + // (now implemented via Literal[] params) correctly returns a GCD-based Reducer. + // BucketFunction overrides only the new API, so this also covers the deprecated->new + // fallback: the single-int dispatch tries reducer(int, ...) first (UOE), then the Literal[]. val table1 = "bucket_compat1" val table2 = "bucket_compat2" @@ -4426,7 +4426,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } test("SPARK-50593: integer truncate is reducible via lcm (generalized reducer, non-bucket)") { - // A second reducible transform exercising the generalized ReducibleParameters API with reducer + // A second reducible transform exercising the generalized Literal[] reducer API with reducer // math distinct from bucket (GCD) and string truncate (prefix-min): integer truncate snaps to // a coarser grid, so truncate(v, W1) and truncate(v, W2) reduce onto multiples of lcm(W1, W2). import org.apache.spark.sql.catalyst.expressions.Expression @@ -4454,11 +4454,10 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with assert(itrunc(id, 3).isCompatible(itrunc(id, 5))) // coprime -> lcm(3, 5) = 15 } - test("SPARK-50593: ReducibleParameters backward compat - old int API still works via default") { - // Verifies the default reducer(ReducibleParameters, ...) implementation correctly - // delegates to the deprecated reducer(int, func, int) when a ReducibleFunction only - // overrides the old API. This mirrors how Iceberg 1.10.0 (and earlier) ship without - // knowledge of ReducibleParameters. + test("SPARK-50593: deprecated int reducer API still works (legacy connector backward compat)") { + // The reducer dispatch attempts the deprecated reducer(int, func, int) first for single-int + // params, so a ReducibleFunction that overrides ONLY the deprecated method still reduces. + // This mirrors how Iceberg 1.10.0 (and earlier) ship -- they predate the Literal[] API. val bucketExpr4 = TransformExpression(LegacyBucketFunction, Seq(Literal(4), attr("id"))) val bucketExpr2 = TransformExpression(LegacyBucketFunction, Seq(Literal(2), attr("id"))) @@ -4471,4 +4470,49 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with assert(r.reduce(3) == 1, s"Expected reduce(3) == 1, got ${r.reduce(3)}") assert(r.reduce(2) == 0, s"Expected reduce(2) == 0, got ${r.reduce(2)}") } + + test("SPARK-50593: a non-IntegerType param (DateType) does not reach the deprecated " + + "int reducer") { + // DateType is stored as a boxed Integer (epoch days) internally, so the reducer dispatch must + // key off the DataType, not the runtime class -- otherwise a DateType param is mistaken for the + // bucket-style int param and routed to the deprecated reducer(int, ...). LegacyBucketFunction + // overrides ONLY that deprecated method, so with a DateType param it must be unreachable, + // leaving the pair not reducible (rather than producing a bogus GCD reducer over epoch-days). + val l = TransformExpression(LegacyBucketFunction, Seq(Literal(8, DateType), attr("id"))) + val r = TransformExpression(LegacyBucketFunction, Seq(Literal(4, DateType), attr("id"))) + assert(!l.isSameFunction(r)) + assert(!l.isCompatible(r), "a DateType param must not reach the deprecated int reducer") + assert(l.reducers(r).isEmpty && r.reducers(l).isEmpty) + } + + test("SPARK-50593: mismatched column/literal argument layout is not reducible") { + // Both transforms pass the strict gate (one column-reference non-literal child), but the column + // and literal sit in swapped positions: truncate(id, 2) is (col, lit) while truncate(4, sid) is + // (lit, col). The reducer only sees the literal positions ([2] vs [4]), so without an + // argument-layout check it would wrongly reduce these and co-locate non-matching rows. + // IntegerTruncateFunction has two same-typed (Int) args, which makes this layout reachable. + val l = TransformExpression(IntegerTruncateFunction, Seq(attr("id"), Literal(2))) + val r = TransformExpression(IntegerTruncateFunction, Seq(Literal(4), attr("store_id"))) + assert(!l.isSameFunction(r)) + assert(!l.isCompatible(r), "swapped column/literal layout must not be reducible") + assert(l.reducers(r).isEmpty && r.reducers(l).isEmpty) + + // Control: same layout (col, lit) on both sides remains reducible via lcm(2, 4). + val a = TransformExpression(IntegerTruncateFunction, Seq(attr("id"), Literal(2))) + val b = TransformExpression(IntegerTruncateFunction, Seq(attr("store_id"), Literal(4))) + assert(a.isCompatible(b), "aligned (col, lit) layout must remain reducible") + } + + test("SPARK-50593: deprecated reducer returning null falls back to the generalized overload") { + // DualApiBucketFunction implements both overloads: the deprecated reducer(int, ...) returns + // null, while the generalized reducer(Literal[], ...) returns a valid GCD reducer. The dispatch + // tries the deprecated one first; a null there must fall through to the generalized overload + // (Option.orElse fires on None), not be mistaken for "not reducible". + val l = TransformExpression(DualApiBucketFunction, Seq(Literal(4), attr("id"))) + val r = TransformExpression(DualApiBucketFunction, Seq(Literal(2), attr("store_id"))) + assert(l.isCompatible(r), "a null from the deprecated reducer must fall back to the new API") + val red = l.reducers(r) + assert(red.isDefined, "generalized reducer must be reached when deprecated returns null") + assert(red.get.asInstanceOf[Reducer[Integer, Integer]].reduce(3) == 1) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 094db237a0c8c..e641ce516a7f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -21,6 +21,7 @@ import java.time.temporal.ChronoUnit import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.connector.expressions.Literal import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -213,13 +214,13 @@ object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, In } override def reducer( - thisParams: ReducibleParameters, + thisParams: Array[Literal[_]], otherFunc: ReducibleFunction[_, _], - otherParams: ReducibleParameters): Reducer[Int, Int] = { + otherParams: Array[Literal[_]]): Reducer[Int, Int] = { if (otherFunc == BucketFunction) { - val thisNumBuckets = thisParams.getInt(0) - val otherNumBuckets = otherParams.getInt(0) + val thisNumBuckets = thisParams(0).value().asInstanceOf[Int] + val otherNumBuckets = otherParams(0).value().asInstanceOf[Int] val gcd = this.gcd(thisNumBuckets, otherNumBuckets) if (gcd > 1 && gcd != thisNumBuckets) { @@ -240,7 +241,7 @@ case class BucketReducer(divisor: Int) extends Reducer[Int, Int] { /** * A bucket function that only overrides the deprecated `reducer(int, func, int)` method, - * not the new `reducer(ReducibleParameters, func, ReducibleParameters)` method. + * not the new `reducer(Literal[], func, Literal[])` method. * * Used to verify that the default implementation of the new method correctly falls back * to the deprecated int-based API, so legacy implementations continue to work. @@ -269,6 +270,45 @@ object LegacyBucketFunction extends ScalarFunction[Int] with ReducibleFunction[I } } +/** + * A bucket function that implements BOTH reducer overloads: the deprecated `reducer(int, ..., int)` + * always returns null (not reducible via the old API), while the new `reducer(Literal[], ...)` + * returns a GCD-based reducer. Used to verify that the dispatch falls back to the generalized + * overload when the deprecated one returns null (not only when it throws). + */ +object DualApiBucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] { + override def inputTypes(): Array[DataType] = Array(IntegerType, LongType) + override def resultType(): DataType = IntegerType + override def name(): String = "dual_bucket" + override def canonicalName(): String = name() + override def toString: String = name() + override def produceResult(input: InternalRow): Int = { + Math.floorMod(input.getLong(1), input.getInt(0)) + } + + // Deprecated API: intentionally signals "not reducible" via null (not via an exception). + override def reducer( + thisNumBuckets: Int, + otherFunc: ReducibleFunction[_, _], + otherNumBuckets: Int): Reducer[Int, Int] = null + + // New API: a real GCD-based reducer. + override def reducer( + thisParams: Array[Literal[_]], + otherFunc: ReducibleFunction[_, _], + otherParams: Array[Literal[_]]): Reducer[Int, Int] = { + if (otherFunc == DualApiBucketFunction) { + val thisNumBuckets = thisParams(0).value().asInstanceOf[Int] + val otherNumBuckets = otherParams(0).value().asInstanceOf[Int] + val gcd = BigInt(thisNumBuckets).gcd(BigInt(otherNumBuckets)).toInt + if (gcd > 1 && gcd != thisNumBuckets) { + return BucketReducer(gcd) + } + } + null + } +} + object UnboundStringSelfFunction extends UnboundFunction { override def bind(inputType: StructType): BoundFunction = StringSelfFunction override def description(): String = name() @@ -328,13 +368,13 @@ object TruncateFunction } override def reducer( - thisParams: ReducibleParameters, + thisParams: Array[Literal[_]], otherFunc: ReducibleFunction[_, _], - otherParams: ReducibleParameters): Reducer[UTF8String, UTF8String] = { + otherParams: Array[Literal[_]]): Reducer[UTF8String, UTF8String] = { if (otherFunc == TruncateFunction) { - val thisWidth = thisParams.getInt(0) - val otherWidth = otherParams.getInt(0) + val thisWidth = thisParams(0).value().asInstanceOf[Int] + val otherWidth = otherParams(0).value().asInstanceOf[Int] val smallerWidth = math.min(thisWidth, otherWidth) if (smallerWidth != thisWidth) { @@ -378,12 +418,12 @@ object IntegerTruncateFunction } override def reducer( - thisParams: ReducibleParameters, + thisParams: Array[Literal[_]], otherFunc: ReducibleFunction[_, _], - otherParams: ReducibleParameters): Reducer[Int, Int] = { + otherParams: Array[Literal[_]]): Reducer[Int, Int] = { if (otherFunc == IntegerTruncateFunction) { - val thisWidth = thisParams.getInt(0) - val otherWidth = otherParams.getInt(0) + val thisWidth = thisParams(0).value().asInstanceOf[Int] + val otherWidth = otherParams(0).value().asInstanceOf[Int] val common = lcm(thisWidth, otherWidth) // Only the finer side reduces; if `common == thisWidth` this side is already the common grid. if (common != thisWidth) { From 9d92888fa6093c89ef5aa7ae644cdfeaf9538aa9 Mon Sep 17 00:00:00 2001 From: akhadka Date: Thu, 18 Jun 2026 17:38:23 -0700 Subject: [PATCH 8/9] [SPARK-50593][SQL] Reject non-scalar literal params in the reducer API --- .../catalog/functions/ReducibleFunction.java | 9 ++++--- .../expressions/TransformExpression.scala | 27 ++++++++++++------- .../KeyGroupedPartitioningSuite.scala | 14 ++++++++++ .../functions/transformFunctions.scala | 19 +++++++++++++ 4 files changed, 56 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java index 90ab50344c7b3..a7bf64a6a71e7 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -66,10 +66,11 @@ public interface ReducibleFunction { * * If this function is 'reducible' on another function, return the {@link Reducer}. *

    - * Each parameter is a {@link Literal} carrying both its value and data type, so this method - * supports functions with any number of parameters of any type. {@link Literal#value()} is - * Spark's internal representation (e.g. {@code UTF8String} for strings, {@code Decimal} for - * decimals); use {@link Literal#dataType()} to interpret it. + * Each parameter is a scalar {@link Literal} carrying both its value and data type. Parameters + * are always scalar literals (e.g. bucket numBuckets, truncate width); complex types (array, map, + * struct) are not passed here. {@link Literal#value()} is Spark's internal representation (e.g. + * {@code UTF8String} for strings, {@code Decimal} for decimals); use {@link Literal#dataType()} + * to interpret it. *

    * Examples: *

      diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index 74c87fbaaf32e..0119cd68c8bd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction, ScalarFunction} import org.apache.spark.sql.connector.expressions.{Literal => V2Literal, LiteralValue} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.sql.types.{AtomicType, DataType, IntegerType} /** * Represents a partition transform expression, for instance, `bucket`, `days`, `years`, etc. @@ -127,17 +127,17 @@ case class TransformExpression(function: BoundFunction, children: Seq[Expression } /** - * Extract all literal parameters from a transform expression as V2 [[V2Literal]]s, preserving - * each value's internal representation and its `DataType`. Connectors interpret the value via - * the accompanying `DataType` rather than relying on a pre-converted JVM type. + * Extract all literal parameters of this transform as V2 [[V2Literal]]s, preserving each value's + * internal representation and its `DataType`. Connectors interpret the value via the accompanying + * `DataType` rather than relying on a pre-converted JVM type. * * Examples: * bucket(4, col) => [Literal(4, IntegerType)] * truncate(col, 3) => [Literal(3, IntegerType)] * days(col) => [] (no literals) */ - private def extractParameters(expr: TransformExpression): Array[V2Literal[_]] = - expr.literalChildren.map(l => LiteralValue(l.value, l.dataType): V2Literal[_]).toArray + private def extractParameters: Array[V2Literal[_]] = + literalChildren.map(l => LiteralValue(l.value, l.dataType): V2Literal[_]).toArray /** * Whether this transform and `other` share the same argument layout: equal arity, and at each @@ -152,6 +152,14 @@ case class TransformExpression(function: BoundFunction, children: Seq[Expression case _ => true } + /** + * Whether every literal parameter is a scalar (an [[AtomicType]]). Reducer parameters are scalar + * literals; this never forwards a complex Catalyst container (ArrayData / MapData / InternalRow) + * across the public reducer boundary -- such a transform is simply treated as not reducible. + */ + private def scalarLiteralParams: Boolean = + literalChildren.forall(_.dataType.isInstanceOf[AtomicType]) + /** * Return a Reducer for a reducible function on another reducible function * Handles both parameterized (bucket, truncate) and non-parameterized (days, hours) functions. @@ -161,12 +169,13 @@ case class TransformExpression(function: BoundFunction, children: Seq[Expression thisExpr: TransformExpression, otherFunction: ReducibleFunction[_, _], otherExpr: TransformExpression): Option[Reducer[_, _]] = { - if (!thisExpr.sameArgumentLayout(otherExpr)) { + if (!thisExpr.sameArgumentLayout(otherExpr) || + !thisExpr.scalarLiteralParams || !otherExpr.scalarLiteralParams) { return None } - val thisParams = extractParameters(thisExpr) - val otherParams = extractParameters(otherExpr) + val thisParams = thisExpr.extractParameters + val otherParams = otherExpr.extractParameters val thisName = thisExpr.function.canonicalName() // Gate on DataType, not the boxed runtime class (DateType/YearMonthInterval box to Int). diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 4a29d272d899f..7dc5749c537c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -4515,4 +4515,18 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with assert(red.isDefined, "generalized reducer must be reached when deprecated returns null") assert(red.get.asInstanceOf[Reducer[Integer, Integer]].reduce(3) == 1) } + + test("SPARK-50593: a complex (non-scalar) literal param is not reducible") { + // Reducer parameters must be scalar literals. ArrayParamFunction's generalized reducer returns + // a reducer unconditionally, so reaching it at all is the leak; the scalar guard must refuse + // the ArrayType literal param first. Different array values keep isSameFunction false, forcing + // the reducer path where the guard applies. + val l = TransformExpression(ArrayParamFunction, + Seq(Literal.create(Array(1, 2, 3), ArrayType(IntegerType)), attr("id"))) + val r = TransformExpression(ArrayParamFunction, + Seq(Literal.create(Array(4, 5, 6), ArrayType(IntegerType)), attr("store_id"))) + assert(!l.isSameFunction(r)) + assert(!l.isCompatible(r), "a complex literal param must not be reducible") + assert(l.reducers(r).isEmpty && r.reducers(l).isEmpty) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index e641ce516a7f7..b24a2c731adaf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -309,6 +309,25 @@ object DualApiBucketFunction extends ScalarFunction[Int] with ReducibleFunction[ } } +/** + * A function with a complex (ArrayType) literal parameter. Its generalized reducer returns a valid + * reducer unconditionally, so a test can prove the dispatch refuses to invoke it for a non-scalar + * literal param (rather than the call happening to fail on a cast). + */ +object ArrayParamFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] { + override def inputTypes(): Array[DataType] = Array(ArrayType(IntegerType), LongType) + override def resultType(): DataType = IntegerType + override def name(): String = "array_param" + override def canonicalName(): String = name() + override def toString: String = name() + override def produceResult(input: InternalRow): Int = input.getInt(1) + + override def reducer( + thisParams: Array[Literal[_]], + otherFunc: ReducibleFunction[_, _], + otherParams: Array[Literal[_]]): Reducer[Int, Int] = BucketReducer(1) +} + object UnboundStringSelfFunction extends UnboundFunction { override def bind(inputType: StructType): BoundFunction = StringSelfFunction override def description(): String = name() From 7ea6de93372b2855edd224558504ff3364a96e79 Mon Sep 17 00:00:00 2001 From: akhadka Date: Fri, 19 Jun 2026 19:25:51 -0700 Subject: [PATCH 9/9] [SPARK-50593][SQL] Share childrenMatch helper for reducer layout check; stamp reducer API @since 4.3.0 --- .../catalog/functions/ReducibleFunction.java | 6 ++-- .../expressions/TransformExpression.scala | 29 ++++++++++--------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java index a7bf64a6a71e7..f05c0f7767949 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -83,7 +83,7 @@ public interface ReducibleFunction { * @param otherFunction the other parameterized function * @param otherParams literal parameters for the other function * @return a reduction function if reducible, null otherwise - * @since 5.0.0 + * @since 4.3.0 */ default Reducer reducer( Literal[] thisParams, @@ -110,12 +110,12 @@ default Reducer reducer( * @param otherBucketFunction the other parameterized function * @param otherNumBuckets parameter for the other function * @return a reduction function if it is reducible, null if not - * @deprecated as of 5.0.0. Please override + * @deprecated as of 4.3.0. Please override * {@link #reducer(Literal[], ReducibleFunction, Literal[])} instead. * The new overload supports transforms with any number of parameters of any type * (e.g. truncate width, multi-arg range buckets), not just a single int. */ - @Deprecated(since = "5.0.0") + @Deprecated(since = "4.3.0") default Reducer reducer( int thisNumBuckets, ReducibleFunction otherBucketFunction, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index 0119cd68c8bd0..a22aadb22786c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -75,13 +75,19 @@ case class TransformExpression(function: BoundFunction, children: Seq[Expression */ def isSameFunction(other: TransformExpression): Boolean = function.canonicalName() == other.function.canonicalName() && - children.length == other.children.length && + childrenMatch(other)(_ == _) + + /** + * Per-position match of children, requiring equal arity. Literal slots are compared by the + * caller-supplied `literalsMatch`; nested transform slots must recursively be the same function; + * any other slot must be a plain column reference on both sides. + */ + private def childrenMatch(other: TransformExpression) + (literalsMatch: (Literal, Literal) => Boolean): Boolean = + children.length == other.children.length && children.zip(other.children).forall { - case (l1: Literal, l2: Literal) => l1 == l2 + case (l1: Literal, l2: Literal) => literalsMatch(l1, l2) case (t1: TransformExpression, t2: TransformExpression) => t1.isSameFunction(t2) - // Any other pair must be a plain column reference on both sides. Column identity is - // ignored (reconciled separately via positional matching); a non-reference slot - // (Add, Cast, ...) or a literal/transform-vs-reference mismatch is "not the same". case (c1, c2) => TransformExpression.isColumnRef(c1) && TransformExpression.isColumnRef(c2) } @@ -140,17 +146,12 @@ case class TransformExpression(function: BoundFunction, children: Seq[Expression literalChildren.map(l => LiteralValue(l.value, l.dataType): V2Literal[_]).toArray /** - * Whether this transform and `other` share the same argument layout: equal arity, and at each - * position a literal slot aligns with a literal slot (and a non-literal with a non-literal). - * Literal *values* may differ -- that is what a [[Reducer]] reconciles. + * Reducer precondition: same argument layout/structure as `other` (arity, aligned slots, equal + * nested transforms, column refs elsewhere). Only literal *values* may differ. Unlike + * [[isSameFunction]] the function name is not compared. */ private def sameArgumentLayout(other: TransformExpression): Boolean = - children.length == other.children.length && - children.zip(other.children).forall { - case (_: Literal, _: Literal) => true - case (_: Literal, _) | (_, _: Literal) => false - case _ => true - } + childrenMatch(other)((_, _) => true) /** * Whether every literal parameter is a scalar (an [[AtomicType]]). Reducer parameters are scalar