From fc3f931d2b4a35146ac5b93d1c5905e91059fd39 Mon Sep 17 00:00:00 2001 From: acetylzhang Date: Mon, 23 Nov 2020 14:41:41 +0800 Subject: [PATCH 1/8] [SPARK-32002][SQL]Support ExtractValue from nested ArrayStruct --- .../expressions/ProjectionOverSchema.scala | 13 ++ .../catalyst/expressions/SelectedField.scala | 28 +++- .../expressions/complexTypeExtractors.scala | 83 ++++++++++- .../apache/spark/sql/ComplexTypesSuite.scala | 17 ++- .../NestArraySchemaPruningSuite.scala | 129 ++++++++++++++++++ 5 files changed, 263 insertions(+), 7 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/NestArraySchemaPruningSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala index 241c761624b76..67e25fc03343b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala @@ -51,6 +51,19 @@ case class ProjectionOverSchema(schema: StructType) { s"unmatched child schema for GetArrayStructFields: ${projSchema.toString}" ) } + case ExtractNestedArrayField(child, _, _, field, containsNullSeq) => + getProjection(child).map(p => (p, p.dataType)).map { + case (projection, ExtractNestedArray(projSchema @ StructType(_), _, _)) => + ExtractNestedArrayField(projection, + projSchema.fieldIndex(field.name), + projSchema.fields.length, + projSchema(field.name), + containsNullSeq) + case (_, projSchema) => + throw new IllegalStateException( + s"unmatched child schema for ExtractNestedArrayField: ${projSchema.toString}" + ) + } case MapKeys(child) => getProjection(child).map { projection => MapKeys(projection) } case MapValues(child) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala index f2acb75ea6ac4..aa768f69e2222 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala @@ -65,7 +65,8 @@ object SelectedField { /** * Convert an expression into the parts of the schema (the field) it accesses. */ - private def selectField(expr: Expression, dataTypeOpt: Option[DataType]): Option[StructField] = { + private def selectField(expr: Expression, dataTypeOpt: Option[DataType], + nestArray: Boolean = false): Option[StructField] = { expr match { case a: Attribute => dataTypeOpt.map { dt => @@ -81,16 +82,37 @@ object SelectedField { // GetArrayStructFields is the top level extractor. This means its result is // not pruned and we need to use the element type of the array its producing. field.dataType - case Some(ArrayType(dataType, _)) => + case Some(ArrayType(dataType, nullable)) => // GetArrayStructFields is part of a chain of extractors and its result is pruned // by a parent expression. In this case need to use the parent element type. - dataType + if (nestArray) ArrayType(dataType, nullable) else dataType case Some(x) => // This should not happen. throw new AnalysisException(s"DataType '$x' is not supported by GetArrayStructFields.") } val newField = StructField(field.name, newFieldDataType, field.nullable) selectField(child, Option(ArrayType(struct(newField), containsNull))) + case ExtractNestedArrayField(child, _, _, field @ StructField(_, _, _, _), _) => + val newFieldDataType = dataTypeOpt match { + case None => + // ExtractNestedArrayField is the top level extractor. This means its result is + // not pruned and we need to use the element type of the array its producing. + field.dataType + case Some(dataType) => + dataType + } + val structType = struct(StructField(field.name, newFieldDataType, field.nullable)) + + val newDataType = child match { + case ExtractNestedArrayField(_, _, _, childField, _) => + childField.dataType match { + case _: ArrayType => ArrayType(structType) + case _ => structType + } + case _: GetArrayStructFields => ArrayType(structType) + case _ => structType + } + selectField(child, Some(newDataType), nestArray = true) case GetMapValue(child, _, _) => // GetMapValue does not select a field from a struct (i.e. prune the struct) so it can't be // the top-level extractor. However it can be part of an extractor chain. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 767650d022200..078f60115b31e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, + CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -60,6 +61,13 @@ object ExtractValue { GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, fields.length, containsNull || fields(ordinal).nullable) + case (ExtractNestedArray(StructType(fields), _, containsNullSeq), + NonNullLiteral(v, StringType)) if containsNullSeq.nonEmpty => + val fieldName = v.toString + val ordinal = findField(fields, fieldName, resolver) + ExtractNestedArrayField(child, ordinal, fields.length, + fields(ordinal).copy(name = fieldName), containsNullSeq) + case (_: ArrayType, _) => GetArrayItem(child, extraction) case (MapType(kt, _, _), _) => GetMapValue(child, extraction) @@ -218,6 +226,77 @@ case class GetArrayStructFields( } } +object ExtractNestedArray { + type ReturnType = Option[(DataType, Boolean, Seq[Boolean])] + + def unapply(dataType: DataType): ReturnType = { + extractArrayType(dataType) + } + + def extractArrayType(dataType: DataType): ReturnType = { + dataType match { + case ArrayType(dt, containsNull) => + extractArrayType(dt) match { + case Some((d, cn, seq)) => Some((d, cn, containsNull +: seq)) + case None => Some((dt, containsNull, Seq.empty[Boolean])) + } + case _ => None + } + } +} + +case class ExtractNestedArrayField( + child: Expression, + ordinal: Int, + numFields: Int, + field: StructField, + containsNullSeq: Seq[Boolean]) extends UnaryExpression + with ExtractValue with NullIntolerant with CodegenFallback { + + protected override def nullSafeEval(input: Any): Any = { + val array = input.asInstanceOf[ArrayData] + new GenericArrayData( + (0 until array.numElements()).map(n => evalArrayItem(n, array, containsNullSeq.size))) + } + + private def evalArrayItem(original: Int, array: ArrayData, num: Int): ArrayData = { + if (array.isNullAt(original)) { + null + } + else { + val innerArray = array.get(original, nestedArrayType(num)).asInstanceOf[ArrayData] + new GenericArrayData((0 until innerArray.numElements()).map(n => { + if (num == 1) { + extractStruct(n, innerArray) + } + else { + evalArrayItem(n, innerArray, num - 1) + } + })) + } + } + + private def extractStruct(n: Int, array: ArrayData): Any = { + if (array.isNullAt(n)) { + null + } else { + val row = array.getStruct(n, numFields) + if (row.isNullAt(ordinal)) { + null + } else { + row.get(ordinal, field.dataType) + } + } + } + + override def dataType: DataType = ArrayType(nestedArrayType(0)) + + def nestedArrayType(num: Int): DataType = { + (num until containsNullSeq.size).reverse + .foldLeft(field.dataType) { (e, i) => ArrayType(e, containsNullSeq(i))} + } +} + /** * Returns the field at `ordinal` in the Array `child`. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala index bdcf7230e3211..ef715483fa192 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ - import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{ArrayType, StructType} +import org.apache.spark.sql.types.{ArrayType, IntegerType, StructType} class ComplexTypesSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -117,4 +116,18 @@ class ComplexTypesSuite extends QueryTest with SharedSparkSession { val df = spark.createDataFrame(List(Row(Seq(Row(1), Row(null)))).asJava, schema) checkAnswer(df.select($"arr".getField("i")), Row(Seq(1, null))) } + + test("SPARK-32002: Support ExtractValue from nested ArrayStruct") { + val jsonStr1 = """{"a": [{"b": [{"c": [1,2]}]}]}""" + val jsonStr2 = """{"a": [{"b": [{"c": [1]}, {"c": [2]}]}]}""" + val df = spark.read.json(Seq(jsonStr1, jsonStr2).toDS()) + checkAnswer(df.select($"a.b.c"), Row(Seq(Seq(Seq(1, 2)))) + :: Row(Seq(Seq(Seq(1), Seq(2)))) :: Nil) + } + + test("SPARK-32003: Support ExtractValue from nested ArrayStruct") { + val jsonStr1 = """{"a": [{"b": [{"c": [{"d": [1, 2]}]}]}]}""" + val df = spark.read.json(Seq(jsonStr1).toDS()) + checkAnswer(df.select($"a.b.c.d"), Row(Seq(Seq(Seq(Seq(1)))))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/NestArraySchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/NestArraySchemaPruningSuite.scala new file mode 100644 index 0000000000000..99c5931b2c31b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/NestArraySchemaPruningSuite.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.File + +import org.scalactic.Equality + +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.SchemaPruningTest +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StructType + + +class NestArraySchemaPruningSuite + extends QueryTest + with FileBasedDataSourceTest + with SchemaPruningTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { + case class AdRecord(positions: Array[Positions]) + case class Positions(imps: Array[Impression]) + case class Impression(id: String, ad: Advertising, clicks: Array[Clicks]) + case class Advertising(index: Int) + case class Clicks(fraud_type: Int) + + val adRecords = AdRecord(Array(Positions(Array(Impression("1", Advertising(1), + Array(Clicks(0), Clicks(1))))))) :: AdRecord(Array(Positions(Array( + Impression("2", Advertising(2), Array(Clicks(1), Clicks(2))))))) :: Nil + + testSchemaPruning("Nested arrays for pruning schema") { + val queryIndex = sql("select positions.imps.ad.index from adRecords") + checkScan(queryIndex, + "struct>>>>>") + checkAnswer(queryIndex, Row(Seq(Seq(1))) :: Row(Seq(Seq(2))) :: Nil) + + val queryId = sql("select positions.imps.id from adRecords") + checkScan(queryId, + "struct>>>>") + checkAnswer(queryId, Row(Seq(Seq("1"))) :: Row(Seq(Seq("2"))) :: Nil) + + val queryIndexAndFraud = + sql("select positions.imps.ad.index, positions.imps.clicks.fraud_type from adRecords") + checkScan(queryIndexAndFraud, "struct, clicks:array>>>>>>") + checkAnswer(queryIndexAndFraud, Row(Seq(Seq(1)), Seq(Seq(Seq(0, 1)))) + :: Row(Seq(Seq(2)), Seq(Seq(Seq(1, 2)))) :: Nil) + } + + protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = { + test(s"$testName") { + withSQLConf(vectorizedReaderEnabledKey -> "true") { + withData(testThunk) + } + withSQLConf(vectorizedReaderEnabledKey -> "false") { + withData(testThunk) + } + } + } + + private def withData(testThunk: => Unit): Unit = { + withTempPath { dir => + val path = dir.getCanonicalPath + + makeDataSourceFile(adRecords, new File(path + "/ad_records/a=1")) + + val schema = "`positions` ARRAY, `clicks`: ARRAY>>>>>" + spark.read.format(dataSourceName).schema(schema).load(path + "/ad_records") + .createOrReplaceTempView("adRecords") + + testThunk + } + } + + protected val schemaEquality = new Equality[StructType] { + override def areEqual(a: StructType, b: Any): Boolean = + b match { + case otherType: StructType => a.sameType(otherType) + case _ => false + } + } + + protected def checkScan(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { + checkScanSchemata(df, expectedSchemaCatalogStrings: _*) + // We check here that we can execute the query without throwing an exception. The results + // themselves are irrelevant, and should be checked elsewhere as needed + df.collect() + } + + protected def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { + val fileSourceScanSchemata = + collect(df.queryExecution.executedPlan) { + case scan: FileSourceScanExec => scan.requiredSchema + } + assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, + s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + + s"but expected $expectedSchemaCatalogStrings") + fileSourceScanSchemata.zip(expectedSchemaCatalogStrings).foreach { + case (scanSchema, expectedScanSchemaCatalogString) => + val expectedScanSchema = CatalystSqlParser.parseDataType(expectedScanSchemaCatalogString) + implicit val equality = schemaEquality + assert(scanSchema === expectedScanSchema) + } + } + + override protected val dataSourceName: String = "parquet" + override protected val vectorizedReaderEnabledKey: String = + SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key +} From 48ca15dfba4819be6da69e87c5fdcbd23eb877db Mon Sep 17 00:00:00 2001 From: acetylzhang Date: Mon, 23 Nov 2020 15:22:14 +0800 Subject: [PATCH 2/8] Remove duplicate unit tests --- .../test/scala/org/apache/spark/sql/ComplexTypesSuite.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala index ef715483fa192..3a9cfb3a182c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala @@ -125,9 +125,4 @@ class ComplexTypesSuite extends QueryTest with SharedSparkSession { :: Row(Seq(Seq(Seq(1), Seq(2)))) :: Nil) } - test("SPARK-32003: Support ExtractValue from nested ArrayStruct") { - val jsonStr1 = """{"a": [{"b": [{"c": [{"d": [1, 2]}]}]}]}""" - val df = spark.read.json(Seq(jsonStr1).toDS()) - checkAnswer(df.select($"a.b.c.d"), Row(Seq(Seq(Seq(Seq(1)))))) - } } From caf0241598ec0e7164478d71052e73f9bdb7d073 Mon Sep 17 00:00:00 2001 From: acetylzhang Date: Mon, 23 Nov 2020 16:27:25 +0800 Subject: [PATCH 3/8] Fix code style --- .../catalyst/expressions/complexTypeExtractors.scala | 12 ++++++------ .../org/apache/spark/sql/ComplexTypesSuite.scala | 3 ++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 078f60115b31e..c71decd7b4365 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -246,12 +246,12 @@ object ExtractNestedArray { } case class ExtractNestedArrayField( - child: Expression, - ordinal: Int, - numFields: Int, - field: StructField, - containsNullSeq: Seq[Boolean]) extends UnaryExpression - with ExtractValue with NullIntolerant with CodegenFallback { + child: Expression, + ordinal: Int, + numFields: Int, + field: StructField, + containsNullSeq: Seq[Boolean]) extends UnaryExpression + with ExtractValue with NullIntolerant with CodegenFallback { protected override def nullSafeEval(input: Any): Any = { val array = input.asInstanceOf[ArrayData] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala index 3a9cfb3a182c0..054fe4aecb49e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ + import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{ArrayType, IntegerType, StructType} +import org.apache.spark.sql.types.{ArrayType, StructType} class ComplexTypesSuite extends QueryTest with SharedSparkSession { import testImplicits._ From 319bf38e59aaf44982b80ba5d1d64e2442f90257 Mon Sep 17 00:00:00 2001 From: acetylzhang Date: Mon, 23 Nov 2020 17:09:46 +0800 Subject: [PATCH 4/8] Test different depths to extract nested arrays --- .../apache/spark/sql/ComplexTypesSuite.scala | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala index 054fe4aecb49e..395517574bc51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala @@ -124,6 +124,25 @@ class ComplexTypesSuite extends QueryTest with SharedSparkSession { val df = spark.read.json(Seq(jsonStr1, jsonStr2).toDS()) checkAnswer(df.select($"a.b.c"), Row(Seq(Seq(Seq(1, 2)))) :: Row(Seq(Seq(Seq(1), Seq(2)))) :: Nil) + + def genJson(start: Char, end: Char, vStr: String): String = { + (start to end).map(c => s"""{"$c": [""").mkString + + vStr + (start to end).map(_ => "]}").mkString + } + + def genResult(start: Char, end: Char, r: Seq[Int]): Any = { + (start until end).fold(r) { (z, _) => Seq(z)} + } + + val start: Char = 'a' + for (i <- 2 to 10) { + val end: Char = (start + i).toChar + val jsonAToZ = genJson(start, end, "1,2,3") + val dfAToZ = spark.read.json(Seq(jsonAToZ).toDS()) + checkAnswer(dfAToZ.select((start to end).mkString(".")), + Row(genResult(start, end, Seq(1, 2, 3)))) + } + } } From 6cfa034bbc5cdf7e3012085c1a412445f5c49596 Mon Sep 17 00:00:00 2001 From: acetylzhang Date: Mon, 23 Nov 2020 17:13:19 +0800 Subject: [PATCH 5/8] Fixing variable naming --- .../test/scala/org/apache/spark/sql/ComplexTypesSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala index 395517574bc51..85440c6dadb9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala @@ -137,9 +137,9 @@ class ComplexTypesSuite extends QueryTest with SharedSparkSession { val start: Char = 'a' for (i <- 2 to 10) { val end: Char = (start + i).toChar - val jsonAToZ = genJson(start, end, "1,2,3") - val dfAToZ = spark.read.json(Seq(jsonAToZ).toDS()) - checkAnswer(dfAToZ.select((start to end).mkString(".")), + val json = genJson(start, end, "1,2,3") + val df = spark.read.json(Seq(json).toDS()) + checkAnswer(df.select((start to end).mkString(".")), Row(genResult(start, end, Seq(1, 2, 3)))) } From 0fe667a4084e3f30c9c75e92ac40ca36720a32af Mon Sep 17 00:00:00 2001 From: acetylzhang Date: Tue, 24 Nov 2020 10:53:48 +0800 Subject: [PATCH 6/8] ExtractNestedArray to ExtractNestedArrayType --- .../spark/sql/catalyst/expressions/ProjectionOverSchema.scala | 2 +- .../sql/catalyst/expressions/complexTypeExtractors.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala index 67e25fc03343b..b995bb8050bd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala @@ -53,7 +53,7 @@ case class ProjectionOverSchema(schema: StructType) { } case ExtractNestedArrayField(child, _, _, field, containsNullSeq) => getProjection(child).map(p => (p, p.dataType)).map { - case (projection, ExtractNestedArray(projSchema @ StructType(_), _, _)) => + case (projection, ExtractNestedArrayType(projSchema @ StructType(_), _, _)) => ExtractNestedArrayField(projection, projSchema.fieldIndex(field.name), projSchema.fields.length, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index c71decd7b4365..f9f98c5581483 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -61,7 +61,7 @@ object ExtractValue { GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, fields.length, containsNull || fields(ordinal).nullable) - case (ExtractNestedArray(StructType(fields), _, containsNullSeq), + case (ExtractNestedArrayType(StructType(fields), _, containsNullSeq), NonNullLiteral(v, StringType)) if containsNullSeq.nonEmpty => val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) @@ -226,7 +226,7 @@ case class GetArrayStructFields( } } -object ExtractNestedArray { +object ExtractNestedArrayType { type ReturnType = Option[(DataType, Boolean, Seq[Boolean])] def unapply(dataType: DataType): ReturnType = { From 523ba8db42261f275720fdc1aabee86bbc871586 Mon Sep 17 00:00:00 2001 From: acetylzhang Date: Wed, 25 Nov 2020 00:50:49 +0800 Subject: [PATCH 7/8] Fix the problem of nullable mapping error --- .../expressions/ProjectionOverSchema.scala | 3 ++- .../sql/catalyst/expressions/SelectedField.scala | 8 ++++---- .../expressions/complexTypeExtractors.scala | 15 ++++++--------- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala index b995bb8050bd0..dea5926e8b2b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala @@ -51,13 +51,14 @@ case class ProjectionOverSchema(schema: StructType) { s"unmatched child schema for GetArrayStructFields: ${projSchema.toString}" ) } - case ExtractNestedArrayField(child, _, _, field, containsNullSeq) => + case ExtractNestedArrayField(child, _, _, field, containsNull, containsNullSeq) => getProjection(child).map(p => (p, p.dataType)).map { case (projection, ExtractNestedArrayType(projSchema @ StructType(_), _, _)) => ExtractNestedArrayField(projection, projSchema.fieldIndex(field.name), projSchema.fields.length, projSchema(field.name), + containsNull, containsNullSeq) case (_, projSchema) => throw new IllegalStateException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala index aa768f69e2222..3e206efef339e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala @@ -92,7 +92,7 @@ object SelectedField { } val newField = StructField(field.name, newFieldDataType, field.nullable) selectField(child, Option(ArrayType(struct(newField), containsNull))) - case ExtractNestedArrayField(child, _, _, field @ StructField(_, _, _, _), _) => + case ExtractNestedArrayField(child, _, _, field @ StructField(_, _, _, _), _, _) => val newFieldDataType = dataTypeOpt match { case None => // ExtractNestedArrayField is the top level extractor. This means its result is @@ -104,12 +104,12 @@ object SelectedField { val structType = struct(StructField(field.name, newFieldDataType, field.nullable)) val newDataType = child match { - case ExtractNestedArrayField(_, _, _, childField, _) => + case ExtractNestedArrayField(_, _, _, childField, containsNull, _) => childField.dataType match { - case _: ArrayType => ArrayType(structType) + case _: ArrayType => ArrayType(structType, containsNull) case _ => structType } - case _: GetArrayStructFields => ArrayType(structType) + case GetArrayStructFields(_, _, _, _, nullable) => ArrayType(structType, nullable) case _ => structType } selectField(child, Some(newDataType), nestArray = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index f9f98c5581483..d6f22ac959476 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -61,12 +61,12 @@ object ExtractValue { GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, fields.length, containsNull || fields(ordinal).nullable) - case (ExtractNestedArrayType(StructType(fields), _, containsNullSeq), + case (ExtractNestedArrayType(StructType(fields), containsNull, containsNullSeq), NonNullLiteral(v, StringType)) if containsNullSeq.nonEmpty => val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) ExtractNestedArrayField(child, ordinal, fields.length, - fields(ordinal).copy(name = fieldName), containsNullSeq) + fields(ordinal).copy(name = fieldName), containsNull, containsNullSeq) case (_: ArrayType, _) => GetArrayItem(child, extraction) @@ -230,14 +230,10 @@ object ExtractNestedArrayType { type ReturnType = Option[(DataType, Boolean, Seq[Boolean])] def unapply(dataType: DataType): ReturnType = { - extractArrayType(dataType) - } - - def extractArrayType(dataType: DataType): ReturnType = { dataType match { case ArrayType(dt, containsNull) => - extractArrayType(dt) match { - case Some((d, cn, seq)) => Some((d, cn, containsNull +: seq)) + unapply(dt) match { + case Some((d, cn, seq)) => Some((d, containsNull, cn +: seq)) case None => Some((dt, containsNull, Seq.empty[Boolean])) } case _ => None @@ -250,6 +246,7 @@ case class ExtractNestedArrayField( ordinal: Int, numFields: Int, field: StructField, + containsNull: Boolean, containsNullSeq: Seq[Boolean]) extends UnaryExpression with ExtractValue with NullIntolerant with CodegenFallback { @@ -289,7 +286,7 @@ case class ExtractNestedArrayField( } } - override def dataType: DataType = ArrayType(nestedArrayType(0)) + override def dataType: DataType = ArrayType(nestedArrayType(0), containsNull) def nestedArrayType(num: Int): DataType = { (num until containsNullSeq.size).reverse From b18c03acb0fb4402f0b56f785fd201034426e27d Mon Sep 17 00:00:00 2001 From: acetylzhang Date: Wed, 25 Nov 2020 00:51:47 +0800 Subject: [PATCH 8/8] Add document --- .../expressions/complexTypeExtractors.scala | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index d6f22ac959476..6bdb1f05d2895 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -37,12 +37,13 @@ object ExtractValue { * Returns the resolved `ExtractValue`. It will return one kind of concrete `ExtractValue`, * depend on the type of `child` and `extraction`. * - * `child` | `extraction` | concrete `ExtractValue` - * ---------------------------------------------------------------- - * Struct | Literal String | GetStructField - * Array[Struct] | Literal String | GetArrayStructFields - * Array | Integral type | GetArrayItem - * Map | map key type | GetMapValue + * `child` | `extraction` | concrete `ExtractValue` + * -------------------------------------------------------------------------------- + * Struct | Literal String | GetStructField + * Array[Struct] | Literal String | GetArrayStructFields + * Array[ ...Array[struct] ] | Literal String | ExtractNestedArrayField + * Array | Integral type | GetArrayItem + * Map | map key type | GetMapValue */ def apply( child: Expression, @@ -226,6 +227,13 @@ case class GetArrayStructFields( } } +/** + * ExtractNestedArrayType is used to match consecutive nested array types. + * + * ReturnType: (DataType: the innermost dataType, Boolean: the outermost array contains null + * , Seq[Boolean]: the second outer layer to the innermost layer contains null) + * + */ object ExtractNestedArrayType { type ReturnType = Option[(DataType, Boolean, Seq[Boolean])] @@ -241,6 +249,10 @@ object ExtractNestedArrayType { } } +/** + * For a child whose data type is a nested array containing struct at the innermost level, extracts + * the `ordinal`-th fields of multi-level nested array, and returns them as a new nested array. + */ case class ExtractNestedArrayField( child: Expression, ordinal: Int,