Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.sql.connector.catalog.functions;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.Literal;

/**
* Base class for user-defined functions that can be 'reduced' on another function.
Expand Down Expand Up @@ -60,6 +61,37 @@
@Evolving
public interface ReducibleFunction<I, O> {

/**
* Generic reducer for parameterized functions (bucket, truncate, etc.).
*
* If this function is 'reducible' on another function, return the {@link Reducer}.
* <p>
* Each parameter is a scalar {@link Literal} carrying both its value and data type. Parameters
* are always scalar literals (e.g. bucket numBuckets, truncate width); complex types (array, map,
* struct) are not passed here. {@link Literal#value()} is Spark's internal representation (e.g.
* {@code UTF8String} for strings, {@code Decimal} for decimals); use {@link Literal#dataType()}
* to interpret it.
* <p>
* Examples:
* <ul>
* <li>bucket(4, x) and bucket(2, x): thisParams = [4], otherParams = [2]</li>
* <li>truncate(x, 3) and truncate(x, 5): thisParams = [3], otherParams = [5]</li>
* <li>hypothetical range_bucket(x, 0L, 100L, 4): thisParams = [0L, 100L, 4]</li>
* </ul>
*
* @param thisParams literal parameters for this function
* @param otherFunction the other parameterized function
* @param otherParams literal parameters for the other function
* @return a reduction function if reducible, null otherwise
* @since 4.3.0
*/
default Reducer<I, O> reducer(
Literal<?>[] thisParams,
ReducibleFunction<?, ?> otherFunction,
Literal<?>[] otherParams) {
throw new UnsupportedOperationException();
}

