diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java index ef1a14e50cdad..f05c0f7767949 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector.catalog.functions; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Literal; /** * Base class for user-defined functions that can be 'reduced' on another function. @@ -60,6 +61,37 @@ @Evolving public interface ReducibleFunction { + /** + * Generic reducer for parameterized functions (bucket, truncate, etc.). + * + * If this function is 'reducible' on another function, return the {@link Reducer}. + *

+ * Each parameter is a scalar {@link Literal} carrying both its value and data type. Parameters + * are always scalar literals (e.g. bucket numBuckets, truncate width); complex types (array, map, + * struct) are not passed here. {@link Literal#value()} is Spark's internal representation (e.g. + * {@code UTF8String} for strings, {@code Decimal} for decimals); use {@link Literal#dataType()} + * to interpret it. + *

+ * Examples: + *

+ * + * @param thisParams literal parameters for this function + * @param otherFunction the other parameterized function + * @param otherParams literal parameters for the other function + * @return a reduction function if reducible, null otherwise + * @since 4.3.0 + */ + default Reducer reducer( + Literal[] thisParams, + ReducibleFunction otherFunction, + Literal[] otherParams) { + throw new UnsupportedOperationException(); + } + /** * This method is for the bucket function. * @@ -78,7 +110,12 @@ public interface ReducibleFunction { * @param otherBucketFunction the other parameterized function * @param otherNumBuckets parameter for the other function * @return a reduction function if it is reducible, null if not + * @deprecated as of 4.3.0. Please override + * {@link #reducer(Literal[], ReducibleFunction, Literal[])} instead. + * The new overload supports transforms with any number of parameters of any type + * (e.g. truncate width, multi-arg range buckets), not just a single int. */ + @Deprecated(since = "4.3.0") default Reducer reducer( int thisNumBuckets, ReducibleFunction otherBucketFunction, @@ -101,6 +138,6 @@ default Reducer reducer( * @return a reduction function if it is reducible, null if not. */ default Reducer reducer(ReducibleFunction otherFunction) { - throw new UnsupportedOperationException(); + return reducer(new Literal[0], otherFunction, new Literal[0]); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index 9041ed15fc501..a22aadb22786c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -17,45 +17,79 @@ package org.apache.spark.sql.catalyst.expressions +import scala.annotation.tailrec +import scala.util.Try + +import org.apache.spark.internal.Logging +import org.apache.spark.internal.LogKeys.FUNCTION_NAME import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction, ScalarFunction} +import org.apache.spark.sql.connector.expressions.{Literal => V2Literal, LiteralValue} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{AtomicType, DataType, IntegerType} /** * Represents a partition transform expression, for instance, `bucket`, `days`, `years`, etc. * * @param function the transform function itself. Spark will use it to decide whether two * partition transform expressions are compatible. - * @param numBucketsOpt the number of buckets if the transform is `bucket`. Unset otherwise. */ -case class TransformExpression( - function: BoundFunction, - children: Seq[Expression], - numBucketsOpt: Option[Int] = None) extends Expression { +case class TransformExpression(function: BoundFunction, children: Seq[Expression]) + extends Expression with Logging { override def nullable: Boolean = true /** - * Whether this [[TransformExpression]] has the same semantics as `other`. - * For instance, `bucket(32, c)` is equal to `bucket(32, d)`, but not to `bucket(16, d)` or - * `year(c)`. + * Extract literal children (constant parameters) from this transform. These are constant + * arguments like width in truncate(col, width). Literals are compared when checking if two + * transforms are the same. + */ + private lazy val literalChildren: Seq[Literal] = + children.collect { case l: Literal => l } + + /** + * Whether this [[TransformExpression]] has the same semantics as `other`. For instance, + * `bucket(32, c)` is equal to `bucket(32, d)`, but not to `bucket(16, d)` or `year(c)`. + * Similarly, `truncate(c, 2)` is equal to `truncate(d, 2)`, but may not to `truncate(c, 4)`. * * This will be used, for instance, by Spark to determine whether storage-partitioned join can * be triggered, by comparing partition transforms from both sides of the join and checking * whether they are compatible. * - * @param other the transform expression to compare to - * @return true if this and `other` has the same semantics w.r.t to transform, false otherwise. + * Two transforms are considered the same when they have the same function name, the same arity, + * and each pair of corresponding children matches: + * - literal arguments must be equal (e.g. numBuckets for bucket, width for truncate), so that + * `bucket(32, c)` is not the same as `bucket(16, c)`; + * - nested transform arguments must recursively be the same function, so that + * `bucket(4, years(c))` is not the same as `bucket(4, days(c))`; + * - everything else must be a plain column reference on both sides. Column identity is + * intentionally ignored (it is reconciled separately via positional matching), but a + * non-reference slot such as `c + 1` or `cast(c)`, or a literal/transform-vs-reference + * mismatch, is treated as not the same. + * + * @param other + * the transform expression to compare to + * @return + * true if this and `other` has the same semantics w.r.t to transform, false otherwise. */ - def isSameFunction(other: TransformExpression): Boolean = other match { - case TransformExpression(otherFunction, _, otherNumBucketsOpt) => - function.canonicalName() == otherFunction.canonicalName() && - numBucketsOpt == otherNumBucketsOpt - case _ => - false - } + def isSameFunction(other: TransformExpression): Boolean = + function.canonicalName() == other.function.canonicalName() && + childrenMatch(other)(_ == _) + + /** + * Per-position match of children, requiring equal arity. Literal slots are compared by the + * caller-supplied `literalsMatch`; nested transform slots must recursively be the same function; + * any other slot must be a plain column reference on both sides. + */ + private def childrenMatch(other: TransformExpression) + (literalsMatch: (Literal, Literal) => Boolean): Boolean = + children.length == other.children.length && + children.zip(other.children).forall { + case (l1: Literal, l2: Literal) => literalsMatch(l1, l2) + case (t1: TransformExpression, t2: TransformExpression) => t1.isSameFunction(t2) + case (c1, c2) => TransformExpression.isColumnRef(c1) && TransformExpression.isColumnRef(c2) + } /** * Whether this [[TransformExpression]]'s function is compatible with the `other` @@ -73,8 +107,8 @@ case class TransformExpression( } else { (function, other.function) match { case (f: ReducibleFunction[_, _], o: ReducibleFunction[_, _]) => - val thisReducer = reducer(f, numBucketsOpt, o, other.numBucketsOpt) - val otherReducer = reducer(o, other.numBucketsOpt, f, numBucketsOpt) + val thisReducer = reducer(f, this, o, other) + val otherReducer = reducer(o, other, f, this) thisReducer.isDefined || otherReducer.isDefined case _ => false } @@ -92,24 +126,94 @@ case class TransformExpression( */ def reducers(other: TransformExpression): Option[Reducer[_, _]] = { (function, other.function) match { - case(e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) => - reducer(e1, numBucketsOpt, e2, other.numBucketsOpt) + case (e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) => + reducer(e1, this, e2, other) case _ => None } } - // Return a Reducer for a reducible function on another reducible function + /** + * Extract all literal parameters of this transform as V2 [[V2Literal]]s, preserving each value's + * internal representation and its `DataType`. Connectors interpret the value via the accompanying + * `DataType` rather than relying on a pre-converted JVM type. + * + * Examples: + * bucket(4, col) => [Literal(4, IntegerType)] + * truncate(col, 3) => [Literal(3, IntegerType)] + * days(col) => [] (no literals) + */ + private def extractParameters: Array[V2Literal[_]] = + literalChildren.map(l => LiteralValue(l.value, l.dataType): V2Literal[_]).toArray + + /** + * Reducer precondition: same argument layout/structure as `other` (arity, aligned slots, equal + * nested transforms, column refs elsewhere). Only literal *values* may differ. Unlike + * [[isSameFunction]] the function name is not compared. + */ + private def sameArgumentLayout(other: TransformExpression): Boolean = + childrenMatch(other)((_, _) => true) + + /** + * Whether every literal parameter is a scalar (an [[AtomicType]]). Reducer parameters are scalar + * literals; this never forwards a complex Catalyst container (ArrayData / MapData / InternalRow) + * across the public reducer boundary -- such a transform is simply treated as not reducible. + */ + private def scalarLiteralParams: Boolean = + literalChildren.forall(_.dataType.isInstanceOf[AtomicType]) + + /** + * Return a Reducer for a reducible function on another reducible function + * Handles both parameterized (bucket, truncate) and non-parameterized (days, hours) functions. + */ private def reducer( thisFunction: ReducibleFunction[_, _], - thisNumBucketsOpt: Option[Int], + thisExpr: TransformExpression, otherFunction: ReducibleFunction[_, _], - otherNumBucketsOpt: Option[Int]): Option[Reducer[_, _]] = { - val res = (thisNumBucketsOpt, otherNumBucketsOpt) match { - case (Some(numBuckets), Some(otherNumBuckets)) => - thisFunction.reducer(numBuckets, otherFunction, otherNumBuckets) - case _ => thisFunction.reducer(otherFunction) + otherExpr: TransformExpression): Option[Reducer[_, _]] = { + if (!thisExpr.sameArgumentLayout(otherExpr) || + !thisExpr.scalarLiteralParams || !otherExpr.scalarLiteralParams) { + return None + } + + val thisParams = thisExpr.extractParameters + val otherParams = otherExpr.extractParameters + val thisName = thisExpr.function.canonicalName() + + // Gate on DataType, not the boxed runtime class (DateType/YearMonthInterval box to Int). + def isSingleInt(p: Array[V2Literal[_]]): Boolean = { + p.length == 1 && p(0).dataType == IntegerType + } + + // Run a reducer overload; a thrown exception or a null both become None. warnOnUoe logs a hint + // when the function implements no usable reducer overload. + def attempt[R](call: => R, warnOnUoe: Boolean): Option[R] = { + val t = Try(Option(call)) + if (warnOnUoe) { + t.failed.foreach { + case _: UnsupportedOperationException => + logWarning(log"V2 function ${MDC(FUNCTION_NAME, thisName)} threw " + + log"UnsupportedOperationException; treating as not reducible. Override " + + log"reducer(Literal[], ReducibleFunction, Literal[]) to enable SPJ.") + case _ => + } + } + t.toOption.flatten + } + + if (thisParams.isEmpty && otherParams.isEmpty) { + attempt(thisFunction.reducer(otherFunction), warnOnUoe = true) + } else if (isSingleInt(thisParams) && isSingleInt(otherParams)) { + // Try the deprecated int API first (legacy connectors); fall back to the generalized overload + // when it is absent or returns null. Option.orElse fires on None, covering both. + attempt(thisFunction.reducer( + thisParams(0).value().asInstanceOf[Int], otherFunction, + otherParams(0).value().asInstanceOf[Int]), warnOnUoe = false) + .orElse( + attempt(thisFunction.reducer(thisParams, otherFunction, otherParams), warnOnUoe = true)) + } else { + // Parameterized functions (bucket, truncate, etc.) + attempt(thisFunction.reducer(thisParams, otherFunction, otherParams), warnOnUoe = true) } - Option(res) } override def dataType: DataType = function.resultType() @@ -118,10 +222,7 @@ case class TransformExpression( copy(children = newChildren) private lazy val resolvedFunction: Option[Expression] = this match { - case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) => - Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc, - Seq(Literal(numBuckets)) ++ arguments)) - case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) => + case TransformExpression(scalarFunc: ScalarFunction[_], arguments) => Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments)) case _ => None } @@ -136,3 +237,18 @@ case class TransformExpression( override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this) } + +object TransformExpression { + /** + * Whether `e` is a bare column reference: an [[Attribute]] or a [[GetStructField]] chain + * (struct-field access on a column). Shared by [[TransformExpression.isSameFunction]] and by + * `KeyedPartitioning.supportsExpressions`, which both decide whether a transform's single + * non-literal argument is a plain column. + */ + @tailrec + private[sql] def isColumnRef(e: Expression): Boolean = e match { + case _: Attribute => true + case g: GetStructField => isColumnRef(g.child) + case _ => false + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala index c561677ed5ad5..5967491cf8e07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier} import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME -import org.apache.spark.sql.connector.expressions.{BucketTransform, Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform} +import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue} import org.apache.spark.sql.connector.read.{SampleMethod => V2SampleMethod} import org.apache.spark.sql.errors.DataTypeErrors.toSQLId @@ -116,17 +116,6 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { funCatalogOpt: Option[FunctionCatalog] = None): Option[Expression] = trans match { case IdentityTransform(ref) => Some(resolveRef[NamedExpression](ref, query)) - case BucketTransform(numBuckets, refs, sorted) - if sorted.isEmpty && refs.length == 1 && refs.forall(_.isInstanceOf[NamedReference]) => - val resolvedRefs = refs.map(r => resolveRef[NamedExpression](r, query)) - // Create a dummy reference for `numBuckets` here and use that, together with `refs`, to - // look up the V2 function. - val numBucketsRef = AttributeReference("numBuckets", IntegerType, nullable = false)() - funCatalogOpt.flatMap { catalog => - loadV2FunctionOpt(catalog, "bucket", Seq(numBucketsRef) ++ resolvedRefs).map { bound => - TransformExpression(bound, resolvedRefs, Some(numBuckets)) - } - } case NamedTransform(name, args) => val catalystArgs = args.map(toCatalyst(_, query, funCatalogOpt)) funCatalogOpt.flatMap { catalog => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index aeacdaec7a8de..fc4d5bc5da73f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.plans.physical -import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.{SparkException, SparkUnsupportedOperationException} @@ -641,19 +640,15 @@ object KeyedPartitioning { def supportsExpressions(expressions: Seq[Expression]): Boolean = { def isSupportedTransform(transform: TransformExpression): Boolean = { - transform.children.size == 1 && isReference(transform.children.head) - } - - @tailrec - def isReference(e: Expression): Boolean = e match { - case _: Attribute => true - case g: GetStructField => isReference(g.child) - case _ => false + // Should only consider column references, not literals. + val nonLiteralChildren = transform.children.filterNot(_.isInstanceOf[Literal]) + // We need exactly one column reference per transform. + nonLiteralChildren.size == 1 && TransformExpression.isColumnRef(nonLiteralChildren.head) } expressions.forall { case t: TransformExpression if isSupportedTransform(t) => true - case e: Expression if isReference(e) => true + case e: Expression if TransformExpression.isColumnRef(e) => true case _ => false } } @@ -1335,7 +1330,13 @@ case class KeyedShuffleSpec( val newExpressions = partitioning.expressions.zip(keyPositions).map { case (te: TransformExpression, positionSet) => - te.copy(children = te.children.map(_ => clustering(positionSet.head))) + // Preserve literal parameters (e.g., numBuckets, truncate width) + // while replacing only column references with the new clustering expression + val newChildren = te.children.map { + case l: Literal => l // Keep literals as-is + case _ => clustering(positionSet.head) // Replace column references + } + te.copy(children = newChildren) case (_, positionSet) => clustering(positionSet.head) } KeyedPartitioning(newExpressions, partitioning.partitionKeys, partitioning.isGrouped) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala index 02e19dd053f29..9f3c4daa7a6f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, ResolveTimeZone, TypeCoercion} -import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder, TransformExpression, V2ExpressionUtils} +import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder, TransformExpression, V2ExpressionUtils} import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RebalancePartitions, RepartitionByExpression, Sort} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} @@ -96,9 +96,7 @@ object DistributionAndOrderingUtils { } private def resolveTransformExpression(expr: Expression): Expression = expr.transform { - case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) => - V2ExpressionUtils.resolveScalarFunction(scalarFunc, Seq(Literal(numBuckets)) ++ arguments) - case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) => + case TransformExpression(scalarFunc: ScalarFunction[_], arguments) => V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 711f6dbdcdb11..7dc5749c537c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.functions.{col, max} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with ExplainSuiteHelper { private val functions = Seq( @@ -126,7 +127,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with val df = sql(s"SELECT * FROM testcat.ns.$table") val distribution = physical.ClusteredDistribution( - Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32)))) + Seq(TransformExpression(BucketFunction, Seq(Literal(32), attr("ts"))))) checkQueryPlan(df, distribution, physical.UnknownPartitioning(0)) } @@ -138,7 +139,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with val df = sql(s"SELECT * FROM testcat.ns.$table") val distribution = physical.ClusteredDistribution( - Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32)))) + Seq(TransformExpression(BucketFunction, Seq(Literal(32), attr("ts"))))) // Has exactly one partition. val partitionKeys = Seq(0).map(v => InternalRow.fromSeq(Seq(v))) @@ -194,13 +195,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with val df = sql(s"SELECT * FROM testcat.ns.$table") val distribution = physical.ClusteredDistribution( - Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32)))) + Seq(TransformExpression(BucketFunction, Seq(Literal(32), attr("ts"))))) checkQueryPlan(df, distribution, physical.UnknownPartitioning(0)) } } - test("non-clustered distribution: V2 function with multiple args") { + test("clustered distribution: V2 function with multiple args") { val partitions: Array[Transform] = Array( Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(2)) ) @@ -216,7 +217,11 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with val distribution = physical.ClusteredDistribution( Seq(TransformExpression(TruncateFunction, Seq(attr("data"), Literal(2))))) - checkQueryPlan(df, distribution, physical.UnknownPartitioning(0)) + // With truncate transform support, KeyedPartitioning should now work + val partitionKeys = Seq("aa", "bb", "cc").map(v => + InternalRow(UTF8String.fromString(v))) + checkQueryPlan(df, distribution, + physical.KeyedPartitioning(distribution.clustering, partitionKeys)) } /** @@ -4183,4 +4188,345 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } } } + + test("SPARK-50593: cross-function truncate vs bucket should NOT trigger SPJ") { + val partitions1 = Array( + Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(3)) + ) + val partitions2 = Array( + Expressions.bucket(4, "data") + ) + + createTable("trunc_cross1", columns, partitions1) + sql("INSERT INTO testcat.ns.trunc_cross1 VALUES " + + "(0, 'aaa', CAST('2022-01-01' AS timestamp)), " + + "(1, 'bbb', CAST('2021-01-01' AS timestamp))") + + createTable("trunc_cross2", columns2, partitions2) + sql("INSERT INTO testcat.ns.trunc_cross2 VALUES " + + "(1, 5, 'aaa'), " + + "(5, 10, 'bbb')") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + + val df = sql( + s""" + |${selectWithMergeJoinHint("trunc_cross1", "trunc_cross2")} + |trunc_cross1.id, trunc_cross2.store_id + |FROM testcat.ns.trunc_cross1 JOIN testcat.ns.trunc_cross2 + |ON trunc_cross1.data = trunc_cross2.data + |ORDER BY trunc_cross1.id + |""".stripMargin) + + // Different functions (truncate vs bucket) are not mutually reducible, so a shuffle + // must still be planned. + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.nonEmpty, + "truncate vs bucket are not compatible - a shuffle should be present, " + + "but none was planned") + checkAnswer(df, Seq(Row(0, 1), Row(1, 5))) + } + } + + test("SPARK-50593: truncate(3) vs truncate(5) triggers SPJ via width reducer") { + // Exercises the Literal[]-based reducer path end-to-end: truncate widths 3 and 5 + // are mutually reducible (reduce the larger to the smaller), so SPJ must avoid the shuffle. + val table1 = "trunc_three" + val table2 = "trunc_five" + + val partitions1 = Array( + Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(3))) + val partitions2 = Array( + Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(5))) + + createTable(table1, columns, partitions1) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(0, 'apple', CAST('2022-01-01' AS timestamp)), " + + "(1, 'grape', CAST('2021-01-01' AS timestamp)), " + + "(2, 'orange', CAST('2020-01-01' AS timestamp))") + + createTable(table2, columns, partitions2) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(10, 'apple', CAST('2022-01-01' AS timestamp)), " + + "(20, 'grape', CAST('2021-01-01' AS timestamp)), " + + "(30, 'orange', CAST('2020-01-01' AS timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + + val df = sql( + s""" + |${selectWithMergeJoinHint(table1, table2)} + |$table1.id AS left_id, $table2.id AS right_id + |FROM testcat.ns.$table1 JOIN testcat.ns.$table2 + |ON $table1.data = $table2.data + |ORDER BY $table1.id + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, + "truncate(3) vs truncate(5) should avoid shuffle via the width reducer, " + + "but a shuffle was planned") + checkAnswer(df, Seq(Row(0, 10), Row(1, 20), Row(2, 30))) + } + } + + test("SPARK-50593: existing bucket SPJ still works with Literal[] API") { + // Exercises the new Literal[]-based reducer path end-to-end: bucket(4) and + // bucket(2) differ, so SPJ can only avoid the shuffle if BucketFunction's reducer + // (now implemented via Literal[] params) correctly returns a GCD-based Reducer. + // BucketFunction overrides only the new API, so this also covers the deprecated->new + // fallback: the single-int dispatch tries reducer(int, ...) first (UOE), then the Literal[]. + val table1 = "bucket_compat1" + val table2 = "bucket_compat2" + + val partitions1 = Array(Expressions.bucket(4, "id")) + val partitions2 = Array(Expressions.bucket(2, "store_id")) + + createTable(table1, columns, partitions1) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(0, 'aaa', CAST('2022-01-01' AS timestamp)), " + + "(1, 'bbb', CAST('2021-01-01' AS timestamp)), " + + "(2, 'ccc', CAST('2020-01-01' AS timestamp)), " + + "(3, 'ddd', CAST('2019-01-01' AS timestamp))") + + createTable(table2, columns2, partitions2) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(0, 5, 'aaa'), " + + "(1, 10, 'bbb'), " + + "(2, 15, 'ccc'), " + + "(3, 20, 'ddd')") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + val df = sql( + s""" + |${selectWithMergeJoinHint(table1, table2)} + |$table1.id, $table2.store_id + |FROM testcat.ns.$table1 JOIN testcat.ns.$table2 + |ON $table1.id = $table2.store_id + |ORDER BY $table1.id + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, + "bucket(4) vs bucket(2) should avoid shuffle via the GCD reducer, " + + "but a shuffle was planned") + checkAnswer(df, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3))) + } + } + + test("SPARK-50593: bucket(4) vs bucket(3) - no common divisor, must shuffle") { + // GCD(4, 3) = 1 -- BucketFunction.reducer returns null. Spark must NOT enable SPJ. + // Regression guard: a buggy null-handling in TransformExpression.reducer (e.g., + // Try(...).toOption instead of Try(Option(...))) would treat null as Some(null), + // enable SPJ, and produce wrong join results for incompatible bucket layouts. + val table1 = "bucket_gcd1_a" + val table2 = "bucket_gcd1_b" + + val partitions1 = Array(Expressions.bucket(4, "id")) + val partitions2 = Array(Expressions.bucket(3, "store_id")) + + createTable(table1, columns, partitions1) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(0, 'aaa', CAST('2022-01-01' AS timestamp)), " + + "(1, 'bbb', CAST('2021-01-01' AS timestamp)), " + + "(2, 'ccc', CAST('2020-01-01' AS timestamp))") + + createTable(table2, columns2, partitions2) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(0, 5, 'aaa'), " + + "(1, 10, 'bbb'), " + + "(2, 15, 'ccc')") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + val df = sql( + s""" + |${selectWithMergeJoinHint(table1, table2)} + |$table1.id, $table2.store_id + |FROM testcat.ns.$table1 JOIN testcat.ns.$table2 + |ON $table1.id = $table2.store_id + |ORDER BY $table1.id + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.nonEmpty, + "bucket(4) vs bucket(3) have no common divisor (GCD=1), so the reducer " + + "returns null. SPJ must NOT be enabled; a shuffle is required.") + checkAnswer(df, Seq(Row(0, 0), Row(1, 1), Row(2, 2))) + } + } + + test("SPARK-50593: isSameFunction recurses into nested transforms, respects column-ref slots") { + import org.apache.spark.sql.catalyst.expressions.{Add, Expression, GetStructField} + val a = attr("a") + val b = attr("b") + def bucket(n: Int, e: Expression): TransformExpression = + TransformExpression(BucketFunction, Seq(Literal(n), e)) + def years(e: Expression): TransformExpression = TransformExpression(YearsFunction, Seq(e)) + def days(e: Expression): TransformExpression = TransformExpression(DaysFunction, Seq(e)) + + // Nested identical -> same (recursing into the inner transform), with column identity ignored. + // isSameFunction stays correct for nested shapes even though the SPJ gate currently rejects + // them; keeping this behavior is the right shape for any future nested support. + assert(bucket(4, years(a)).isSameFunction(bucket(4, years(a)))) + assert(bucket(4, years(a)).isSameFunction(bucket(4, years(b))), "column identity is ignored") + // Nested different inner -> not same. + assert(!bucket(4, years(a)).isSameFunction(bucket(4, days(a)))) + // Different outer literal -> not same. + assert(!bucket(4, years(a)).isSameFunction(bucket(2, years(a)))) + // Flat sanity (no nesting). + assert(bucket(4, a).isSameFunction(bucket(4, b))) + assert(!bucket(4, a).isSameFunction(bucket(2, b))) + + // A non-reference column slot (a + 1) carries value-changing semantics, so it is conservatively + // treated as not-same -- even compared to itself. + val add = bucket(4, Add(a, Literal(1))) + assert(!add.isSameFunction(bucket(4, Add(b, Literal(1))))) + assert(!add.isSameFunction(add), "a non-reference slot is treated as not-same by design") + + // Struct-field column references are recognized (reflexivity preserved for genuine refs). + val s = AttributeReference("s", StructType(Seq(StructField("f", IntegerType))))() + val sf = GetStructField(s, 0) + assert(bucket(4, sf).isSameFunction(bucket(4, sf))) + } + + test("SPARK-50593: supportsExpressions admits flat parameterized transforms, " + + "rejects nested and non-reference slots") { + import org.apache.spark.sql.catalyst.expressions.{Add, Expression} + val a = AttributeReference("a", IntegerType)() + val b = AttributeReference("b", IntegerType)() + def bucket(n: Int, e: Expression): TransformExpression = + TransformExpression(BucketFunction, Seq(Literal(n), e)) + + // Flat parameterized transform over a bare column -> admitted (one non-literal child = column). + assert(physical.KeyedPartitioning.supportsExpressions(Seq(bucket(4, a)))) + // Bare identity column -> admitted. + assert(physical.KeyedPartitioning.supportsExpressions(Seq(a))) + + // Nested transform -> rejected: the non-literal child is a transform, not a column reference. + // SPJ reasons about a transform via its function and literal params alone, which is unsound + // when the remaining argument is itself a transform. + val nested = bucket(4, TransformExpression(YearsFunction, Seq(a))) + assert(!physical.KeyedPartitioning.supportsExpressions(Seq(nested))) + + // Value-changing slot (a + 1) -> rejected: not a plain column reference. + assert(!physical.KeyedPartitioning.supportsExpressions(Seq(bucket(4, Add(a, Literal(1)))))) + + // Two non-literal column references -> rejected: a partition expression must map to exactly one + // clustering column (the positional keyPositions model needs a single column per transform). + assert(!physical.KeyedPartitioning.supportsExpressions( + Seq(TransformExpression(BucketFunction, Seq(Literal(4), a, b))))) + } + + test("SPARK-50593: integer truncate is reducible via lcm (generalized reducer, non-bucket)") { + // A second reducible transform exercising the generalized Literal[] reducer API with reducer + // math distinct from bucket (GCD) and string truncate (prefix-min): integer truncate snaps to + // a coarser grid, so truncate(v, W1) and truncate(v, W2) reduce onto multiples of lcm(W1, W2). + import org.apache.spark.sql.catalyst.expressions.Expression + val id = attr("id") + def itrunc(e: Expression, w: Int): TransformExpression = + TransformExpression(IntegerTruncateFunction, Seq(e, Literal(w))) + + // Same width -> same function (no reduction needed). + assert(itrunc(id, 4).isSameFunction(itrunc(id, 4))) + + // W2 is a multiple of W1: the finer side (W1=2) reduces onto the coarser grid (W2=4). + assert(itrunc(id, 2).isCompatible(itrunc(id, 4))) + val r = itrunc(id, 2).reducers(itrunc(id, 4)) + assert(r.isDefined, "truncate(2) must reduce onto truncate(4)") + val red = r.get.asInstanceOf[Reducer[Integer, Integer]] + // truncate(.,2) values snapped to multiples of 4: 6 -> 4, 2 -> 0, 8 -> 8 + assert(red.reduce(6) == 4 && red.reduce(2) == 0 && red.reduce(8) == 8) + // The coarser side (4) is already the common grid -> no reducer. + assert(itrunc(id, 4).reducers(itrunc(id, 2)).isEmpty) + + // Neither divides the other: both sides reduce to the lcm grid. + assert(itrunc(id, 6).isCompatible(itrunc(id, 4))) // lcm(6, 4) = 12 + assert(itrunc(id, 6).reducers(itrunc(id, 4)).isDefined) + assert(itrunc(id, 4).reducers(itrunc(id, 6)).isDefined) + assert(itrunc(id, 3).isCompatible(itrunc(id, 5))) // coprime -> lcm(3, 5) = 15 + } + + test("SPARK-50593: deprecated int reducer API still works (legacy connector backward compat)") { + // The reducer dispatch attempts the deprecated reducer(int, func, int) first for single-int + // params, so a ReducibleFunction that overrides ONLY the deprecated method still reduces. + // This mirrors how Iceberg 1.10.0 (and earlier) ship -- they predate the Literal[] API. + val bucketExpr4 = TransformExpression(LegacyBucketFunction, Seq(Literal(4), attr("id"))) + val bucketExpr2 = TransformExpression(LegacyBucketFunction, Seq(Literal(2), attr("id"))) + + val reducer = bucketExpr4.reducers(bucketExpr2) + assert(reducer.isDefined, "Expected a reducer for legacy_bucket(4) on legacy_bucket(2)") + + // Verify the returned Reducer actually reduces bucket 4 -> bucket 2 (GCD = 2). + // bucket(4, x) produces values in [0, 4); reducing by GCD=2 gives v % 2. + val r = reducer.get.asInstanceOf[Reducer[Integer, Integer]] + assert(r.reduce(3) == 1, s"Expected reduce(3) == 1, got ${r.reduce(3)}") + assert(r.reduce(2) == 0, s"Expected reduce(2) == 0, got ${r.reduce(2)}") + } + + test("SPARK-50593: a non-IntegerType param (DateType) does not reach the deprecated " + + "int reducer") { + // DateType is stored as a boxed Integer (epoch days) internally, so the reducer dispatch must + // key off the DataType, not the runtime class -- otherwise a DateType param is mistaken for the + // bucket-style int param and routed to the deprecated reducer(int, ...). LegacyBucketFunction + // overrides ONLY that deprecated method, so with a DateType param it must be unreachable, + // leaving the pair not reducible (rather than producing a bogus GCD reducer over epoch-days). + val l = TransformExpression(LegacyBucketFunction, Seq(Literal(8, DateType), attr("id"))) + val r = TransformExpression(LegacyBucketFunction, Seq(Literal(4, DateType), attr("id"))) + assert(!l.isSameFunction(r)) + assert(!l.isCompatible(r), "a DateType param must not reach the deprecated int reducer") + assert(l.reducers(r).isEmpty && r.reducers(l).isEmpty) + } + + test("SPARK-50593: mismatched column/literal argument layout is not reducible") { + // Both transforms pass the strict gate (one column-reference non-literal child), but the column + // and literal sit in swapped positions: truncate(id, 2) is (col, lit) while truncate(4, sid) is + // (lit, col). The reducer only sees the literal positions ([2] vs [4]), so without an + // argument-layout check it would wrongly reduce these and co-locate non-matching rows. + // IntegerTruncateFunction has two same-typed (Int) args, which makes this layout reachable. + val l = TransformExpression(IntegerTruncateFunction, Seq(attr("id"), Literal(2))) + val r = TransformExpression(IntegerTruncateFunction, Seq(Literal(4), attr("store_id"))) + assert(!l.isSameFunction(r)) + assert(!l.isCompatible(r), "swapped column/literal layout must not be reducible") + assert(l.reducers(r).isEmpty && r.reducers(l).isEmpty) + + // Control: same layout (col, lit) on both sides remains reducible via lcm(2, 4). + val a = TransformExpression(IntegerTruncateFunction, Seq(attr("id"), Literal(2))) + val b = TransformExpression(IntegerTruncateFunction, Seq(attr("store_id"), Literal(4))) + assert(a.isCompatible(b), "aligned (col, lit) layout must remain reducible") + } + + test("SPARK-50593: deprecated reducer returning null falls back to the generalized overload") { + // DualApiBucketFunction implements both overloads: the deprecated reducer(int, ...) returns + // null, while the generalized reducer(Literal[], ...) returns a valid GCD reducer. The dispatch + // tries the deprecated one first; a null there must fall through to the generalized overload + // (Option.orElse fires on None), not be mistaken for "not reducible". + val l = TransformExpression(DualApiBucketFunction, Seq(Literal(4), attr("id"))) + val r = TransformExpression(DualApiBucketFunction, Seq(Literal(2), attr("store_id"))) + assert(l.isCompatible(r), "a null from the deprecated reducer must fall back to the new API") + val red = l.reducers(r) + assert(red.isDefined, "generalized reducer must be reached when deprecated returns null") + assert(red.get.asInstanceOf[Reducer[Integer, Integer]].reduce(3) == 1) + } + + test("SPARK-50593: a complex (non-scalar) literal param is not reducible") { + // Reducer parameters must be scalar literals. ArrayParamFunction's generalized reducer returns + // a reducer unconditionally, so reaching it at all is the leak; the scalar guard must refuse + // the ArrayType literal param first. Different array values keep isSameFunction false, forcing + // the reducer path where the guard applies. + val l = TransformExpression(ArrayParamFunction, + Seq(Literal.create(Array(1, 2, 3), ArrayType(IntegerType)), attr("id"))) + val r = TransformExpression(ArrayParamFunction, + Seq(Literal.create(Array(4, 5, 6), ArrayType(IntegerType)), attr("store_id"))) + assert(!l.isSameFunction(r)) + assert(!l.isCompatible(r), "a complex literal param must not be reducible") + assert(l.reducers(r).isEmpty && r.reducers(l).isEmpty) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 35102c6893d3b..b24a2c731adaf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -21,6 +21,7 @@ import java.time.temporal.ChronoUnit import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.connector.expressions.Literal import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -213,11 +214,14 @@ object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, In } override def reducer( - thisNumBuckets: Int, + thisParams: Array[Literal[_]], otherFunc: ReducibleFunction[_, _], - otherNumBuckets: Int): Reducer[Int, Int] = { + otherParams: Array[Literal[_]]): Reducer[Int, Int] = { if (otherFunc == BucketFunction) { + val thisNumBuckets = thisParams(0).value().asInstanceOf[Int] + val otherNumBuckets = otherParams(0).value().asInstanceOf[Int] + val gcd = this.gcd(thisNumBuckets, otherNumBuckets) if (gcd > 1 && gcd != thisNumBuckets) { return BucketReducer(gcd) @@ -235,6 +239,95 @@ case class BucketReducer(divisor: Int) extends Reducer[Int, Int] { override def displayName(): String = toString } +/** + * A bucket function that only overrides the deprecated `reducer(int, func, int)` method, + * not the new `reducer(Literal[], func, Literal[])` method. + * + * Used to verify that the default implementation of the new method correctly falls back + * to the deprecated int-based API, so legacy implementations continue to work. + */ +object LegacyBucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] { + override def inputTypes(): Array[DataType] = Array(IntegerType, LongType) + override def resultType(): DataType = IntegerType + override def name(): String = "legacy_bucket" + override def canonicalName(): String = name() + override def toString: String = name() + override def produceResult(input: InternalRow): Int = { + Math.floorMod(input.getLong(1), input.getInt(0)) + } + + override def reducer( + thisNumBuckets: Int, + otherFunc: ReducibleFunction[_, _], + otherNumBuckets: Int): Reducer[Int, Int] = { + if (otherFunc == LegacyBucketFunction) { + val gcd = BigInt(thisNumBuckets).gcd(BigInt(otherNumBuckets)).toInt + if (gcd > 1 && gcd != thisNumBuckets) { + return BucketReducer(gcd) + } + } + null + } +} + +/** + * A bucket function that implements BOTH reducer overloads: the deprecated `reducer(int, ..., int)` + * always returns null (not reducible via the old API), while the new `reducer(Literal[], ...)` + * returns a GCD-based reducer. Used to verify that the dispatch falls back to the generalized + * overload when the deprecated one returns null (not only when it throws). + */ +object DualApiBucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] { + override def inputTypes(): Array[DataType] = Array(IntegerType, LongType) + override def resultType(): DataType = IntegerType + override def name(): String = "dual_bucket" + override def canonicalName(): String = name() + override def toString: String = name() + override def produceResult(input: InternalRow): Int = { + Math.floorMod(input.getLong(1), input.getInt(0)) + } + + // Deprecated API: intentionally signals "not reducible" via null (not via an exception). + override def reducer( + thisNumBuckets: Int, + otherFunc: ReducibleFunction[_, _], + otherNumBuckets: Int): Reducer[Int, Int] = null + + // New API: a real GCD-based reducer. + override def reducer( + thisParams: Array[Literal[_]], + otherFunc: ReducibleFunction[_, _], + otherParams: Array[Literal[_]]): Reducer[Int, Int] = { + if (otherFunc == DualApiBucketFunction) { + val thisNumBuckets = thisParams(0).value().asInstanceOf[Int] + val otherNumBuckets = otherParams(0).value().asInstanceOf[Int] + val gcd = BigInt(thisNumBuckets).gcd(BigInt(otherNumBuckets)).toInt + if (gcd > 1 && gcd != thisNumBuckets) { + return BucketReducer(gcd) + } + } + null + } +} + +/** + * A function with a complex (ArrayType) literal parameter. Its generalized reducer returns a valid + * reducer unconditionally, so a test can prove the dispatch refuses to invoke it for a non-scalar + * literal param (rather than the call happening to fail on a cast). + */ +object ArrayParamFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] { + override def inputTypes(): Array[DataType] = Array(ArrayType(IntegerType), LongType) + override def resultType(): DataType = IntegerType + override def name(): String = "array_param" + override def canonicalName(): String = name() + override def toString: String = name() + override def produceResult(input: InternalRow): Int = input.getInt(1) + + override def reducer( + thisParams: Array[Literal[_]], + otherFunc: ReducibleFunction[_, _], + otherParams: Array[Literal[_]]): Reducer[Int, Int] = BucketReducer(1) +} + object UnboundStringSelfFunction extends UnboundFunction { override def bind(inputType: StructType): BoundFunction = StringSelfFunction override def description(): String = name() @@ -253,12 +346,35 @@ object StringSelfFunction extends ScalarFunction[UTF8String] { } object UnboundTruncateFunction extends UnboundFunction { - override def bind(inputType: StructType): BoundFunction = TruncateFunction + override def bind(inputType: StructType): BoundFunction = { + if (inputType.size == 2) { + inputType.head.dataType match { + case StringType => TruncateFunction + case IntegerType => IntegerTruncateFunction + case _ => + throw new UnsupportedOperationException( + s"'truncate' does not support data type: ${inputType.head.dataType}") + } + } else { + throw new UnsupportedOperationException( + "'truncate' requires exactly 2 arguments: (column, width)") + } + } + override def description(): String = name() override def name(): String = "truncate" } -object TruncateFunction extends ScalarFunction[UTF8String] { +/** + * Truncate transform for String type. + * Follows Iceberg spec: truncate(str, L) = str[0:L] + * + * Implements ReducibleFunction: ANY two different widths are compatible. + * The reducer uses the smaller width. + */ +object TruncateFunction + extends ScalarFunction[UTF8String] + with ReducibleFunction[UTF8String, UTF8String] { override def inputTypes(): Array[DataType] = Array(StringType, IntegerType) override def resultType(): DataType = StringType override def name(): String = "truncate" @@ -266,7 +382,84 @@ object TruncateFunction extends ScalarFunction[UTF8String] { override def toString: String = name() override def produceResult(input: InternalRow): UTF8String = { val str = input.getUTF8String(0) - val length = input.getInt(1) - str.substring(0, length) + val width = input.getInt(1) + str.substring(0, width) + } + + override def reducer( + thisParams: Array[Literal[_]], + otherFunc: ReducibleFunction[_, _], + otherParams: Array[Literal[_]]): Reducer[UTF8String, UTF8String] = { + + if (otherFunc == TruncateFunction) { + val thisWidth = thisParams(0).value().asInstanceOf[Int] + val otherWidth = otherParams(0).value().asInstanceOf[Int] + val smallerWidth = math.min(thisWidth, otherWidth) + + if (smallerWidth != thisWidth) { + return TruncateReducer(smallerWidth) + } + } + null + } +} + +case class TruncateReducer(width: Int) extends Reducer[UTF8String, UTF8String] { + override def reduce(value: UTF8String): UTF8String = { + value.substring(0, width) + } + override def resultType(): DataType = StringType + override def displayName(): String = s"truncate($width)" +} + +/** + * Truncate transform for Integer type. + * Follows Iceberg spec: truncate(value, W) = value - (((value % W) + W) % W), which snaps `value` + * down to a multiple of `W`. + * + * Implements ReducibleFunction: truncate(v, W1) and truncate(v, W2) are always reducible onto a + * common coarser grid of multiples of lcm(W1, W2). The finer side (whose width does not already + * equal the lcm) reduces by snapping to that grid; when W2 is a multiple of W1 the lcm is simply + * the coarser width W2. + */ +object IntegerTruncateFunction + extends ScalarFunction[Int] + with ReducibleFunction[Int, Int] { + override def inputTypes(): Array[DataType] = Array(IntegerType, IntegerType) + override def resultType(): DataType = IntegerType + override def name(): String = "truncate" + override def canonicalName(): String = name() + override def toString: String = name() + override def produceResult(input: InternalRow): Int = { + val value = input.getInt(0) + val width = input.getInt(1) + value - (((value % width) + width) % width) } + + override def reducer( + thisParams: Array[Literal[_]], + otherFunc: ReducibleFunction[_, _], + otherParams: Array[Literal[_]]): Reducer[Int, Int] = { + if (otherFunc == IntegerTruncateFunction) { + val thisWidth = thisParams(0).value().asInstanceOf[Int] + val otherWidth = otherParams(0).value().asInstanceOf[Int] + val common = lcm(thisWidth, otherWidth) + // Only the finer side reduces; if `common == thisWidth` this side is already the common grid. + if (common != thisWidth) { + return IntTruncateReducer(common) + } + } + null + } + + private def lcm(a: Int, b: Int): Int = { + val g = BigInt(a).gcd(BigInt(b)) + (BigInt(a) / g * BigInt(b)).toInt + } +} + +case class IntTruncateReducer(width: Int) extends Reducer[Int, Int] { + override def reduce(value: Int): Int = value - (((value % width) + width) % width) + override def resultType(): DataType = IntegerType + override def displayName(): String = s"truncate($width)" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala index a70baece77844..629d65bb20c0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, TransformExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal, TransformExpression} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning, KeyedPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning} import org.apache.spark.sql.connector.catalog.functions.{BucketFunction, YearsFunction} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -492,7 +492,7 @@ class ProjectedOrderingAndPartitioningSuite // KP([bucket(32, id)], keys1d) through Project(id as pk) should produce // KP([bucket(32, pk)], keys1d): the alias is pushed into the bucket's column argument. val id = AttributeReference("id", IntegerType)() - val bucketExpr = TransformExpression(BucketFunction, Seq(id), Some(32)) + val bucketExpr = TransformExpression(BucketFunction, Seq(Literal(32), id)) val keys1d = Seq(InternalRow(0), InternalRow(1), InternalRow(2)) val child = DummyLeafExecWithPartitioning( output = Seq(id), @@ -507,7 +507,7 @@ class ProjectedOrderingAndPartitioningSuite case te: TransformExpression => assert(te.isSameFunction(bucketExpr), "bucket function and numBuckets must be preserved after alias substitution") - assert(te.children.head.asInstanceOf[Attribute].name === "pk", + assert(te.children.collectFirst { case a: Attribute => a }.get.name === "pk", "bucket's column argument must be rewritten to the aliased attribute") case other => fail(s"Expected TransformExpression, got $other") } @@ -524,7 +524,7 @@ class ProjectedOrderingAndPartitioningSuite // Result: KP([bucket(32, id)], keys1d, isNarrowed=true, isGrouped=false). val id = AttributeReference("id", IntegerType)() val ts = AttributeReference("ts", IntegerType)() - val bucketExpr = TransformExpression(BucketFunction, Seq(id), Some(32)) + val bucketExpr = TransformExpression(BucketFunction, Seq(Literal(32), id)) val yearsExpr = TransformExpression(YearsFunction, Seq(ts)) // Projected to position [0] (bucket): (0),(1),(0) -- bucket value 0 appears twice. val keys2d = Seq(InternalRow(0, 2020), InternalRow(1, 2020), InternalRow(0, 2021)) @@ -539,7 +539,7 @@ class ProjectedOrderingAndPartitioningSuite kp.expressions.head match { case te: TransformExpression => assert(te.isSameFunction(bucketExpr), "bucket must be the surviving expression") - assert(te.children.head.asInstanceOf[Attribute].name === "id") + assert(te.children.collectFirst { case a: Attribute => a }.get.name === "id") case other => fail(s"Expected TransformExpression, got $other") } assert(kp.isNarrowed, "dropping years(ts) position must mark the KP as narrowed") @@ -554,7 +554,7 @@ class ProjectedOrderingAndPartitioningSuite // Result: KP([bucket(32, id), years(ts_alias)], keys2d) -- not narrowed. val id = AttributeReference("id", IntegerType)() val ts = AttributeReference("ts", IntegerType)() - val bucketExpr = TransformExpression(BucketFunction, Seq(id), Some(32)) + val bucketExpr = TransformExpression(BucketFunction, Seq(Literal(32), id)) val yearsExpr = TransformExpression(YearsFunction, Seq(ts)) val keys2d = Seq(InternalRow(0, 2020), InternalRow(1, 2020), InternalRow(0, 2021)) val child = DummyLeafExecWithPartitioning( @@ -569,14 +569,14 @@ class ProjectedOrderingAndPartitioningSuite kp.expressions(0) match { case te: TransformExpression => assert(te.isSameFunction(bucketExpr)) - assert(te.children.head.asInstanceOf[Attribute].name === "id", + assert(te.children.collectFirst { case a: Attribute => a }.get.name === "id", "bucket's argument must remain id (no alias for id in this projection)") case other => fail(s"Expected TransformExpression at pos 0, got $other") } kp.expressions(1) match { case te: TransformExpression => assert(te.isSameFunction(yearsExpr)) - assert(te.children.head.asInstanceOf[Attribute].name === "ts_alias", + assert(te.children.collectFirst { case a: Attribute => a }.get.name === "ts_alias", "years() argument must be rewritten to ts_alias") case other => fail(s"Expected TransformExpression at pos 1, got $other") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 17d00ec055e07..84eee883aeeda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -1191,11 +1191,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { } def bucket(numBuckets: Int, expr: Expression): TransformExpression = { - TransformExpression(BucketFunction, Seq(expr), Some(numBuckets)) + TransformExpression(BucketFunction, Seq(Literal(numBuckets), expr)) } def buckets(numBuckets: Int, expr: Seq[Expression]): TransformExpression = { - TransformExpression(BucketFunction, expr, Some(numBuckets)) + TransformExpression(BucketFunction, Seq(Literal(numBuckets)) ++ expr) } test("ShufflePartitionIdPassThrough - avoid unnecessary shuffle when children are compatible") {