Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
import org.apache.spark.internal.config.ConfigBindingPolicy
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis.TableOutputResolver.DefaultValueFillMode._
import org.apache.spark.sql.catalyst.analysis.resolver.{
AnalyzerBridgeState,
HybridAnalyzer,
Expand Down Expand Up @@ -3788,9 +3789,16 @@ class Analyzer(
validateStoreAssignmentPolicy()
TableOutputResolver.suitableForByNameCheck(v2Write.isByName,
expected = v2Write.table.output, queryOutput = v2Write.query.output)
// With schema evolution + coercion flag, missing top-level columns AND missing nested
// struct fields are filled with defaults/null (RECURSE mode). Otherwise, only missing
// top-level columns are filled via FILL mode; missing nested struct fields still cause
// schema enforcement errors.
val defaultValueFillMode =
if (conf.coerceInsertNestedTypes && v2Write.schemaEvolutionEnabled) RECURSE
else FILL
val projection = TableOutputResolver.resolveOutputColumns(
v2Write.table.name, v2Write.table.output, v2Write.query, v2Write.isByName, conf,
supportColDefaultValue = true)
defaultValueFillMode)
if (projection != v2Write.query) {
val cleanedTable = v2Write.table match {
case r: DataSourceV2Relation =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ object TableOutputResolver extends SQLConfHelper with Logging {

/**
* Modes for filling in default or null values for missing columns.
* If FILL, fill missing top-level columns with their default values.
* If RECURSE, fill missing top-level columns and also recurse into nested struct
* fields to fill null.
* If FILL, fill missing top-level columns with their default values (by-name reorder path).
* If RECURSE, fill missing top-level columns (including trailing columns on the by-position
* path for INSERT with schema evolution when enabled) and recurse into nested structs,
* arrays, and maps to fill missing struct fields with null or defaults.
* If NONE, do not fill any missing columns.
*/
object DefaultValueFillMode extends Enumeration {
Expand Down Expand Up @@ -92,33 +93,38 @@ object TableOutputResolver extends SQLConfHelper with Logging {
query: LogicalPlan,
byName: Boolean,
conf: SQLConf,
supportColDefaultValue: Boolean = false): LogicalPlan = {
defaultValueFillMode: DefaultValueFillMode.Value = NONE): LogicalPlan = {

if (expected.size < query.output.size) {
throw QueryCompilationErrors.cannotWriteTooManyColumnsToTableError(
tableName, expected.map(_.name), query.output)
}

// In RECURSE mode, allow fewer source columns than target by filling trailing columns
// with defaults. In other modes, a column count mismatch in by-position resolution is
// an error.
val fillDefaultValue = defaultValueFillMode == RECURSE
val errors = new mutable.ArrayBuffer[String]()
val resolved: Seq[NamedExpression] = if (byName) {
// If a top-level column does not have a corresponding value in the input query, fill with
// the column's default value. We need to pass `fillDefaultValue` as FILL here, if the
// `supportColDefaultValue` parameter is also true.
val defaultValueFillMode = if (supportColDefaultValue) FILL else NONE
// By-name resolution: the defaultValueFillMode is passed through to control whether
// missing top-level columns are filled (FILL/RECURSE) and whether missing nested
// struct fields are also filled (RECURSE only).
reorderColumnsByName(
tableName,
query.output,
expected,
conf,
errors += _,
Nil,
defaultValueFillMode)
defaultValueFillMode,
enforceFullOutput = true)
} else {
if (expected.size > query.output.size) {
if (expected.size > query.output.size && !fillDefaultValue) {
throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError(
tableName, expected.map(_.name), query.output)
}
resolveColumnsByPosition(tableName, query.output, expected, conf, errors += _)
resolveColumnsByPosition(
tableName, query.output, expected, conf, errors += _, fillDefaultValue = fillDefaultValue)
}

if (errors.nonEmpty) {
Expand Down Expand Up @@ -157,17 +163,17 @@ object TableOutputResolver extends SQLConfHelper with Logging {
case (valueType: StructType, colType: StructType) =>
val resolvedValue = resolveStructType(
tableName, value, valueType, col, colType,
byName = true, conf, addError, colPath, fillChildDefaultValue)
byName = true, conf, addError, colPath, fillChildDefaultValue, enforceFullOutput = false)
resolvedValue.getOrElse(value)
case (valueType: ArrayType, colType: ArrayType) =>
val resolvedValue = resolveArrayType(
tableName, value, valueType, col, colType,
byName = true, conf, addError, colPath, fillChildDefaultValue)
byName = true, conf, addError, colPath, fillChildDefaultValue, enforceFullOutput = false)
resolvedValue.getOrElse(value)
case (valueType: MapType, colType: MapType) =>
val resolvedValue = resolveMapType(
tableName, value, valueType, col, colType,
byName = true, conf, addError, colPath, fillChildDefaultValue)
byName = true, conf, addError, colPath, fillChildDefaultValue, enforceFullOutput = false)
resolvedValue.getOrElse(value)
case _ =>
checkUpdate(tableName, value, col, conf, addError, colPath)
Expand Down Expand Up @@ -304,7 +310,8 @@ object TableOutputResolver extends SQLConfHelper with Logging {
conf: SQLConf,
addError: String => Unit,
colPath: Seq[String] = Nil,
defaultValueFillMode: DefaultValueFillMode.Value): Seq[NamedExpression] = {
defaultValueFillMode: DefaultValueFillMode.Value,
enforceFullOutput: Boolean = false): Seq[NamedExpression] = {
val matchedCols = mutable.HashSet.empty[String]
val reordered = expectedCols.flatMap { expectedCol =>
val matched = inputCols.filter(col => conf.resolver(col.name, expectedCol.name))
Expand Down Expand Up @@ -336,15 +343,15 @@ object TableOutputResolver extends SQLConfHelper with Logging {
case (matchedType: StructType, expectedType: StructType) =>
resolveStructType(
tableName, matchedCol, matchedType, actualExpectedCol, expectedType,
byName = true, conf, addError, newColPath, childFillDefaultValue)
byName = true, conf, addError, newColPath, childFillDefaultValue, enforceFullOutput)
case (matchedType: ArrayType, expectedType: ArrayType) =>
resolveArrayType(
tableName, matchedCol, matchedType, actualExpectedCol, expectedType,
byName = true, conf, addError, newColPath, childFillDefaultValue)
byName = true, conf, addError, newColPath, childFillDefaultValue, enforceFullOutput)
case (matchedType: MapType, expectedType: MapType) =>
resolveMapType(
tableName, matchedCol, matchedType, actualExpectedCol, expectedType,
byName = true, conf, addError, newColPath, childFillDefaultValue)
byName = true, conf, addError, newColPath, childFillDefaultValue, enforceFullOutput)
case _ =>
checkField(
tableName, actualExpectedCol, matchedCol, byName = true, conf, addError, newColPath)
Expand All @@ -366,6 +373,11 @@ object TableOutputResolver extends SQLConfHelper with Logging {
} else {
reordered
}
} else if (enforceFullOutput) {
val colName =
if (colPath.nonEmpty) colPath.quoted
else expectedCols.map(_.name).map(toSQLId).mkString(", ")
throw QueryCompilationErrors.incompatibleDataToTableCannotFindDataError(tableName, colName)
} else {
Nil
}
Expand All @@ -377,7 +389,8 @@ object TableOutputResolver extends SQLConfHelper with Logging {
expectedCols: Seq[Attribute],
conf: SQLConf,
addError: String => Unit,
colPath: Seq[String] = Nil): Seq[NamedExpression] = {
colPath: Seq[String] = Nil,
fillDefaultValue: Boolean = false): Seq[NamedExpression] = {
val actualExpectedCols = expectedCols.map { attr =>
attr.withDataType { CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType) }
}
Expand All @@ -393,7 +406,7 @@ object TableOutputResolver extends SQLConfHelper with Logging {
tableName, colPath.quoted, extraColsStr
)
}
} else if (inputCols.size < actualExpectedCols.size) {
} else if (inputCols.size < actualExpectedCols.size && !fillDefaultValue) {
val missingColsStr = actualExpectedCols.takeRight(actualExpectedCols.size - inputCols.size)
.map(col => toSQLId(col.name))
.mkString(", ")
Expand All @@ -407,25 +420,48 @@ object TableOutputResolver extends SQLConfHelper with Logging {
}
}

inputCols.zip(actualExpectedCols).flatMap { case (inputCol, expectedCol) =>
val matched = inputCols.zip(actualExpectedCols).flatMap { case (inputCol, expectedCol) =>
val newColPath = colPath :+ expectedCol.name
(inputCol.dataType, expectedCol.dataType) match {
case (inputType: StructType, expectedType: StructType) =>
resolveStructType(
tableName, inputCol, inputType, expectedCol, expectedType,
byName = false, conf, addError, newColPath, fillDefaultValue = false)
byName = false, conf, addError, newColPath, fillDefaultValue, enforceFullOutput = true)
case (inputType: ArrayType, expectedType: ArrayType) =>
resolveArrayType(
tableName, inputCol, inputType, expectedCol, expectedType,
byName = false, conf, addError, newColPath, fillDefaultValue = false)
byName = false, conf, addError, newColPath, fillDefaultValue, enforceFullOutput = true)
case (inputType: MapType, expectedType: MapType) =>
resolveMapType(
tableName, inputCol, inputType, expectedCol, expectedType,
byName = false, conf, addError, newColPath, fillDefaultValue = false)
byName = false, conf, addError, newColPath, fillDefaultValue, enforceFullOutput = true)
case _ =>
checkField(tableName, expectedCol, inputCol, byName = false, conf, addError, newColPath)
}
}

