diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java
index ef1a14e50cdad..f05c0f7767949 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java
@@ -17,6 +17,7 @@
package org.apache.spark.sql.connector.catalog.functions;
import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.Literal;
/**
* Base class for user-defined functions that can be 'reduced' on another function.
@@ -60,6 +61,37 @@
@Evolving
public interface ReducibleFunction {
+ /**
+ * Generic reducer for parameterized functions (bucket, truncate, etc.).
+ *
+ * If this function is 'reducible' on another function, return the {@link Reducer}.
+ *
+ * Each parameter is a scalar {@link Literal} carrying both its value and data type. Parameters
+ * are always scalar literals (e.g. bucket numBuckets, truncate width); complex types (array, map,
+ * struct) are not passed here. {@link Literal#value()} is Spark's internal representation (e.g.
+ * {@code UTF8String} for strings, {@code Decimal} for decimals); use {@link Literal#dataType()}
+ * to interpret it.
+ *
+ * Examples:
+ *
+ * - bucket(4, x) and bucket(2, x): thisParams = [4], otherParams = [2]
+ * - truncate(x, 3) and truncate(x, 5): thisParams = [3], otherParams = [5]
+ * - hypothetical range_bucket(x, 0L, 100L, 4): thisParams = [0L, 100L, 4]
+ *
+ *
+ * @param thisParams literal parameters for this function
+ * @param otherFunction the other parameterized function
+ * @param otherParams literal parameters for the other function
+ * @return a reduction function if reducible, null otherwise
+ * @since 4.3.0
+ */
+ default Reducer reducer(
+ Literal>[] thisParams,
+ ReducibleFunction, ?> otherFunction,
+ Literal>[] otherParams) {
+ throw new UnsupportedOperationException();
+ }
+
/**
* This method is for the bucket function.
*
@@ -78,7 +110,12 @@ public interface ReducibleFunction {
* @param otherBucketFunction the other parameterized function
* @param otherNumBuckets parameter for the other function
* @return a reduction function if it is reducible, null if not
+ * @deprecated as of 4.3.0. Please override
+ * {@link #reducer(Literal[], ReducibleFunction, Literal[])} instead.
+ * The new overload supports transforms with any number of parameters of any type
+ * (e.g. truncate width, multi-arg range buckets), not just a single int.
*/
+ @Deprecated(since = "4.3.0")
default Reducer reducer(
int thisNumBuckets,
ReducibleFunction, ?> otherBucketFunction,
@@ -101,6 +138,6 @@ default Reducer reducer(
* @return a reduction function if it is reducible, null if not.
*/
default Reducer reducer(ReducibleFunction, ?> otherFunction) {
- throw new UnsupportedOperationException();
+ return reducer(new Literal>[0], otherFunction, new Literal>[0]);
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala
index 9041ed15fc501..a22aadb22786c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala
@@ -17,45 +17,79 @@
package org.apache.spark.sql.catalyst.expressions
+import scala.annotation.tailrec
+import scala.util.Try
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.LogKeys.FUNCTION_NAME
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction, ScalarFunction}
+import org.apache.spark.sql.connector.expressions.{Literal => V2Literal, LiteralValue}
import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types.{AtomicType, DataType, IntegerType}
/**
* Represents a partition transform expression, for instance, `bucket`, `days`, `years`, etc.
*
* @param function the transform function itself. Spark will use it to decide whether two
* partition transform expressions are compatible.
- * @param numBucketsOpt the number of buckets if the transform is `bucket`. Unset otherwise.
*/
-case class TransformExpression(
- function: BoundFunction,
- children: Seq[Expression],
- numBucketsOpt: Option[Int] = None) extends Expression {
+case class TransformExpression(function: BoundFunction, children: Seq[Expression])
+ extends Expression with Logging {
override def nullable: Boolean = true
/**
- * Whether this [[TransformExpression]] has the same semantics as `other`.
- * For instance, `bucket(32, c)` is equal to `bucket(32, d)`, but not to `bucket(16, d)` or
- * `year(c)`.
+ * Extract literal children (constant parameters) from this transform. These are constant
+ * arguments like width in truncate(col, width). Literals are compared when checking if two
+ * transforms are the same.
+ */
+ private lazy val literalChildren: Seq[Literal] =
+ children.collect { case l: Literal => l }
+
+ /**
+ * Whether this [[TransformExpression]] has the same semantics as `other`. For instance,
+ * `bucket(32, c)` is equal to `bucket(32, d)`, but not to `bucket(16, d)` or `year(c)`.
+ * Similarly, `truncate(c, 2)` is equal to `truncate(d, 2)`, but may not to `truncate(c, 4)`.
*
* This will be used, for instance, by Spark to determine whether storage-partitioned join can
* be triggered, by comparing partition transforms from both sides of the join and checking
* whether they are compatible.
*
- * @param other the transform expression to compare to
- * @return true if this and `other` has the same semantics w.r.t to transform, false otherwise.
+ * Two transforms are considered the same when they have the same function name, the same arity,
+ * and each pair of corresponding children matches:
+ * - literal arguments must be equal (e.g. numBuckets for bucket, width for truncate), so that
+ * `bucket(32, c)` is not the same as `bucket(16, c)`;
+ * - nested transform arguments must recursively be the same function, so that
+ * `bucket(4, years(c))` is not the same as `bucket(4, days(c))`;
+ * - everything else must be a plain column reference on both sides. Column identity is
+ * intentionally ignored (it is reconciled separately via positional matching), but a
+ * non-reference slot such as `c + 1` or `cast(c)`, or a literal/transform-vs-reference
+ * mismatch, is treated as not the same.
+ *
+ * @param other
+ * the transform expression to compare to
+ * @return
+ * true if this and `other` has the same semantics w.r.t to transform, false otherwise.
*/
- def isSameFunction(other: TransformExpression): Boolean = other match {
- case TransformExpression(otherFunction, _, otherNumBucketsOpt) =>
- function.canonicalName() == otherFunction.canonicalName() &&
- numBucketsOpt == otherNumBucketsOpt
- case _ =>
- false
- }
+ def isSameFunction(other: TransformExpression): Boolean =
+ function.canonicalName() == other.function.canonicalName() &&
+ childrenMatch(other)(_ == _)
+
+ /**
+ * Per-position match of children, requiring equal arity. Literal slots are compared by the
+ * caller-supplied `literalsMatch`; nested transform slots must recursively be the same function;
+ * any other slot must be a plain column reference on both sides.
+ */
+ private def childrenMatch(other: TransformExpression)
+ (literalsMatch: (Literal, Literal) => Boolean): Boolean =
+ children.length == other.children.length &&
+ children.zip(other.children).forall {
+ case (l1: Literal, l2: Literal) => literalsMatch(l1, l2)
+ case (t1: TransformExpression, t2: TransformExpression) => t1.isSameFunction(t2)
+ case (c1, c2) => TransformExpression.isColumnRef(c1) && TransformExpression.isColumnRef(c2)
+ }
/**
* Whether this [[TransformExpression]]'s function is compatible with the `other`
@@ -73,8 +107,8 @@ case class TransformExpression(
} else {
(function, other.function) match {
case (f: ReducibleFunction[_, _], o: ReducibleFunction[_, _]) =>
- val thisReducer = reducer(f, numBucketsOpt, o, other.numBucketsOpt)
- val otherReducer = reducer(o, other.numBucketsOpt, f, numBucketsOpt)
+ val thisReducer = reducer(f, this, o, other)
+ val otherReducer = reducer(o, other, f, this)
thisReducer.isDefined || otherReducer.isDefined
case _ => false
}
@@ -92,24 +126,94 @@ case class TransformExpression(
*/
def reducers(other: TransformExpression): Option[Reducer[_, _]] = {
(function, other.function) match {
- case(e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) =>
- reducer(e1, numBucketsOpt, e2, other.numBucketsOpt)
+ case (e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) =>
+ reducer(e1, this, e2, other)
case _ => None
}
}
- // Return a Reducer for a reducible function on another reducible function
+ /**
+ * Extract all literal parameters of this transform as V2 [[V2Literal]]s, preserving each value's
+ * internal representation and its `DataType`. Connectors interpret the value via the accompanying
+ * `DataType` rather than relying on a pre-converted JVM type.
+ *
+ * Examples:
+ * bucket(4, col) => [Literal(4, IntegerType)]
+ * truncate(col, 3) => [Literal(3, IntegerType)]
+ * days(col) => [] (no literals)
+ */
+ private def extractParameters: Array[V2Literal[_]] =
+ literalChildren.map(l => LiteralValue(l.value, l.dataType): V2Literal[_]).toArray
+
+ /**
+ * Reducer precondition: same argument layout/structure as `other` (arity, aligned slots, equal
+ * nested transforms, column refs elsewhere). Only literal *values* may differ. Unlike
+ * [[isSameFunction]] the function name is not compared.
+ */
+ private def sameArgumentLayout(other: TransformExpression): Boolean =
+ childrenMatch(other)((_, _) => true)
+
+ /**
+ * Whether every literal parameter is a scalar (an [[AtomicType]]). Reducer parameters are scalar
+ * literals; this never forwards a complex Catalyst container (ArrayData / MapData / InternalRow)
+ * across the public reducer boundary -- such a transform is simply treated as not reducible.
+ */
+ private def scalarLiteralParams: Boolean =
+ literalChildren.forall(_.dataType.isInstanceOf[AtomicType])
+
+ /**
+ * Return a Reducer for a reducible function on another reducible function
+ * Handles both parameterized (bucket, truncate) and non-parameterized (days, hours) functions.
+ */
private def reducer(
thisFunction: ReducibleFunction[_, _],
- thisNumBucketsOpt: Option[Int],
+ thisExpr: TransformExpression,
otherFunction: ReducibleFunction[_, _],
- otherNumBucketsOpt: Option[Int]): Option[Reducer[_, _]] = {
- val res = (thisNumBucketsOpt, otherNumBucketsOpt) match {
- case (Some(numBuckets), Some(otherNumBuckets)) =>
- thisFunction.reducer(numBuckets, otherFunction, otherNumBuckets)
- case _ => thisFunction.reducer(otherFunction)
+ otherExpr: TransformExpression): Option[Reducer[_, _]] = {
+ if (!thisExpr.sameArgumentLayout(otherExpr) ||
+ !thisExpr.scalarLiteralParams || !otherExpr.scalarLiteralParams) {
+ return None
+ }
+
+ val thisParams = thisExpr.extractParameters
+ val otherParams = otherExpr.extractParameters
+ val thisName = thisExpr.function.canonicalName()
+
+ // Gate on DataType, not the boxed runtime class (DateType/YearMonthInterval box to Int).
+ def isSingleInt(p: Array[V2Literal[_]]): Boolean = {
+ p.length == 1 && p(0).dataType == IntegerType
+ }
+
+ // Run a reducer overload; a thrown exception or a null both become None. warnOnUoe logs a hint
+ // when the function implements no usable reducer overload.
+ def attempt[R](call: => R, warnOnUoe: Boolean): Option[R] = {
+ val t = Try(Option(call))
+ if (warnOnUoe) {
+ t.failed.foreach {
+ case _: UnsupportedOperationException =>
+ logWarning(log"V2 function ${MDC(FUNCTION_NAME, thisName)} threw " +
+ log"UnsupportedOperationException; treating as not reducible. Override " +
+ log"reducer(Literal[], ReducibleFunction, Literal[]) to enable SPJ.")
+ case _ =>
+ }
+ }
+ t.toOption.flatten
+ }
+
+ if (thisParams.isEmpty && otherParams.isEmpty) {
+ attempt(thisFunction.reducer(otherFunction), warnOnUoe = true)
+ } else if (isSingleInt(thisParams) && isSingleInt(otherParams)) {
+ // Try the deprecated int API first (legacy connectors); fall back to the generalized overload
+ // when it is absent or returns null. Option.orElse fires on None, covering both.
+ attempt(thisFunction.reducer(
+ thisParams(0).value().asInstanceOf[Int], otherFunction,
+ otherParams(0).value().asInstanceOf[Int]), warnOnUoe = false)
+ .orElse(
+ attempt(thisFunction.reducer(thisParams, otherFunction, otherParams), warnOnUoe = true))
+ } else {
+ // Parameterized functions (bucket, truncate, etc.)
+ attempt(thisFunction.reducer(thisParams, otherFunction, otherParams), warnOnUoe = true)
}
- Option(res)
}
override def dataType: DataType = function.resultType()
@@ -118,10 +222,7 @@ case class TransformExpression(
copy(children = newChildren)
private lazy val resolvedFunction: Option[Expression] = this match {
- case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) =>
- Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc,
- Seq(Literal(numBuckets)) ++ arguments))
- case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) =>
+ case TransformExpression(scalarFunc: ScalarFunction[_], arguments) =>
Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments))
case _ => None
}
@@ -136,3 +237,18 @@ case class TransformExpression(
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this)
}
+
+object TransformExpression {
+ /**
+ * Whether `e` is a bare column reference: an [[Attribute]] or a [[GetStructField]] chain
+ * (struct-field access on a column). Shared by [[TransformExpression.isSameFunction]] and by
+ * `KeyedPartitioning.supportsExpressions`, which both decide whether a transform's single
+ * non-literal argument is a plain column.
+ */
+ @tailrec
+ private[sql] def isColumnRef(e: Expression): Boolean = e match {
+ case _: Attribute => true
+ case g: GetStructField => isColumnRef(g.child)
+ case _ => false
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
index c561677ed5ad5..5967491cf8e07 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan,
import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier}
import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
-import org.apache.spark.sql.connector.expressions.{BucketTransform, Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform}
+import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform}
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue}
import org.apache.spark.sql.connector.read.{SampleMethod => V2SampleMethod}
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
@@ -116,17 +116,6 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
funCatalogOpt: Option[FunctionCatalog] = None): Option[Expression] = trans match {
case IdentityTransform(ref) =>
Some(resolveRef[NamedExpression](ref, query))
- case BucketTransform(numBuckets, refs, sorted)
- if sorted.isEmpty && refs.length == 1 && refs.forall(_.isInstanceOf[NamedReference]) =>
- val resolvedRefs = refs.map(r => resolveRef[NamedExpression](r, query))
- // Create a dummy reference for `numBuckets` here and use that, together with `refs`, to
- // look up the V2 function.
- val numBucketsRef = AttributeReference("numBuckets", IntegerType, nullable = false)()
- funCatalogOpt.flatMap { catalog =>
- loadV2FunctionOpt(catalog, "bucket", Seq(numBucketsRef) ++ resolvedRefs).map { bound =>
- TransformExpression(bound, resolvedRefs, Some(numBuckets))
- }
- }
case NamedTransform(name, args) =>
val catalystArgs = args.map(toCatalyst(_, query, funCatalogOpt))
funCatalogOpt.flatMap { catalog =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index aeacdaec7a8de..fc4d5bc5da73f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.plans.physical
-import scala.annotation.tailrec
import scala.collection.mutable
import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
@@ -641,19 +640,15 @@ object KeyedPartitioning {
def supportsExpressions(expressions: Seq[Expression]): Boolean = {
def isSupportedTransform(transform: TransformExpression): Boolean = {
- transform.children.size == 1 && isReference(transform.children.head)
- }
-
- @tailrec
- def isReference(e: Expression): Boolean = e match {
- case _: Attribute => true
- case g: GetStructField => isReference(g.child)
- case _ => false
+ // Should only consider column references, not literals.
+ val nonLiteralChildren = transform.children.filterNot(_.isInstanceOf[Literal])
+ // We need exactly one column reference per transform.
+ nonLiteralChildren.size == 1 && TransformExpression.isColumnRef(nonLiteralChildren.head)
}
expressions.forall {
case t: TransformExpression if isSupportedTransform(t) => true
- case e: Expression if isReference(e) => true
+ case e: Expression if TransformExpression.isColumnRef(e) => true
case _ => false
}
}
@@ -1335,7 +1330,13 @@ case class KeyedShuffleSpec(
val newExpressions = partitioning.expressions.zip(keyPositions).map {
case (te: TransformExpression, positionSet) =>
- te.copy(children = te.children.map(_ => clustering(positionSet.head)))
+ // Preserve literal parameters (e.g., numBuckets, truncate width)
+ // while replacing only column references with the new clustering expression
+ val newChildren = te.children.map {
+ case l: Literal => l // Keep literals as-is
+ case _ => clustering(positionSet.head) // Replace column references
+ }
+ te.copy(children = newChildren)
case (_, positionSet) => clustering(positionSet.head)
}
KeyedPartitioning(newExpressions, partitioning.partitionKeys, partitioning.isGrouped)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala
index 02e19dd053f29..9f3c4daa7a6f3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.datasources.v2
import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, ResolveTimeZone, TypeCoercion}
-import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder, TransformExpression, V2ExpressionUtils}
+import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder, TransformExpression, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RebalancePartitions, RepartitionByExpression, Sort}
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
@@ -96,9 +96,7 @@ object DistributionAndOrderingUtils {
}
private def resolveTransformExpression(expr: Expression): Expression = expr.transform {
- case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) =>
- V2ExpressionUtils.resolveScalarFunction(scalarFunc, Seq(Literal(numBuckets)) ++ arguments)
- case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) =>
+ case TransformExpression(scalarFunc: ScalarFunction[_], arguments) =>
V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 711f6dbdcdb11..7dc5749c537c8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -38,6 +38,7 @@ import org.apache.spark.sql.functions.{col, max}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf._
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with ExplainSuiteHelper {
private val functions = Seq(
@@ -126,7 +127,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
val df = sql(s"SELECT * FROM testcat.ns.$table")
val distribution = physical.ClusteredDistribution(
- Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32))))
+ Seq(TransformExpression(BucketFunction, Seq(Literal(32), attr("ts")))))
checkQueryPlan(df, distribution, physical.UnknownPartitioning(0))
}
@@ -138,7 +139,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
val df = sql(s"SELECT * FROM testcat.ns.$table")
val distribution = physical.ClusteredDistribution(
- Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32))))
+ Seq(TransformExpression(BucketFunction, Seq(Literal(32), attr("ts")))))
// Has exactly one partition.
val partitionKeys = Seq(0).map(v => InternalRow.fromSeq(Seq(v)))
@@ -194,13 +195,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
val df = sql(s"SELECT * FROM testcat.ns.$table")
val distribution = physical.ClusteredDistribution(
- Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32))))
+ Seq(TransformExpression(BucketFunction, Seq(Literal(32), attr("ts")))))
checkQueryPlan(df, distribution, physical.UnknownPartitioning(0))
}
}
- test("non-clustered distribution: V2 function with multiple args") {
+ test("clustered distribution: V2 function with multiple args") {
val partitions: Array[Transform] = Array(
Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(2))
)
@@ -216,7 +217,11 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
val distribution = physical.ClusteredDistribution(
Seq(TransformExpression(TruncateFunction, Seq(attr("data"), Literal(2)))))
- checkQueryPlan(df, distribution, physical.UnknownPartitioning(0))
+ // With truncate transform support, KeyedPartitioning should now work
+ val partitionKeys = Seq("aa", "bb", "cc").map(v =>
+ InternalRow(UTF8String.fromString(v)))
+ checkQueryPlan(df, distribution,
+ physical.KeyedPartitioning(distribution.clustering, partitionKeys))
}
/**
@@ -4183,4 +4188,345 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
}
}
}
+
+ test("SPARK-50593: cross-function truncate vs bucket should NOT trigger SPJ") {
+ val partitions1 = Array(
+ Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(3))
+ )
+ val partitions2 = Array(
+ Expressions.bucket(4, "data")
+ )
+
+ createTable("trunc_cross1", columns, partitions1)
+ sql("INSERT INTO testcat.ns.trunc_cross1 VALUES " +
+ "(0, 'aaa', CAST('2022-01-01' AS timestamp)), " +
+ "(1, 'bbb', CAST('2021-01-01' AS timestamp))")
+
+ createTable("trunc_cross2", columns2, partitions2)
+ sql("INSERT INTO testcat.ns.trunc_cross2 VALUES " +
+ "(1, 5, 'aaa'), " +
+ "(5, 10, 'bbb')")
+
+ withSQLConf(
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+ SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
+
+ val df = sql(
+ s"""
+ |${selectWithMergeJoinHint("trunc_cross1", "trunc_cross2")}
+ |trunc_cross1.id, trunc_cross2.store_id
+ |FROM testcat.ns.trunc_cross1 JOIN testcat.ns.trunc_cross2
+ |ON trunc_cross1.data = trunc_cross2.data
+ |ORDER BY trunc_cross1.id
+ |""".stripMargin)
+
+ // Different functions (truncate vs bucket) are not mutually reducible, so a shuffle
+ // must still be planned.
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ assert(shuffles.nonEmpty,
+ "truncate vs bucket are not compatible - a shuffle should be present, " +
+ "but none was planned")
+ checkAnswer(df, Seq(Row(0, 1), Row(1, 5)))
+ }
+ }
+
+ test("SPARK-50593: truncate(3) vs truncate(5) triggers SPJ via width reducer") {
+ // Exercises the Literal[]-based reducer path end-to-end: truncate widths 3 and 5
+ // are mutually reducible (reduce the larger to the smaller), so SPJ must avoid the shuffle.
+ val table1 = "trunc_three"
+ val table2 = "trunc_five"
+
+ val partitions1 = Array(
+ Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(3)))
+ val partitions2 = Array(
+ Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(5)))
+
+ createTable(table1, columns, partitions1)
+ sql(s"INSERT INTO testcat.ns.$table1 VALUES " +
+ "(0, 'apple', CAST('2022-01-01' AS timestamp)), " +
+ "(1, 'grape', CAST('2021-01-01' AS timestamp)), " +
+ "(2, 'orange', CAST('2020-01-01' AS timestamp))")
+
+ createTable(table2, columns, partitions2)
+ sql(s"INSERT INTO testcat.ns.$table2 VALUES " +
+ "(10, 'apple', CAST('2022-01-01' AS timestamp)), " +
+ "(20, 'grape', CAST('2021-01-01' AS timestamp)), " +
+ "(30, 'orange', CAST('2020-01-01' AS timestamp))")
+
+ withSQLConf(
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+ SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
+
+ val df = sql(
+ s"""
+ |${selectWithMergeJoinHint(table1, table2)}
+ |$table1.id AS left_id, $table2.id AS right_id
+ |FROM testcat.ns.$table1 JOIN testcat.ns.$table2
+ |ON $table1.data = $table2.data
+ |ORDER BY $table1.id
+ |""".stripMargin)
+
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ assert(shuffles.isEmpty,
+ "truncate(3) vs truncate(5) should avoid shuffle via the width reducer, " +
+ "but a shuffle was planned")
+ checkAnswer(df, Seq(Row(0, 10), Row(1, 20), Row(2, 30)))
+ }
+ }
+
+ test("SPARK-50593: existing bucket SPJ still works with Literal[] API") {
+ // Exercises the new Literal[]-based reducer path end-to-end: bucket(4) and
+ // bucket(2) differ, so SPJ can only avoid the shuffle if BucketFunction's reducer
+ // (now implemented via Literal[] params) correctly returns a GCD-based Reducer.
+ // BucketFunction overrides only the new API, so this also covers the deprecated->new
+ // fallback: the single-int dispatch tries reducer(int, ...) first (UOE), then the Literal[].
+ val table1 = "bucket_compat1"
+ val table2 = "bucket_compat2"
+
+ val partitions1 = Array(Expressions.bucket(4, "id"))
+ val partitions2 = Array(Expressions.bucket(2, "store_id"))
+
+ createTable(table1, columns, partitions1)
+ sql(s"INSERT INTO testcat.ns.$table1 VALUES " +
+ "(0, 'aaa', CAST('2022-01-01' AS timestamp)), " +
+ "(1, 'bbb', CAST('2021-01-01' AS timestamp)), " +
+ "(2, 'ccc', CAST('2020-01-01' AS timestamp)), " +
+ "(3, 'ddd', CAST('2019-01-01' AS timestamp))")
+
+ createTable(table2, columns2, partitions2)
+ sql(s"INSERT INTO testcat.ns.$table2 VALUES " +
+ "(0, 5, 'aaa'), " +
+ "(1, 10, 'bbb'), " +
+ "(2, 15, 'ccc'), " +
+ "(3, 20, 'ddd')")
+
+ withSQLConf(
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+ SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
+ val df = sql(
+ s"""
+ |${selectWithMergeJoinHint(table1, table2)}
+ |$table1.id, $table2.store_id
+ |FROM testcat.ns.$table1 JOIN testcat.ns.$table2
+ |ON $table1.id = $table2.store_id
+ |ORDER BY $table1.id
+ |""".stripMargin)
+
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ assert(shuffles.isEmpty,
+ "bucket(4) vs bucket(2) should avoid shuffle via the GCD reducer, " +
+ "but a shuffle was planned")
+ checkAnswer(df, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3)))
+ }
+ }
+
+ test("SPARK-50593: bucket(4) vs bucket(3) - no common divisor, must shuffle") {
+ // GCD(4, 3) = 1 -- BucketFunction.reducer returns null. Spark must NOT enable SPJ.
+ // Regression guard: a buggy null-handling in TransformExpression.reducer (e.g.,
+ // Try(...).toOption instead of Try(Option(...))) would treat null as Some(null),
+ // enable SPJ, and produce wrong join results for incompatible bucket layouts.
+ val table1 = "bucket_gcd1_a"
+ val table2 = "bucket_gcd1_b"
+
+ val partitions1 = Array(Expressions.bucket(4, "id"))
+ val partitions2 = Array(Expressions.bucket(3, "store_id"))
+
+ createTable(table1, columns, partitions1)
+ sql(s"INSERT INTO testcat.ns.$table1 VALUES " +
+ "(0, 'aaa', CAST('2022-01-01' AS timestamp)), " +
+ "(1, 'bbb', CAST('2021-01-01' AS timestamp)), " +
+ "(2, 'ccc', CAST('2020-01-01' AS timestamp))")
+
+ createTable(table2, columns2, partitions2)
+ sql(s"INSERT INTO testcat.ns.$table2 VALUES " +
+ "(0, 5, 'aaa'), " +
+ "(1, 10, 'bbb'), " +
+ "(2, 15, 'ccc')")
+
+ withSQLConf(
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+ SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
+ val df = sql(
+ s"""
+ |${selectWithMergeJoinHint(table1, table2)}
+ |$table1.id, $table2.store_id
+ |FROM testcat.ns.$table1 JOIN testcat.ns.$table2
+ |ON $table1.id = $table2.store_id
+ |ORDER BY $table1.id
+ |""".stripMargin)
+
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ assert(shuffles.nonEmpty,
+ "bucket(4) vs bucket(3) have no common divisor (GCD=1), so the reducer " +
+ "returns null. SPJ must NOT be enabled; a shuffle is required.")
+ checkAnswer(df, Seq(Row(0, 0), Row(1, 1), Row(2, 2)))
+ }
+ }
+
+ test("SPARK-50593: isSameFunction recurses into nested transforms, respects column-ref slots") {
+ import org.apache.spark.sql.catalyst.expressions.{Add, Expression, GetStructField}
+ val a = attr("a")
+ val b = attr("b")
+ def bucket(n: Int, e: Expression): TransformExpression =
+ TransformExpression(BucketFunction, Seq(Literal(n), e))
+ def years(e: Expression): TransformExpression = TransformExpression(YearsFunction, Seq(e))
+ def days(e: Expression): TransformExpression = TransformExpression(DaysFunction, Seq(e))
+
+ // Nested identical -> same (recursing into the inner transform), with column identity ignored.
+ // isSameFunction stays correct for nested shapes even though the SPJ gate currently rejects
+ // them; keeping this behavior is the right shape for any future nested support.
+ assert(bucket(4, years(a)).isSameFunction(bucket(4, years(a))))
+ assert(bucket(4, years(a)).isSameFunction(bucket(4, years(b))), "column identity is ignored")
+ // Nested different inner -> not same.
+ assert(!bucket(4, years(a)).isSameFunction(bucket(4, days(a))))
+ // Different outer literal -> not same.
+ assert(!bucket(4, years(a)).isSameFunction(bucket(2, years(a))))
+ // Flat sanity (no nesting).
+ assert(bucket(4, a).isSameFunction(bucket(4, b)))
+ assert(!bucket(4, a).isSameFunction(bucket(2, b)))
+
+ // A non-reference column slot (a + 1) carries value-changing semantics, so it is conservatively
+ // treated as not-same -- even compared to itself.
+ val add = bucket(4, Add(a, Literal(1)))
+ assert(!add.isSameFunction(bucket(4, Add(b, Literal(1)))))
+ assert(!add.isSameFunction(add), "a non-reference slot is treated as not-same by design")
+
+ // Struct-field column references are recognized (reflexivity preserved for genuine refs).
+ val s = AttributeReference("s", StructType(Seq(StructField("f", IntegerType))))()
+ val sf = GetStructField(s, 0)
+ assert(bucket(4, sf).isSameFunction(bucket(4, sf)))
+ }
+
+ test("SPARK-50593: supportsExpressions admits flat parameterized transforms, " +
+ "rejects nested and non-reference slots") {
+ import org.apache.spark.sql.catalyst.expressions.{Add, Expression}
+ val a = AttributeReference("a", IntegerType)()
+ val b = AttributeReference("b", IntegerType)()
+ def bucket(n: Int, e: Expression): TransformExpression =
+ TransformExpression(BucketFunction, Seq(Literal(n), e))
+
+ // Flat parameterized transform over a bare column -> admitted (one non-literal child = column).
+ assert(physical.KeyedPartitioning.supportsExpressions(Seq(bucket(4, a))))
+ // Bare identity column -> admitted.
+ assert(physical.KeyedPartitioning.supportsExpressions(Seq(a)))
+
+ // Nested transform -> rejected: the non-literal child is a transform, not a column reference.
+ // SPJ reasons about a transform via its function and literal params alone, which is unsound
+ // when the remaining argument is itself a transform.
+ val nested = bucket(4, TransformExpression(YearsFunction, Seq(a)))
+ assert(!physical.KeyedPartitioning.supportsExpressions(Seq(nested)))
+
+ // Value-changing slot (a + 1) -> rejected: not a plain column reference.
+ assert(!physical.KeyedPartitioning.supportsExpressions(Seq(bucket(4, Add(a, Literal(1))))))
+
+ // Two non-literal column references -> rejected: a partition expression must map to exactly one
+ // clustering column (the positional keyPositions model needs a single column per transform).
+ assert(!physical.KeyedPartitioning.supportsExpressions(
+ Seq(TransformExpression(BucketFunction, Seq(Literal(4), a, b)))))
+ }
+
+ test("SPARK-50593: integer truncate is reducible via lcm (generalized reducer, non-bucket)") {
+ // A second reducible transform exercising the generalized Literal[] reducer API with reducer
+ // math distinct from bucket (GCD) and string truncate (prefix-min): integer truncate snaps to
+ // a coarser grid, so truncate(v, W1) and truncate(v, W2) reduce onto multiples of lcm(W1, W2).
+ import org.apache.spark.sql.catalyst.expressions.Expression
+ val id = attr("id")
+ def itrunc(e: Expression, w: Int): TransformExpression =
+ TransformExpression(IntegerTruncateFunction, Seq(e, Literal(w)))
+
+ // Same width -> same function (no reduction needed).
+ assert(itrunc(id, 4).isSameFunction(itrunc(id, 4)))
+
+ // W2 is a multiple of W1: the finer side (W1=2) reduces onto the coarser grid (W2=4).
+ assert(itrunc(id, 2).isCompatible(itrunc(id, 4)))
+ val r = itrunc(id, 2).reducers(itrunc(id, 4))
+ assert(r.isDefined, "truncate(2) must reduce onto truncate(4)")
+ val red = r.get.asInstanceOf[Reducer[Integer, Integer]]
+ // truncate(.,2) values snapped to multiples of 4: 6 -> 4, 2 -> 0, 8 -> 8
+ assert(red.reduce(6) == 4 && red.reduce(2) == 0 && red.reduce(8) == 8)
+ // The coarser side (4) is already the common grid -> no reducer.
+ assert(itrunc(id, 4).reducers(itrunc(id, 2)).isEmpty)
+
+ // Neither divides the other: both sides reduce to the lcm grid.
+ assert(itrunc(id, 6).isCompatible(itrunc(id, 4))) // lcm(6, 4) = 12
+ assert(itrunc(id, 6).reducers(itrunc(id, 4)).isDefined)
+ assert(itrunc(id, 4).reducers(itrunc(id, 6)).isDefined)
+ assert(itrunc(id, 3).isCompatible(itrunc(id, 5))) // coprime -> lcm(3, 5) = 15
+ }
+
+ test("SPARK-50593: deprecated int reducer API still works (legacy connector backward compat)") {
+ // The reducer dispatch attempts the deprecated reducer(int, func, int) first for single-int
+ // params, so a ReducibleFunction that overrides ONLY the deprecated method still reduces.
+ // This mirrors how Iceberg 1.10.0 (and earlier) ship -- they predate the Literal[] API.
+ val bucketExpr4 = TransformExpression(LegacyBucketFunction, Seq(Literal(4), attr("id")))
+ val bucketExpr2 = TransformExpression(LegacyBucketFunction, Seq(Literal(2), attr("id")))
+
+ val reducer = bucketExpr4.reducers(bucketExpr2)
+ assert(reducer.isDefined, "Expected a reducer for legacy_bucket(4) on legacy_bucket(2)")
+
+ // Verify the returned Reducer actually reduces bucket 4 -> bucket 2 (GCD = 2).
+ // bucket(4, x) produces values in [0, 4); reducing by GCD=2 gives v % 2.
+ val r = reducer.get.asInstanceOf[Reducer[Integer, Integer]]
+ assert(r.reduce(3) == 1, s"Expected reduce(3) == 1, got ${r.reduce(3)}")
+ assert(r.reduce(2) == 0, s"Expected reduce(2) == 0, got ${r.reduce(2)}")
+ }
+
+ test("SPARK-50593: a non-IntegerType param (DateType) does not reach the deprecated " +
+ "int reducer") {
+ // DateType is stored as a boxed Integer (epoch days) internally, so the reducer dispatch must
+ // key off the DataType, not the runtime class -- otherwise a DateType param is mistaken for the
+ // bucket-style int param and routed to the deprecated reducer(int, ...). LegacyBucketFunction
+ // overrides ONLY that deprecated method, so with a DateType param it must be unreachable,
+ // leaving the pair not reducible (rather than producing a bogus GCD reducer over epoch-days).
+ val l = TransformExpression(LegacyBucketFunction, Seq(Literal(8, DateType), attr("id")))
+ val r = TransformExpression(LegacyBucketFunction, Seq(Literal(4, DateType), attr("id")))
+ assert(!l.isSameFunction(r))
+ assert(!l.isCompatible(r), "a DateType param must not reach the deprecated int reducer")
+ assert(l.reducers(r).isEmpty && r.reducers(l).isEmpty)
+ }
+
+ test("SPARK-50593: mismatched column/literal argument layout is not reducible") {
+ // Both transforms pass the strict gate (one column-reference non-literal child), but the column
+ // and literal sit in swapped positions: truncate(id, 2) is (col, lit) while truncate(4, sid) is
+ // (lit, col). The reducer only sees the literal positions ([2] vs [4]), so without an
+ // argument-layout check it would wrongly reduce these and co-locate non-matching rows.
+ // IntegerTruncateFunction has two same-typed (Int) args, which makes this layout reachable.
+ val l = TransformExpression(IntegerTruncateFunction, Seq(attr("id"), Literal(2)))
+ val r = TransformExpression(IntegerTruncateFunction, Seq(Literal(4), attr("store_id")))
+ assert(!l.isSameFunction(r))
+ assert(!l.isCompatible(r), "swapped column/literal layout must not be reducible")
+ assert(l.reducers(r).isEmpty && r.reducers(l).isEmpty)
+
+ // Control: same layout (col, lit) on both sides remains reducible via lcm(2, 4).
+ val a = TransformExpression(IntegerTruncateFunction, Seq(attr("id"), Literal(2)))
+ val b = TransformExpression(IntegerTruncateFunction, Seq(attr("store_id"), Literal(4)))
+ assert(a.isCompatible(b), "aligned (col, lit) layout must remain reducible")
+ }
+
+ test("SPARK-50593: deprecated reducer returning null falls back to the generalized overload") {
+ // DualApiBucketFunction implements both overloads: the deprecated reducer(int, ...) returns
+ // null, while the generalized reducer(Literal[], ...) returns a valid GCD reducer. The dispatch
+ // tries the deprecated one first; a null there must fall through to the generalized overload
+ // (Option.orElse fires on None), not be mistaken for "not reducible".
+ val l = TransformExpression(DualApiBucketFunction, Seq(Literal(4), attr("id")))
+ val r = TransformExpression(DualApiBucketFunction, Seq(Literal(2), attr("store_id")))
+ assert(l.isCompatible(r), "a null from the deprecated reducer must fall back to the new API")
+ val red = l.reducers(r)
+ assert(red.isDefined, "generalized reducer must be reached when deprecated returns null")
+ assert(red.get.asInstanceOf[Reducer[Integer, Integer]].reduce(3) == 1)
+ }
+
+ test("SPARK-50593: a complex (non-scalar) literal param is not reducible") {
+ // Reducer parameters must be scalar literals. ArrayParamFunction's generalized reducer returns
+ // a reducer unconditionally, so reaching it at all is the leak; the scalar guard must refuse
+ // the ArrayType literal param first. Different array values keep isSameFunction false, forcing
+ // the reducer path where the guard applies.
+ val l = TransformExpression(ArrayParamFunction,
+ Seq(Literal.create(Array(1, 2, 3), ArrayType(IntegerType)), attr("id")))
+ val r = TransformExpression(ArrayParamFunction,
+ Seq(Literal.create(Array(4, 5, 6), ArrayType(IntegerType)), attr("store_id")))
+ assert(!l.isSameFunction(r))
+ assert(!l.isCompatible(r), "a complex literal param must not be reducible")
+ assert(l.reducers(r).isEmpty && r.reducers(l).isEmpty)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
index 35102c6893d3b..b24a2c731adaf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
@@ -21,6 +21,7 @@ import java.time.temporal.ChronoUnit
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.connector.expressions.Literal
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -213,11 +214,14 @@ object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, In
}
override def reducer(
- thisNumBuckets: Int,
+ thisParams: Array[Literal[_]],
otherFunc: ReducibleFunction[_, _],
- otherNumBuckets: Int): Reducer[Int, Int] = {
+ otherParams: Array[Literal[_]]): Reducer[Int, Int] = {
if (otherFunc == BucketFunction) {
+ val thisNumBuckets = thisParams(0).value().asInstanceOf[Int]
+ val otherNumBuckets = otherParams(0).value().asInstanceOf[Int]
+
val gcd = this.gcd(thisNumBuckets, otherNumBuckets)
if (gcd > 1 && gcd != thisNumBuckets) {
return BucketReducer(gcd)
@@ -235,6 +239,95 @@ case class BucketReducer(divisor: Int) extends Reducer[Int, Int] {
override def displayName(): String = toString
}
+/**
+ * A bucket function that only overrides the deprecated `reducer(int, func, int)` method,
+ * not the new `reducer(Literal[], func, Literal[])` method.
+ *
+ * Used to verify that the default implementation of the new method correctly falls back
+ * to the deprecated int-based API, so legacy implementations continue to work.
+ */
+object LegacyBucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] {
+ override def inputTypes(): Array[DataType] = Array(IntegerType, LongType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "legacy_bucket"
+ override def canonicalName(): String = name()
+ override def toString: String = name()
+ override def produceResult(input: InternalRow): Int = {
+ Math.floorMod(input.getLong(1), input.getInt(0))
+ }
+
+ override def reducer(
+ thisNumBuckets: Int,
+ otherFunc: ReducibleFunction[_, _],
+ otherNumBuckets: Int): Reducer[Int, Int] = {
+ if (otherFunc == LegacyBucketFunction) {
+ val gcd = BigInt(thisNumBuckets).gcd(BigInt(otherNumBuckets)).toInt
+ if (gcd > 1 && gcd != thisNumBuckets) {
+ return BucketReducer(gcd)
+ }
+ }
+ null
+ }
+}
+
+/**
+ * A bucket function that implements BOTH reducer overloads: the deprecated `reducer(int, ..., int)`
+ * always returns null (not reducible via the old API), while the new `reducer(Literal[], ...)`
+ * returns a GCD-based reducer. Used to verify that the dispatch falls back to the generalized
+ * overload when the deprecated one returns null (not only when it throws).
+ */
+object DualApiBucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] {
+ override def inputTypes(): Array[DataType] = Array(IntegerType, LongType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "dual_bucket"
+ override def canonicalName(): String = name()
+ override def toString: String = name()
+ override def produceResult(input: InternalRow): Int = {
+ Math.floorMod(input.getLong(1), input.getInt(0))
+ }
+
+ // Deprecated API: intentionally signals "not reducible" via null (not via an exception).
+ override def reducer(
+ thisNumBuckets: Int,
+ otherFunc: ReducibleFunction[_, _],
+ otherNumBuckets: Int): Reducer[Int, Int] = null
+
+ // New API: a real GCD-based reducer.
+ override def reducer(
+ thisParams: Array[Literal[_]],
+ otherFunc: ReducibleFunction[_, _],
+ otherParams: Array[Literal[_]]): Reducer[Int, Int] = {
+ if (otherFunc == DualApiBucketFunction) {
+ val thisNumBuckets = thisParams(0).value().asInstanceOf[Int]
+ val otherNumBuckets = otherParams(0).value().asInstanceOf[Int]
+ val gcd = BigInt(thisNumBuckets).gcd(BigInt(otherNumBuckets)).toInt
+ if (gcd > 1 && gcd != thisNumBuckets) {
+ return BucketReducer(gcd)
+ }
+ }
+ null
+ }
+}
+
+/**
+ * A function with a complex (ArrayType) literal parameter. Its generalized reducer returns a valid
+ * reducer unconditionally, so a test can prove the dispatch refuses to invoke it for a non-scalar
+ * literal param (rather than the call happening to fail on a cast).
+ */
+object ArrayParamFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] {
+ override def inputTypes(): Array[DataType] = Array(ArrayType(IntegerType), LongType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "array_param"
+ override def canonicalName(): String = name()
+ override def toString: String = name()
+ override def produceResult(input: InternalRow): Int = input.getInt(1)
+
+ override def reducer(
+ thisParams: Array[Literal[_]],
+ otherFunc: ReducibleFunction[_, _],
+ otherParams: Array[Literal[_]]): Reducer[Int, Int] = BucketReducer(1)
+}
+
object UnboundStringSelfFunction extends UnboundFunction {
override def bind(inputType: StructType): BoundFunction = StringSelfFunction
override def description(): String = name()
@@ -253,12 +346,35 @@ object StringSelfFunction extends ScalarFunction[UTF8String] {
}
object UnboundTruncateFunction extends UnboundFunction {
- override def bind(inputType: StructType): BoundFunction = TruncateFunction
+ override def bind(inputType: StructType): BoundFunction = {
+ if (inputType.size == 2) {
+ inputType.head.dataType match {
+ case StringType => TruncateFunction
+ case IntegerType => IntegerTruncateFunction
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"'truncate' does not support data type: ${inputType.head.dataType}")
+ }
+ } else {
+ throw new UnsupportedOperationException(
+ "'truncate' requires exactly 2 arguments: (column, width)")
+ }
+ }
+
override def description(): String = name()
override def name(): String = "truncate"
}
-object TruncateFunction extends ScalarFunction[UTF8String] {
+/**
+ * Truncate transform for String type.
+ * Follows Iceberg spec: truncate(str, L) = str[0:L]
+ *
+ * Implements ReducibleFunction: ANY two different widths are compatible.
+ * The reducer uses the smaller width.
+ */
+object TruncateFunction
+ extends ScalarFunction[UTF8String]
+ with ReducibleFunction[UTF8String, UTF8String] {
override def inputTypes(): Array[DataType] = Array(StringType, IntegerType)
override def resultType(): DataType = StringType
override def name(): String = "truncate"
@@ -266,7 +382,84 @@ object TruncateFunction extends ScalarFunction[UTF8String] {
override def toString: String = name()
override def produceResult(input: InternalRow): UTF8String = {
val str = input.getUTF8String(0)
- val length = input.getInt(1)
- str.substring(0, length)
+ val width = input.getInt(1)
+ str.substring(0, width)
+ }
+
+ override def reducer(
+ thisParams: Array[Literal[_]],
+ otherFunc: ReducibleFunction[_, _],
+ otherParams: Array[Literal[_]]): Reducer[UTF8String, UTF8String] = {
+
+ if (otherFunc == TruncateFunction) {
+ val thisWidth = thisParams(0).value().asInstanceOf[Int]
+ val otherWidth = otherParams(0).value().asInstanceOf[Int]
+ val smallerWidth = math.min(thisWidth, otherWidth)
+
+ if (smallerWidth != thisWidth) {
+ return TruncateReducer(smallerWidth)
+ }
+ }
+ null
+ }
+}
+
+case class TruncateReducer(width: Int) extends Reducer[UTF8String, UTF8String] {
+ override def reduce(value: UTF8String): UTF8String = {
+ value.substring(0, width)
+ }
+ override def resultType(): DataType = StringType
+ override def displayName(): String = s"truncate($width)"
+}
+
+/**
+ * Truncate transform for Integer type.
+ * Follows Iceberg spec: truncate(value, W) = value - (((value % W) + W) % W), which snaps `value`
+ * down to a multiple of `W`.
+ *
+ * Implements ReducibleFunction: truncate(v, W1) and truncate(v, W2) are always reducible onto a
+ * common coarser grid of multiples of lcm(W1, W2). The finer side (whose width does not already
+ * equal the lcm) reduces by snapping to that grid; when W2 is a multiple of W1 the lcm is simply
+ * the coarser width W2.
+ */
+object IntegerTruncateFunction
+ extends ScalarFunction[Int]
+ with ReducibleFunction[Int, Int] {
+ override def inputTypes(): Array[DataType] = Array(IntegerType, IntegerType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "truncate"
+ override def canonicalName(): String = name()
+ override def toString: String = name()
+ override def produceResult(input: InternalRow): Int = {
+ val value = input.getInt(0)
+ val width = input.getInt(1)
+ value - (((value % width) + width) % width)
}
+
+ override def reducer(
+ thisParams: Array[Literal[_]],
+ otherFunc: ReducibleFunction[_, _],
+ otherParams: Array[Literal[_]]): Reducer[Int, Int] = {
+ if (otherFunc == IntegerTruncateFunction) {
+ val thisWidth = thisParams(0).value().asInstanceOf[Int]
+ val otherWidth = otherParams(0).value().asInstanceOf[Int]
+ val common = lcm(thisWidth, otherWidth)
+ // Only the finer side reduces; if `common == thisWidth` this side is already the common grid.
+ if (common != thisWidth) {
+ return IntTruncateReducer(common)
+ }
+ }
+ null
+ }
+
+ private def lcm(a: Int, b: Int): Int = {
+ val g = BigInt(a).gcd(BigInt(b))
+ (BigInt(a) / g * BigInt(b)).toInt
+ }
+}
+
+case class IntTruncateReducer(width: Int) extends Reducer[Int, Int] {
+ override def reduce(value: Int): Int = value - (((value % width) + width) % width)
+ override def resultType(): DataType = IntegerType
+ override def displayName(): String = s"truncate($width)"
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala
index a70baece77844..629d65bb20c0b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, TransformExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal, TransformExpression}
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning, KeyedPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning}
import org.apache.spark.sql.connector.catalog.functions.{BucketFunction, YearsFunction}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -492,7 +492,7 @@ class ProjectedOrderingAndPartitioningSuite
// KP([bucket(32, id)], keys1d) through Project(id as pk) should produce
// KP([bucket(32, pk)], keys1d): the alias is pushed into the bucket's column argument.
val id = AttributeReference("id", IntegerType)()
- val bucketExpr = TransformExpression(BucketFunction, Seq(id), Some(32))
+ val bucketExpr = TransformExpression(BucketFunction, Seq(Literal(32), id))
val keys1d = Seq(InternalRow(0), InternalRow(1), InternalRow(2))
val child = DummyLeafExecWithPartitioning(
output = Seq(id),
@@ -507,7 +507,7 @@ class ProjectedOrderingAndPartitioningSuite
case te: TransformExpression =>
assert(te.isSameFunction(bucketExpr),
"bucket function and numBuckets must be preserved after alias substitution")
- assert(te.children.head.asInstanceOf[Attribute].name === "pk",
+ assert(te.children.collectFirst { case a: Attribute => a }.get.name === "pk",
"bucket's column argument must be rewritten to the aliased attribute")
case other => fail(s"Expected TransformExpression, got $other")
}
@@ -524,7 +524,7 @@ class ProjectedOrderingAndPartitioningSuite
// Result: KP([bucket(32, id)], keys1d, isNarrowed=true, isGrouped=false).
val id = AttributeReference("id", IntegerType)()
val ts = AttributeReference("ts", IntegerType)()
- val bucketExpr = TransformExpression(BucketFunction, Seq(id), Some(32))
+ val bucketExpr = TransformExpression(BucketFunction, Seq(Literal(32), id))
val yearsExpr = TransformExpression(YearsFunction, Seq(ts))
// Projected to position [0] (bucket): (0),(1),(0) -- bucket value 0 appears twice.
val keys2d = Seq(InternalRow(0, 2020), InternalRow(1, 2020), InternalRow(0, 2021))
@@ -539,7 +539,7 @@ class ProjectedOrderingAndPartitioningSuite
kp.expressions.head match {
case te: TransformExpression =>
assert(te.isSameFunction(bucketExpr), "bucket must be the surviving expression")
- assert(te.children.head.asInstanceOf[Attribute].name === "id")
+ assert(te.children.collectFirst { case a: Attribute => a }.get.name === "id")
case other => fail(s"Expected TransformExpression, got $other")
}
assert(kp.isNarrowed, "dropping years(ts) position must mark the KP as narrowed")
@@ -554,7 +554,7 @@ class ProjectedOrderingAndPartitioningSuite
// Result: KP([bucket(32, id), years(ts_alias)], keys2d) -- not narrowed.
val id = AttributeReference("id", IntegerType)()
val ts = AttributeReference("ts", IntegerType)()
- val bucketExpr = TransformExpression(BucketFunction, Seq(id), Some(32))
+ val bucketExpr = TransformExpression(BucketFunction, Seq(Literal(32), id))
val yearsExpr = TransformExpression(YearsFunction, Seq(ts))
val keys2d = Seq(InternalRow(0, 2020), InternalRow(1, 2020), InternalRow(0, 2021))
val child = DummyLeafExecWithPartitioning(
@@ -569,14 +569,14 @@ class ProjectedOrderingAndPartitioningSuite
kp.expressions(0) match {
case te: TransformExpression =>
assert(te.isSameFunction(bucketExpr))
- assert(te.children.head.asInstanceOf[Attribute].name === "id",
+ assert(te.children.collectFirst { case a: Attribute => a }.get.name === "id",
"bucket's argument must remain id (no alias for id in this projection)")
case other => fail(s"Expected TransformExpression at pos 0, got $other")
}
kp.expressions(1) match {
case te: TransformExpression =>
assert(te.isSameFunction(yearsExpr))
- assert(te.children.head.asInstanceOf[Attribute].name === "ts_alias",
+ assert(te.children.collectFirst { case a: Attribute => a }.get.name === "ts_alias",
"years() argument must be rewritten to ts_alias")
case other => fail(s"Expected TransformExpression at pos 1, got $other")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
index 17d00ec055e07..84eee883aeeda 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
@@ -1191,11 +1191,11 @@ class EnsureRequirementsSuite extends SharedSparkSession {
}
def bucket(numBuckets: Int, expr: Expression): TransformExpression = {
- TransformExpression(BucketFunction, Seq(expr), Some(numBuckets))
+ TransformExpression(BucketFunction, Seq(Literal(numBuckets), expr))
}
def buckets(numBuckets: Int, expr: Seq[Expression]): TransformExpression = {
- TransformExpression(BucketFunction, expr, Some(numBuckets))
+ TransformExpression(BucketFunction, Seq(Literal(numBuckets)) ++ expr)
}
test("ShufflePartitionIdPassThrough - avoid unnecessary shuffle when children are compatible") {