diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala index 03b5517f6df05..a6be98c8a3aae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala @@ -42,7 +42,8 @@ case class ProjectionOverSchema(schema: StructType) { getProjection(a.child).map(p => (p, p.dataType)).map { case (projection, ArrayType(projSchema @ StructType(_), _)) => // For case-sensitivity aware field resolution, we should take `ordinal` which - // points to correct struct field. + // points to correct struct field, because `ExtractValue` actually does column + // name resolving correctly. val selectedField = a.child.dataType.asInstanceOf[ArrayType] .elementType.asInstanceOf[StructType](a.ordinal) val prunedField = projSchema(selectedField.name) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala index 0be2792bfd7db..5b12667f4a884 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala @@ -231,6 +231,27 @@ object NestedColumnAliasing { * of it. */ object GeneratorNestedColumnAliasing { + // Partitions `attrToAliases` based on whether the attribute is in Generator's output. + private def aliasesOnGeneratorOutput( + attrToAliases: Map[ExprId, Seq[Alias]], + generatorOutput: Seq[Attribute]) = { + val generatorOutputExprId = generatorOutput.map(_.exprId) + attrToAliases.partition { k => + generatorOutputExprId.contains(k._1) + } + } + + // Partitions `nestedFieldToAlias` based on whether the attribute of nested field extractor + // is in Generator's output. + private def nestedFieldOnGeneratorOutput( + nestedFieldToAlias: Map[ExtractValue, Alias], + generatorOutput: Seq[Attribute]) = { + val generatorOutputSet = AttributeSet(generatorOutput) + nestedFieldToAlias.partition { pair => + pair._1.references.subsetOf(generatorOutputSet) + } + } + def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match { // Either `nestedPruningOnExpressions` or `nestedSchemaPruningEnabled` is enabled, we // need to prune nested columns through Project and under Generate. The difference is @@ -241,12 +262,81 @@ object GeneratorNestedColumnAliasing { // On top on `Generate`, a `Project` that might have nested column accessors. // We try to get alias maps for both project list and generator's children expressions. val exprsToPrune = projectList ++ g.generator.children - NestedColumnAliasing.getAliasSubMap(exprsToPrune, g.qualifiedGeneratorOutput).map { + NestedColumnAliasing.getAliasSubMap(exprsToPrune).map { case (nestedFieldToAlias, attrToAliases) => + val (nestedFieldsOnGenerator, nestedFieldsNotOnGenerator) = + nestedFieldOnGeneratorOutput(nestedFieldToAlias, g.qualifiedGeneratorOutput) + val (attrToAliasesOnGenerator, attrToAliasesNotOnGenerator) = + aliasesOnGeneratorOutput(attrToAliases, g.qualifiedGeneratorOutput) + + // Push nested column accessors through `Generator`. // Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`. - val newChild = - NestedColumnAliasing.replaceWithAliases(g, nestedFieldToAlias, attrToAliases) - Project(NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias), newChild) + val newChild = NestedColumnAliasing.replaceWithAliases(g, + nestedFieldsNotOnGenerator, attrToAliasesNotOnGenerator) + val pushedThrough = Project(NestedColumnAliasing + .getNewProjectList(projectList, nestedFieldsNotOnGenerator), newChild) + + // If the generator output is `ArrayType`, we cannot push through the extractor. + // It is because we don't allow field extractor on two-level array, + // i.e., attr.field when attr is a ArrayType(ArrayType(...)). + // Similarily, we also cannot push through if the child of generator is `MapType`. + g.generator.children.head.dataType match { + case _: MapType => return Some(pushedThrough) + case ArrayType(_: ArrayType, _) => return Some(pushedThrough) + case _ => + } + + // Pruning on `Generator`'s output. We only process single field case. + // For multiple field case, we cannot directly move field extractor into + // the generator expression. A workaround is to re-construct array of struct + // from multiple fields. But it will be more complicated and may not worth. + // TODO(SPARK-34956): support multiple fields. + if (nestedFieldsOnGenerator.size > 1 || nestedFieldsOnGenerator.isEmpty) { + pushedThrough + } else { + // Only one nested column accessor. + // E.g., df.select(explode($"items").as("item")).select($"item.a") + pushedThrough match { + case p @ Project(_, newG: Generate) => + // Replace the child expression of `ExplodeBase` generator with + // nested column accessor. + // E.g., df.select(explode($"items").as("item")).select($"item.a") => + // df.select(explode($"items.a").as("item.a")) + val rewrittenG = newG.transformExpressions { + case e: ExplodeBase => + val extractor = nestedFieldsOnGenerator.head._1.transformUp { + case _: Attribute => + e.child + case g: GetStructField => + ExtractValue(g.child, Literal(g.extractFieldName), SQLConf.get.resolver) + } + e.withNewChildren(Seq(extractor)) + } + + // As we change the child of the generator, its output data type must be updated. + val updatedGeneratorOutput = rewrittenG.generatorOutput + .zip(rewrittenG.generator.elementSchema.toAttributes) + .map { case (oldAttr, newAttr) => + newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name) + } + assert(updatedGeneratorOutput.length == rewrittenG.generatorOutput.length, + "Updated generator output must have the same length " + + "with original generator output.") + val updatedGenerate = rewrittenG.copy(generatorOutput = updatedGeneratorOutput) + + // Replace nested column accessor with generator output. + p.withNewChildren(Seq(updatedGenerate)).transformExpressions { + case f: ExtractValue if nestedFieldsOnGenerator.contains(f) => + updatedGenerate.output + .find(a => attrToAliasesOnGenerator.contains(a.exprId)) + .getOrElse(f) + } + + case other => + // We should not reach here. + throw new IllegalStateException(s"Unreasonable plan after optimization: $other") + } + } } case g: Generate if SQLConf.get.nestedSchemaPruningEnabled && diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala index 0ae4d3f6e6801..a856caa6781e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala @@ -329,14 +329,14 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { comparePlans(optimized, expected) } - test("Nested field pruning for Project and Generate: not prune on generator output") { + test("Nested field pruning for Project and Generate: multiple-field case is not supported") { val companies = LocalRelation( 'id.int, 'employers.array(employer)) val query = companies .generate(Explode('employers.getField("company")), outputNames = Seq("company")) - .select('company.getField("name")) + .select('company.getField("name"), 'company.getField("address")) .analyze val optimized = Optimize.execute(query) @@ -347,7 +347,8 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { .generate(Explode($"${aliases(0)}"), unrequiredChildIndex = Seq(0), outputNames = Seq("company")) - .select('company.getField("name").as("company.name")) + .select('company.getField("name").as("company.name"), + 'company.getField("address").as("company.address")) .analyze comparePlans(optimized, expected) } @@ -684,6 +685,29 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { ).analyze comparePlans(optimized2, expected2) } + + test("SPARK-34638: nested column prune on generator output for one field") { + val companies = LocalRelation( + 'id.int, + 'employers.array(employer)) + + val query = companies + .generate(Explode('employers.getField("company")), outputNames = Seq("company")) + .select('company.getField("name")) + .analyze + val optimized = Optimize.execute(query) + + val aliases = collectGeneratedAliases(optimized) + + val expected = companies + .select('employers.getField("company").getField("name").as(aliases(0))) + .generate(Explode($"${aliases(0)}"), + unrequiredChildIndex = Seq(0), + outputNames = Seq("company")) + .select('company.as("company.name")) + .analyze + comparePlans(optimized, expected) + } } object NestedColumnAliasingSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index 765d2fc584a7d..ac5c28953a5d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -351,6 +351,43 @@ abstract class SchemaPruningSuite } } + testSchemaPruning("SPARK-34638: nested column prune on generator output") { + val query1 = spark.table("contacts") + .select(explode(col("friends")).as("friend")) + .select("friend.first") + checkScan(query1, "struct>>") + checkAnswer(query1, Row("Susan") :: Nil) + + // Currently we don't prune multiple field case. + val query2 = spark.table("contacts") + .select(explode(col("friends")).as("friend")) + .select("friend.first", "friend.middle") + checkScan(query2, "struct>>") + checkAnswer(query2, Row("Susan", "Z.") :: Nil) + + val query3 = spark.table("contacts") + .select(explode(col("friends")).as("friend")) + .select("friend.first", "friend.middle", "friend") + checkScan(query3, "struct>>") + checkAnswer(query3, Row("Susan", "Z.", Row("Susan", "Z.", "Smith")) :: Nil) + } + + testSchemaPruning("SPARK-34638: nested column prune on generator output - case-sensitivity") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val query1 = spark.table("contacts") + .select(explode(col("friends")).as("friend")) + .select("friend.First") + checkScan(query1, "struct>>") + checkAnswer(query1, Row("Susan") :: Nil) + + val query2 = spark.table("contacts") + .select(explode(col("friends")).as("friend")) + .select("friend.MIDDLE") + checkScan(query2, "struct>>") + checkAnswer(query2, Row("Z.") :: Nil) + } + } + testSchemaPruning("select one deep nested complex field after repartition") { val query = sql("select * from contacts") .repartition(100) @@ -816,4 +853,21 @@ abstract class SchemaPruningSuite Row("John", "Y.") :: Nil) } } + + test("SPARK-34638: queries should not fail on unsupported cases") { + withTable("nested_array") { + sql("select * from values array(array(named_struct('a', 1, 'b', 3), " + + "named_struct('a', 2, 'b', 4))) T(items)").write.saveAsTable("nested_array") + val query = sql("select d.a from (select explode(c) d from " + + "(select explode(items) c from nested_array))") + checkAnswer(query, Row(1) :: Row(2) :: Nil) + } + + withTable("map") { + sql("select * from values map(1, named_struct('a', 1, 'b', 3), " + + "2, named_struct('a', 2, 'b', 4)) T(items)").write.saveAsTable("map") + val query = sql("select d.a from (select explode(items) (c, d) from map)") + checkAnswer(query, Row(1) :: Row(2) :: Nil) + } + } }