From 07c7bd49039bb6fea424dcd46d5f74e301ccf6f4 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Wed, 12 Jan 2022 01:03:16 +0800 Subject: [PATCH] fix --- .../optimizer/NestedColumnAliasing.scala | 4 +- .../org/apache/spark/sql/DataFrameSuite.scala | 54 +++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala index 77a25ecb04ef5..9d63f4e94647c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala @@ -245,11 +245,13 @@ object NestedColumnAliasing { val otherRootReferences = new mutable.ArrayBuffer[AttributeReference]() exprList.foreach { e => collectRootReferenceAndExtractValue(e).foreach { - case ev: ExtractValue => + // we can not alias the attr from lambda variable whose expr id is not available + case ev: ExtractValue if ev.find(_.isInstanceOf[NamedLambdaVariable]).isEmpty => if (ev.references.size == 1) { nestedFieldReferences.append(ev) } case ar: AttributeReference => otherRootReferences.append(ar) + case _ => // ignore } } val exclusiveAttrSet = AttributeSet(exclusiveAttrs ++ otherRootReferences) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 374c86775e4ac..7482d76207388 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -3033,6 +3033,60 @@ class DataFrameSuite extends QueryTest } } } + + test("SPARK-37855: IllegalStateException when transforming an array inside a nested struct") { + def makeInput(): DataFrame = { + val innerElement1 = Row(3, 3.12) + val innerElement2 = Row(4, 2.1) + val innerElement3 = Row(1, 985.2) + val innerElement4 = Row(10, 757548.0) + val innerElement5 = Row(1223, 0.665) + + val outerElement1 = Row(1, Row(List(innerElement1, innerElement2))) + val outerElement2 = Row(2, Row(List(innerElement3))) + val outerElement3 = Row(3, Row(List(innerElement4, innerElement5))) + + val data = Seq( + Row("row1", List(outerElement1)), + Row("row2", List(outerElement2, outerElement3)) + ) + + val schema = new StructType() + .add("name", StringType) + .add("outer_array", ArrayType(new StructType() + .add("id", IntegerType) + .add("inner_array_struct", new StructType() + .add("inner_array", ArrayType(new StructType() + .add("id", IntegerType) + .add("value", DoubleType) + )) + ) + )) + + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + val df = makeInput().limit(2) + + val res = df.withColumn("extracted", transform( + col("outer_array"), + c1 => { + struct( + c1.getField("id").alias("outer_id"), + transform( + c1.getField("inner_array_struct").getField("inner_array"), + c2 => { + struct( + c2.getField("value").alias("inner_value") + ) + } + ) + ) + } + )) + + assert(res.collect.length == 2) + } } case class GroupByKey(a: Int, b: Int)