val defaults = if (fillDefaultValue) {
actualExpectedCols.drop(inputCols.size).map { expectedCol =>
val defaultExpr = getDefaultValueExprOrNullLit(
expectedCol, conf.useNullsForMissingDefaultColumnValues)
if (defaultExpr.isEmpty) {
throw QueryCompilationErrors.incompatibleDataToTableCannotFindDataError(
tableName, (colPath :+ expectedCol.name).quoted)
}
applyColumnMetadata(defaultExpr.get, expectedCol)
}
} else {
Nil
}

val result = matched ++ defaults
if (result.length != actualExpectedCols.size) {
val colName =
if (colPath.nonEmpty) colPath.quoted
else actualExpectedCols.map(_.name).map(toSQLId).mkString(", ")
throw QueryCompilationErrors.incompatibleDataToTableCannotFindDataError(tableName, colName)
}
result
}

private[sql] def checkNullability(
Expand All @@ -447,6 +483,7 @@ object TableOutputResolver extends SQLConfHelper with Logging {
input.nullable && !attr.nullable && conf.storeAssignmentPolicy != StoreAssignmentPolicy.LEGACY
}

// scalastyle:off argcount
private def resolveStructType(
tableName: String,
input: Expression,
Expand All @@ -457,18 +494,19 @@ object TableOutputResolver extends SQLConfHelper with Logging {
conf: SQLConf,
addError: String => Unit,
colPath: Seq[String],
fillDefaultValue: Boolean): Option[NamedExpression] = {
fillDefaultValue: Boolean,
enforceFullOutput: Boolean): Option[NamedExpression] = {
val nullCheckedInput = checkNullability(input, expected, conf, colPath)
val fields = inputType.zipWithIndex.map { case (f, i) =>
Alias(GetStructField(nullCheckedInput, i, Some(f.name)), f.name)()
}
val defaultValueMode = if (fillDefaultValue) RECURSE else NONE
val resolved = if (byName) {
reorderColumnsByName(tableName, fields, toAttributes(expectedType), conf, addError, colPath,
defaultValueMode)
defaultValueMode, enforceFullOutput)
} else {
resolveColumnsByPosition(
tableName, fields, toAttributes(expectedType), conf, addError, colPath)
tableName, fields, toAttributes(expectedType), conf, addError, colPath, fillDefaultValue)
}
if (resolved.length == expectedType.length) {
val struct = CreateStruct(resolved)
Expand All @@ -478,6 +516,11 @@ object TableOutputResolver extends SQLConfHelper with Logging {
struct
}
Some(applyColumnMetadata(res, expected))
} else if (enforceFullOutput) {
val colName =
if (colPath.nonEmpty) colPath.quoted
else expectedType.fields.map(_.name).map(toSQLId).mkString(", ")
throw QueryCompilationErrors.incompatibleDataToTableCannotFindDataError(tableName, colName)
} else {
None
}
Expand All @@ -493,17 +536,19 @@ object TableOutputResolver extends SQLConfHelper with Logging {
conf: SQLConf,
addError: String => Unit,
colPath: Seq[String],
fillDefaultValue: Boolean): Option[NamedExpression] = {
fillDefaultValue: Boolean,
enforceFullOutput: Boolean): Option[NamedExpression] = {
val nullCheckedInput = checkNullability(input, expected, conf, colPath)
val param = NamedLambdaVariable("element", inputType.elementType, inputType.containsNull)
val fakeAttr =
AttributeReference("element", expectedType.elementType, expectedType.containsNull)()
val res = if (byName) {
val defaultValueMode = if (fillDefaultValue) RECURSE else NONE
reorderColumnsByName(tableName, Seq(param), Seq(fakeAttr), conf, addError, colPath,
defaultValueMode)
defaultValueMode, enforceFullOutput)
} else {
resolveColumnsByPosition(tableName, Seq(param), Seq(fakeAttr), conf, addError, colPath)
resolveColumnsByPosition(
tableName, Seq(param), Seq(fakeAttr), conf, addError, colPath, fillDefaultValue)
}
if (res.length == 1) {
val castedArray =
Expand All @@ -515,6 +560,9 @@ object TableOutputResolver extends SQLConfHelper with Logging {
ArrayTransform(nullCheckedInput, func)
}
Some(applyColumnMetadata(castedArray, expected))
} else if (enforceFullOutput) {
val colName = if (colPath.nonEmpty) colPath.quoted else toSQLId(expected.name)
throw QueryCompilationErrors.incompatibleDataToTableCannotFindDataError(tableName, colName)
} else {
None
}
Expand All @@ -530,17 +578,19 @@ object TableOutputResolver extends SQLConfHelper with Logging {
conf: SQLConf,
addError: String => Unit,
colPath: Seq[String],
fillDefaultValue: Boolean): Option[NamedExpression] = {
fillDefaultValue: Boolean,
enforceFullOutput: Boolean): Option[NamedExpression] = {
val nullCheckedInput = checkNullability(input, expected, conf, colPath)

val keyParam = NamedLambdaVariable("key", inputType.keyType, nullable = false)
val fakeKeyAttr = AttributeReference("key", expectedType.keyType, nullable = false)()
val defaultValueFillMode = if (fillDefaultValue) RECURSE else NONE
val resKey = if (byName) {
reorderColumnsByName(tableName, Seq(keyParam), Seq(fakeKeyAttr), conf, addError, colPath,
defaultValueFillMode)
defaultValueFillMode, enforceFullOutput)
} else {
resolveColumnsByPosition(tableName, Seq(keyParam), Seq(fakeKeyAttr), conf, addError, colPath)
resolveColumnsByPosition(
tableName, Seq(keyParam), Seq(fakeKeyAttr), conf, addError, colPath, fillDefaultValue)
}

val valueParam =
Expand All @@ -549,10 +599,10 @@ object TableOutputResolver extends SQLConfHelper with Logging {
AttributeReference("value", expectedType.valueType, expectedType.valueContainsNull)()
val resValue = if (byName) {
reorderColumnsByName(tableName, Seq(valueParam), Seq(fakeValueAttr), conf, addError, colPath,
defaultValueFillMode)
defaultValueFillMode, enforceFullOutput)
} else {
resolveColumnsByPosition(
tableName, Seq(valueParam), Seq(fakeValueAttr), conf, addError, colPath)
tableName, Seq(valueParam), Seq(fakeValueAttr), conf, addError, colPath, fillDefaultValue)
}