/**
* This method is for the bucket function.
*
Expand All @@ -78,7 +110,12 @@ public interface ReducibleFunction<I, O> {
* @param otherBucketFunction the other parameterized function
* @param otherNumBuckets parameter for the other function
* @return a reduction function if it is reducible, null if not
* @deprecated as of 4.3.0. Please override
* {@link #reducer(Literal[], ReducibleFunction, Literal[])} instead.
* The new overload supports transforms with any number of parameters of any type
* (e.g. truncate width, multi-arg range buckets), not just a single int.
*/
@Deprecated(since = "4.3.0")
default Reducer<I, O> reducer(
int thisNumBuckets,
ReducibleFunction<?, ?> otherBucketFunction,
Expand All @@ -101,6 +138,6 @@ default Reducer<I, O> reducer(
* @return a reduction function if it is reducible, null if not.
*/
default Reducer<I, O> reducer(ReducibleFunction<?, ?> otherFunction) {
throw new UnsupportedOperationException();
return reducer(new Literal<?>[0], otherFunction, new Literal<?>[0]);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,45 +17,79 @@

package org.apache.spark.sql.catalyst.expressions

import scala.annotation.tailrec
import scala.util.Try

import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.FUNCTION_NAME
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction, ScalarFunction}
import org.apache.spark.sql.connector.expressions.{Literal => V2Literal, LiteralValue}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{AtomicType, DataType, IntegerType}

/**
* Represents a partition transform expression, for instance, `bucket`, `days`, `years`, etc.
*
* @param function the transform function itself. Spark will use it to decide whether two
* partition transform expressions are compatible.
* @param numBucketsOpt the number of buckets if the transform is `bucket`. Unset otherwise.
*/
case class TransformExpression(
function: BoundFunction,
children: Seq[Expression],
numBucketsOpt: Option[Int] = None) extends Expression {
case class TransformExpression(function: BoundFunction, children: Seq[Expression])
extends Expression with Logging {

override def nullable: Boolean = true

/**
* Whether this [[TransformExpression]] has the same semantics as `other`.
* For instance, `bucket(32, c)` is equal to `bucket(32, d)`, but not to `bucket(16, d)` or
* `year(c)`.
* Extract literal children (constant parameters) from this transform. These are constant
* arguments like width in truncate(col, width). Literals are compared when checking if two
* transforms are the same.
*/
private lazy val literalChildren: Seq[Literal] =
children.collect { case l: Literal => l }

/**
* Whether this [[TransformExpression]] has the same semantics as `other`. For instance,
* `bucket(32, c)` is equal to `bucket(32, d)`, but not to `bucket(16, d)` or `year(c)`.
* Similarly, `truncate(c, 2)` is equal to `truncate(d, 2)`, but may not to `truncate(c, 4)`.
*
* This will be used, for instance, by Spark to determine whether storage-partitioned join can
* be triggered, by comparing partition transforms from both sides of the join and checking
* whether they are compatible.
*
* @param other the transform expression to compare to
* @return true if this and `other` has the same semantics w.r.t to transform, false otherwise.
* Two transforms are considered the same when they have the same function name, the same arity,
* and each pair of corresponding children matches:
* - literal arguments must be equal (e.g. numBuckets for bucket, width for truncate), so that
* `bucket(32, c)` is not the same as `bucket(16, c)`;
* - nested transform arguments must recursively be the same function, so that
* `bucket(4, years(c))` is not the same as `bucket(4, days(c))`;
* - everything else must be a plain column reference on both sides. Column identity is
* intentionally ignored (it is reconciled separately via positional matching), but a
* non-reference slot such as `c + 1` or `cast(c)`, or a literal/transform-vs-reference
* mismatch, is treated as not the same.
*
* @param other
* the transform expression to compare to
* @return
* true if this and `other` has the same semantics w.r.t to transform, false otherwise.
*/
def isSameFunction(other: TransformExpression): Boolean = other match {
case TransformExpression(otherFunction, _, otherNumBucketsOpt) =>
function.canonicalName() == otherFunction.canonicalName() &&
numBucketsOpt == otherNumBucketsOpt
case _ =>
false
}
def isSameFunction(other: TransformExpression): Boolean =
function.canonicalName() == other.function.canonicalName() &&
childrenMatch(other)(_ == _)

/**
* Per-position match of children, requiring equal arity. Literal slots are compared by the
* caller-supplied `literalsMatch`; nested transform slots must recursively be the same function;
* any other slot must be a plain column reference on both sides.
*/
private def childrenMatch(other: TransformExpression)
(literalsMatch: (Literal, Literal) => Boolean): Boolean =
children.length == other.children.length &&
children.zip(other.children).forall {
case (l1: Literal, l2: Literal) => literalsMatch(l1, l2)
case (t1: TransformExpression, t2: TransformExpression) => t1.isSameFunction(t2)
case (c1, c2) => TransformExpression.isColumnRef(c1) && TransformExpression.isColumnRef(c2)
}

/**
* Whether this [[TransformExpression]]'s function is compatible with the `other`
Expand All @@ -73,8 +107,8 @@ case class TransformExpression(
} else {
(function, other.function) match {
case (f: ReducibleFunction[_, _], o: ReducibleFunction[_, _]) =>
val thisReducer = reducer(f, numBucketsOpt, o, other.numBucketsOpt)
val otherReducer = reducer(o, other.numBucketsOpt, f, numBucketsOpt)
val thisReducer = reducer(f, this, o, other)
val otherReducer = reducer(o, other, f, this)
thisReducer.isDefined || otherReducer.isDefined
case _ => false
}
Expand All @@ -92,24 +126,94 @@ case class TransformExpression(
*/
def reducers(other: TransformExpression): Option[Reducer[_, _]] = {
(function, other.function) match {
case(e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) =>
reducer(e1, numBucketsOpt, e2, other.numBucketsOpt)
case (e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) =>
reducer(e1, this, e2, other)
case _ => None
}
}

// Return a Reducer for a reducible function on another reducible function
/**
* Extract all literal parameters of this transform as V2 [[V2Literal]]s, preserving each value's
* internal representation and its `DataType`. Connectors interpret the value via the accompanying
* `DataType` rather than relying on a pre-converted JVM type.
*
* Examples:
* bucket(4, col) => [Literal(4, IntegerType)]
* truncate(col, 3) => [Literal(3, IntegerType)]
* days(col) => [] (no literals)
*/
private def extractParameters: Array[V2Literal[_]] =
literalChildren.map(l => LiteralValue(l.value, l.dataType): V2Literal[_]).toArray

/**
* Reducer precondition: same argument layout/structure as `other` (arity, aligned slots, equal
* nested transforms, column refs elsewhere). Only literal *values* may differ. Unlike
* [[isSameFunction]] the function name is not compared.
*/
private def sameArgumentLayout(other: TransformExpression): Boolean =

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sameArgumentLayout and isSameFunction (L76-86) are now two separate children.zip(...) walkers. The literal axis differing is deliberate — equality for "same transform" vs. "values may differ, a Reducer reconciles them". But the non-literal handling drifted apart, and sameArgumentLayout ended up strictly weaker:

position pair isSameFunction sameArgumentLayout
(literal, literal) l1 == l2 true — intended
(transform, transform) recursive isSameFunction case _ => true
(col, transform) isColumnRef && isColumnRef case _ => true

So sameArgumentLayout treats any two non-literals as interchangeable, and its correctness rests entirely on the supportsExpressions gate guaranteeing the single non-literal child is always a plain column ref. On its own it would let two structurally different transforms that share a function name and literal layout reduce (e.g. bucket(4, days(c)) vs bucket(4, hours(c))).

That's fine as the code stands. But since these were a single parameterized helper a commit ago (childrenMatch), it'd be nice to bring sameArgumentLayout back in sync with isSameFunction so the precondition is correct on its own rather than gate-dependent (future-proof):

/** Per-position match; `literalsMatch` is the only axis that varies between callers. */
private def childrenMatch(other: TransformExpression)
    (literalsMatch: (Literal, Literal) => Boolean): Boolean =
  children.length == other.children.length &&
    children.zip(other.children).forall {
      case (l1: Literal, l2: Literal) => literalsMatch(l1, l2)
      case (t1: TransformExpression, t2: TransformExpression) => t1.isSameFunction(t2)
      case (c1, c2) => TransformExpression.isColumnRef(c1) && TransformExpression.isColumnRef(c2)
    }

def isSameFunction(other: TransformExpression): Boolean =
  function.canonicalName() == other.function.canonicalName() &&
    childrenMatch(other)(_ == _)

// reducer precondition: same layout/structure, literal *values* may differ
private def sameArgumentLayout(other: TransformExpression): Boolean =
  childrenMatch(other)((_, _) => true)

Equally fine to leave as-is if you'd rather not touch it now — mainly flagging so the gate-dependency is a conscious choice.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peter-toth I like the idea. Yea, very similar to my parameterized helper created for nested transform. Will make the change.
Thanks.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P2] This equal-arity precondition prevents the generalized reducer from handling valid mixed-arity transform pairs. For example, both raw(x) and mod(x, 2) pass KeyedPartitioning.supportsExpressions because each has one direct column reference, and a connector can validly implement a reducer for [] versus [2]; however, childrenMatch rejects their [column] versus [column, literal] children before either reducer is called. This reintroduces the parameterized-vs-zero-parameter gap discussed earlier, where we agreed to leave compatibility to the connector, and forces an unnecessary shuffle. Please preserve the positional slot-safety check without globally requiring equal child counts, and add a zero-vs-one-parameter regression test.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good point.

@metanil, please disregard my previous request regarding sameArgumentLayout().

childrenMatch(other)((_, _) => true)

/**
* Whether every literal parameter is a scalar (an [[AtomicType]]). Reducer parameters are scalar
* literals; this never forwards a complex Catalyst container (ArrayData / MapData / InternalRow)
* across the public reducer boundary -- such a transform is simply treated as not reducible.
*/
private def scalarLiteralParams: Boolean =

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the name is not clear. literalParamsAreScalar, or allLiteralParamsAreScalar?

literalChildren.forall(_.dataType.isInstanceOf[AtomicType])

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P2] AtomicType is stricter than the documented non-complex scalar contract. CalendarIntervalType is explicitly non-complex in Spark, although it does not extend AtomicType, and Expressions.literal(new CalendarInterval(...)) creates a valid self-describing V2 literal. A connector whose transform returns a comparable key type can therefore pass KeyedPartitioning.supportsExpressions, yet differing interval parameters are rejected here before its generalized reducer can reconcile them, forcing an unnecessary shuffle. The public reducer Javadoc excludes array/map/struct parameters and says the new overload supports parameters of any type. Please reject the specified container types rather than all non-AtomicType values, and add a positive CalendarIntervalType reducer test.


/**
* Return a Reducer for a reducible function on another reducible function
* Handles both parameterized (bucket, truncate) and non-parameterized (days, hours) functions.
*/
private def reducer(
thisFunction: ReducibleFunction[_, _],
thisNumBucketsOpt: Option[Int],
thisExpr: TransformExpression,
otherFunction: ReducibleFunction[_, _],
otherNumBucketsOpt: Option[Int]): Option[Reducer[_, _]] = {
val res = (thisNumBucketsOpt, otherNumBucketsOpt) match {
case (Some(numBuckets), Some(otherNumBuckets)) =>
thisFunction.reducer(numBuckets, otherFunction, otherNumBuckets)
case _ => thisFunction.reducer(otherFunction)
otherExpr: TransformExpression): Option[Reducer[_, _]] = {
if (!thisExpr.sameArgumentLayout(otherExpr) ||
!thisExpr.scalarLiteralParams || !otherExpr.scalarLiteralParams) {
return None
}

val thisParams = thisExpr.extractParameters
val otherParams = otherExpr.extractParameters
val thisName = thisExpr.function.canonicalName()

// Gate on DataType, not the boxed runtime class (DateType/YearMonthInterval box to Int).
def isSingleInt(p: Array[V2Literal[_]]): Boolean = {
p.length == 1 && p(0).dataType == IntegerType
}

// Run a reducer overload; a thrown exception or a null both become None. warnOnUoe logs a hint
// when the function implements no usable reducer overload.
def attempt[R](call: => R, warnOnUoe: Boolean): Option[R] = {
val t = Try(Option(call))
if (warnOnUoe) {
t.failed.foreach {
case _: UnsupportedOperationException =>
logWarning(log"V2 function ${MDC(FUNCTION_NAME, thisName)} threw " +
log"UnsupportedOperationException; treating as not reducible. Override " +
log"reducer(Literal[], ReducibleFunction, Literal[]) to enable SPJ.")
case _ =>
}
}
t.toOption.flatten
}

if (thisParams.isEmpty && otherParams.isEmpty) {
attempt(thisFunction.reducer(otherFunction), warnOnUoe = true)
} else if (isSingleInt(thisParams) && isSingleInt(otherParams)) {
// Try the deprecated int API first (legacy connectors); fall back to the generalized overload
// when it is absent or returns null. Option.orElse fires on None, covering both.
attempt(thisFunction.reducer(
thisParams(0).value().asInstanceOf[Int], otherFunction,
otherParams(0).value().asInstanceOf[Int]), warnOnUoe = false)
.orElse(
attempt(thisFunction.reducer(thisParams, otherFunction, otherParams), warnOnUoe = true))
} else {
// Parameterized functions (bucket, truncate, etc.)
attempt(thisFunction.reducer(thisParams, otherFunction, otherParams), warnOnUoe = true)
}
Comment on lines +187 to 216

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional cleanup (non-blocking): the warnOnUoe flag encodes a whole-function question -- "did no overload turn out to be implemented?" -- as a per-attempt flag. You can drop the flag by giving attempt a three-state result (None = not implemented / threw, Some(None) = implemented but not reducible, Some(r) = reducible) and deciding the warning once from the aggregate.

As a bonus this fixes a subtle false-warning: today if a connector implements the deprecated API and deliberately returns null (not reducible) while not implementing the new API, the terminal attempt throws UOE and warnOnUoe = true fires -- even though the function is implemented and intentionally said "no". With the version below, that case is Some(None) and does not trigger the hint; the warning fires only when every overload threw.

Note Try(Option(call)).toOption preserves the current behavior of treating any throwable as "not implemented" (same as the existing t.toOption.flatten). If you'd rather only treat UnsupportedOperationException that way and rethrow everything else, a Try(...) match { case Success(r) => Some(r); case Failure(_: UnsupportedOperationException) => None; case Failure(e) => throw e } form works too.

Suggested change
// Run a reducer overload; a thrown exception or a null both become None. warnOnUoe logs a hint
// when the function implements no usable reducer overload.
def attempt[R](call: => R, warnOnUoe: Boolean): Option[R] = {
val t = Try(Option(call))
if (warnOnUoe) {
t.failed.foreach {
case _: UnsupportedOperationException =>
logWarning(log"V2 function ${MDC(FUNCTION_NAME, thisName)} threw " +
log"UnsupportedOperationException; treating as not reducible. Override " +
log"reducer(Literal[], ReducibleFunction, Literal[]) to enable SPJ.")
case _ =>
}
}
t.toOption.flatten
}
if (thisParams.isEmpty && otherParams.isEmpty) {
attempt(thisFunction.reducer(otherFunction), warnOnUoe = true)
} else if (isSingleInt(thisParams) && isSingleInt(otherParams)) {
// Try the deprecated int API first (legacy connectors); fall back to the generalized overload
// when it is absent or returns null. Option.orElse fires on None, covering both.
attempt(thisFunction.reducer(
thisParams(0).value().asInstanceOf[Int], otherFunction,
otherParams(0).value().asInstanceOf[Int]), warnOnUoe = false)
.orElse(
attempt(thisFunction.reducer(thisParams, otherFunction, otherParams), warnOnUoe = true))
} else {
// Parameterized functions (bucket, truncate, etc.)
attempt(thisFunction.reducer(thisParams, otherFunction, otherParams), warnOnUoe = true)
}
// Probe a reducer overload, distinguishing three outcomes so the warning fires only when *no*
// overload is implemented -- not on the deprecated-API probe, which throws UOE by design for
// connectors that implement only the new Literal[] API:
// None -> overload not implemented (threw, e.g. UnsupportedOperationException)
// Some(None) -> implemented, but not reducible for these params (returned null)
// Some(r) -> implemented and reducible
def attempt(call: => Reducer[_, _]): Option[Option[Reducer[_, _]]] =
Try(Option(call)).toOption
val attempts: Seq[Option[Option[Reducer[_, _]]]] =
if (thisParams.isEmpty && otherParams.isEmpty) {
Seq(attempt(thisFunction.reducer(otherFunction)))
} else if (isSingleInt(thisParams) && isSingleInt(otherParams)) {
// Try the deprecated int API first (legacy connectors), then the generalized overload.
Seq(
attempt(thisFunction.reducer(
thisParams(0).value().asInstanceOf[Int], otherFunction,
otherParams(0).value().asInstanceOf[Int])),
attempt(thisFunction.reducer(thisParams, otherFunction, otherParams)))
} else {
// Parameterized functions (bucket, truncate, etc.)
Seq(attempt(thisFunction.reducer(thisParams, otherFunction, otherParams)))
}
// First implemented-and-reducible overload wins. Warn only when every overload threw (i.e.
// nothing is implemented); a deliberate null from an implemented overload is Some(None) and
// does not trigger the hint.
val result = attempts.flatten.flatten.headOption
if (result.isEmpty && attempts.forall(_.isEmpty)) {
logWarning(log"V2 function ${MDC(FUNCTION_NAME, thisName)} implements no reducer; " +
log"treating as not reducible. Override " +
log"reducer(Literal[], ReducibleFunction, Literal[]) to enable SPJ.")
}
result

Option(res)
}

override def dataType: DataType = function.resultType()
Expand All @@ -118,10 +222,7 @@ case class TransformExpression(
copy(children = newChildren)

private lazy val resolvedFunction: Option[Expression] = this match {
case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) =>
Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc,
Seq(Literal(numBuckets)) ++ arguments))
case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) =>
case TransformExpression(scalarFunc: ScalarFunction[_], arguments) =>
Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments))
case _ => None
}
Expand All @@ -136,3 +237,18 @@ case class TransformExpression(
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this)
}

object TransformExpression {
/**
* Whether `e` is a bare column reference: an [[Attribute]] or a [[GetStructField]] chain
* (struct-field access on a column). Shared by [[TransformExpression.isSameFunction]] and by
* `KeyedPartitioning.supportsExpressions`, which both decide whether a transform's single
* non-literal argument is a plain column.
*/
@tailrec
private[sql] def isColumnRef(e: Expression): Boolean = e match {
case _: Attribute => true
case g: GetStructField => isColumnRef(g.child)
case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan,
import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier}
import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
import org.apache.spark.sql.connector.expressions.{BucketTransform, Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform}
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform}
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue}
import org.apache.spark.sql.connector.read.{SampleMethod => V2SampleMethod}
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
Expand Down Expand Up @@ -116,17 +116,6 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
funCatalogOpt: Option[FunctionCatalog] = None): Option[Expression] = trans match {
case IdentityTransform(ref) =>
Some(resolveRef[NamedExpression](ref, query))
case BucketTransform(numBuckets, refs, sorted)
if sorted.isEmpty && refs.length == 1 && refs.forall(_.isInstanceOf[NamedReference]) =>
val resolvedRefs = refs.map(r => resolveRef[NamedExpression](r, query))
// Create a dummy reference for `numBuckets` here and use that, together with `refs`, to
// look up the V2 function.
val numBucketsRef = AttributeReference("numBuckets", IntegerType, nullable = false)()
funCatalogOpt.flatMap { catalog =>
loadV2FunctionOpt(catalog, "bucket", Seq(numBucketsRef) ++ resolvedRefs).map { bound =>
TransformExpression(bound, resolvedRefs, Some(numBuckets))
}
}
case NamedTransform(name, args) =>
val catalystArgs = args.map(toCatalyst(_, query, funCatalogOpt))
funCatalogOpt.flatMap { catalog =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.catalyst.plans.physical

import scala.annotation.tailrec
import scala.collection.mutable

import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
Expand Down Expand Up @@ -641,19 +640,15 @@ object KeyedPartitioning {

def supportsExpressions(expressions: Seq[Expression]): Boolean = {
def isSupportedTransform(transform: TransformExpression): Boolean = {
transform.children.size == 1 && isReference(transform.children.head)
}

@tailrec
def isReference(e: Expression): Boolean = e match {
case _: Attribute => true
case g: GetStructField => isReference(g.child)
case _ => false
// Should only consider column references, not literals.
val nonLiteralChildren = transform.children.filterNot(_.isInstanceOf[Literal])
// We need exactly one column reference per transform.
nonLiteralChildren.size == 1 && TransformExpression.isColumnRef(nonLiteralChildren.head)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] Apply bound-function input casts before evaluating parameterized transforms through SPJ

This gate now admits parameterized transforms whose bound function requests legal implicit casts. BoundFunction.inputTypes() explicitly allows types to differ from those passed to bind, with Spark responsible for casting. However, the scan path stores the raw Catalyst children in TransformExpression, and the identity-vs-transform reducer later binds and directly evaluates that expression without running Analyzer type coercion.

For a concrete case, a connector can report truncate(str_col, Expressions.literal(2.toShort)); binding (StringType, ShortType) to a scalar function that declares (StringType, IntegerType) is legal. This transform passes the new gate, but the synthetic reducer evaluates the raw Short literal through ApplyFunctionExpression. Its SpecificInternalRow(IntegerType) then attempts to cast the boxed Short to Int, raising ClassCastException. Before parameterized transforms were admitted here, this query would fall back to a shuffle.

Please coerce the transform children to function.inputTypes() before direct reducer evaluation (or reject mismatched types from this SPJ path), and add an identity-vs-parameterized-transform regression covering an implicitly cast literal.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sunchao Thanks. I was able to reproduce this issue which raises ClassCastException.
I'll work on a fix, and let you know.

}

expressions.forall {
case t: TransformExpression if isSupportedTransform(t) => true
case e: Expression if isReference(e) => true
case e: Expression if TransformExpression.isColumnRef(e) => true
case _ => false
}
}
Expand Down Expand Up @@ -1335,7 +1330,13 @@ case class KeyedShuffleSpec(

val newExpressions = partitioning.expressions.zip(keyPositions).map {
case (te: TransformExpression, positionSet) =>
te.copy(children = te.children.map(_ => clustering(positionSet.head)))
// Preserve literal parameters (e.g., numBuckets, truncate width)
// while replacing only column references with the new clustering expression
val newChildren = te.children.map {
case l: Literal => l // Keep literals as-is
case _ => clustering(positionSet.head) // Replace column references
}
te.copy(children = newChildren)
case (_, positionSet) => clustering(positionSet.head)
}
KeyedPartitioning(newExpressions, partitioning.partitionKeys, partitioning.isGrouped)
Expand Down
Loading