if (resKey.length == 1 && resValue.length == 1) {
Expand All @@ -577,10 +627,14 @@ object TableOutputResolver extends SQLConfHelper with Logging {
MapFromArrays(newKeys, newValues)
}
Some(applyColumnMetadata(casted, expected))
} else if (enforceFullOutput) {
val colName = if (colPath.nonEmpty) colPath.quoted else toSQLId(expected.name)
throw QueryCompilationErrors.incompatibleDataToTableCannotFindDataError(tableName, colName)
} else {
None
}
}
// scalastyle:on argcount

// For table insertions, capture the overflow errors and show proper message.
// Without this method, the overflow errors of castings will show hints for turning off ANSI SQL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7270,6 +7270,18 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val INSERT_INTO_NESTED_TYPE_COERCION_ENABLED =
buildConf("spark.sql.insertNestedTypeCoercion.enabled")
.internal()
.doc("If enabled, allow INSERT INTO WITH SCHEMA EVOLUTION to fill missing nested " +
"struct fields with null when the source has fewer nested fields than the target " +
"table. Also relaxes by-position column-count enforcement so trailing missing " +
"top-level columns are filled with their default value (or null). This is " +
"experimental and the semantics may change.")
.version("4.2.0")
.booleanConf
.createWithDefault(false)

val TIME_TYPE_ENABLED =
buildConf("spark.sql.timeType.enabled")
.internal()
Expand Down Expand Up @@ -8597,6 +8609,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def coerceMergeNestedTypes: Boolean =
getConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED)

def coerceInsertNestedTypes: Boolean =
getConf(SQLConf.INSERT_INTO_NESTED_TYPE_COERCION_ENABLED)

def isTimeTypeEnabled: Boolean = getConf(SQLConf.TIME_TYPE_ENABLED)

def listaggAllowDistinctCastWithOrder: Boolean = getConf(LISTAGG_ALLOW_DISTINCT_CAST_WITH_ORDER)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ object PreprocessTableInsertion extends ResolveInsertionBase {
query,
byName,
conf,
supportColDefaultValue = true)
TableOutputResolver.DefaultValueFillMode.FILL)
} catch {
case e: AnalysisException if staticPartCols.nonEmpty &&
(e.getCondition == "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS" ||
Expand Down
Loading