From 25a4d19ef62ec7dadd397110839ce9a98e246a64 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 May 2016 16:38:42 -0700 Subject: [PATCH 1/6] encoder code cleanup --- .../linalg/UDTSerializationBenchmark.scala | 2 +- .../sql/catalyst/analysis/Analyzer.scala | 53 +++++++- .../catalyst/encoders/ExpressionEncoder.scala | 126 ++++-------------- .../encoders/EncoderResolutionSuite.scala | 42 +++--- .../encoders/ExpressionEncoderSuite.scala | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 47 ++++--- .../spark/sql/KeyValueGroupedDataset.scala | 26 ++-- .../spark/sql/RelationalGroupedDataset.scala | 2 +- .../org/apache/spark/sql/functions.scala | 2 +- .../org/apache/spark/sql/DatasetSuite.scala | 8 +- .../org/apache/spark/sql/QueryTest.scala | 8 +- 11 files changed, 134 insertions(+), 184 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala index be7110ad6bbf..8b439e6b7a01 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala @@ -29,7 +29,7 @@ object UDTSerializationBenchmark { val iters = 1e2.toInt val numRows = 1e3.toInt - val encoder = ExpressionEncoder[Vector].defaultBinding + val encoder = ExpressionEncoder[Vector].resolveAndBind() val vectors = (1 to numRows).map { i => Vectors.dense(Array.fill(1e5.toInt)(1.0 * i)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 02966796afdd..9b3df69c231e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1924,10 +1924,61 @@ class Analyzer( } else { inputAttributes } + + validateBoundReference(deserializer, inputs) + val unbound = deserializer transform { case b: BoundReference => inputs(b.ordinal) } - resolveExpression(unbound, LocalRelation(inputs), throws = true) + val resolved = resolveExpression(unbound, LocalRelation(inputs), throws = true) + + validateGetStructField(resolved) + resolved + } + } + + private def fail(schema: StructType, maxOrdinal: Int): Unit = { + throw new AnalysisException(s"Try to map ${schema.simpleString} to Tuple${maxOrdinal + 1}, " + + "but failed as the number of fields does not line up.") + } + + /** + * Deserializer expression may contains `BoundReference`s instead of `AttributeReference`s + * (e.g. if it's from a tuple encoder or tupled encoder), we should check if their + * ordinals match the real schema. + */ + private def validateBoundReference(deserializer: Expression, inputs: Seq[Attribute]): Unit = { + var maxOrdinal = -1 + deserializer.foreach { + case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal + case _ => + } + if (maxOrdinal >= 0 && maxOrdinal != inputs.length - 1) { + fail(inputs.toStructType, maxOrdinal) + } + } + + /** + * Deserializer expression may contains `GetStructField`s instead of `UnresolvedExtractValue`s + * (e.g. if we are encoding a nested tuple), we should check if their ordinals match the real + * schema. + */ + private def validateGetStructField(deserializer: Expression): Unit = { + val exprToMaxOrdinal = scala.collection.mutable.HashMap.empty[Expression, Int] + deserializer foreach { + case g: GetStructField => + val maxOrdinal = exprToMaxOrdinal.getOrElse(g.child, -1) + if (maxOrdinal < g.ordinal) { + exprToMaxOrdinal.update(g.child, g.ordinal) + } + case _ => + } + exprToMaxOrdinal.foreach { + case (expr, maxOrdinal) => + val schema = expr.dataType.asInstanceOf[StructType] + if (maxOrdinal != schema.length - 1) { + fail(schema, maxOrdinal) + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 2296946cd7c5..4358e19380da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -17,19 +17,17 @@ package org.apache.spark.sql.catalyst.encoders -import java.util.concurrent.ConcurrentMap - import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.sql.{AnalysisException, Encoder} import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation} import org.apache.spark.sql.types.{ObjectType, StructField, StructType} import org.apache.spark.util.Utils @@ -192,6 +190,26 @@ case class ExpressionEncoder[T]( if (flat) require(serializer.size == 1) + /** + * Returns a new copy of this encoder, where the `deserializer` is resolved and bound to the + * given schema. + * + * Note that, ideally encoder is used as a container of serde expressions, the resolution and + * binding stuff should happen inside query framework. However, in some cases we need to + * use encoder as a function to do serialization directly(e.g. Dataset.collect), then we can use + * this method to do resolution and binding outside of query framework. + */ + def resolveAndBind( + attrs: Seq[Attribute] = schema.toAttributes, + analyzer: Analyzer = SimpleAnalyzer): ExpressionEncoder[T] = { + val dummyPlan = CatalystSerde.deserialize(LocalRelation(attrs))(this) + val analyzedPlan = analyzer.execute(dummyPlan) + analyzer.checkAnalysis(analyzedPlan) + val resolved = SimplifyCasts(analyzedPlan).asInstanceOf[DeserializeToObject].deserializer + val bound = BindReferences.bindReference(resolved, attrs) + copy(deserializer = bound) + } + @transient private lazy val extractProjection = GenerateUnsafeProjection.generate(serializer) @@ -201,16 +219,6 @@ case class ExpressionEncoder[T]( @transient private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil) - /** - * Returns this encoder where it has been bound to its own output (i.e. no remaping of columns - * is performed). - */ - def defaultBinding: ExpressionEncoder[T] = { - val attrs = schema.toAttributes - resolve(attrs, OuterScopes.outerScopes).bind(attrs) - } - - /** * Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form * of this object. @@ -236,7 +244,7 @@ case class ExpressionEncoder[T]( /** * Returns an object of type `T`, extracting the required values from the provided row. Note that - * you must `resolve` and `bind` an encoder to a specific schema before you can call this + * you must `resolveAndBind` an encoder to a specific schema before you can call this * function. */ def fromRow(row: InternalRow): T = try { @@ -259,94 +267,6 @@ case class ExpressionEncoder[T]( }) } - /** - * Validates `deserializer` to make sure it can be resolved by given schema, and produce - * friendly error messages to explain why it fails to resolve if there is something wrong. - */ - def validate(schema: Seq[Attribute]): Unit = { - def fail(st: StructType, maxOrdinal: Int): Unit = { - throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: " + StructType.fromAttributes(schema).simpleString + "\n" + - " - Target schema: " + this.schema.simpleString) - } - - // If this is a tuple encoder or tupled encoder, which means its leaf nodes are all - // `BoundReference`, make sure their ordinals are all valid. - var maxOrdinal = -1 - deserializer.foreach { - case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal - case _ => - } - if (maxOrdinal >= 0 && maxOrdinal != schema.length - 1) { - fail(StructType.fromAttributes(schema), maxOrdinal) - } - - // If we have nested tuple, the `fromRowExpression` will contains `GetStructField` instead of - // `UnresolvedExtractValue`, so we need to check if their ordinals are all valid. - // Note that, `BoundReference` contains the expected type, but here we need the actual type, so - // we unbound it by the given `schema` and propagate the actual type to `GetStructField`, after - // we resolve the `fromRowExpression`. - val resolved = SimpleAnalyzer.resolveExpression( - deserializer, - LocalRelation(schema), - throws = true) - - val unbound = resolved transform { - case b: BoundReference => schema(b.ordinal) - } - - val exprToMaxOrdinal = scala.collection.mutable.HashMap.empty[Expression, Int] - unbound.foreach { - case g: GetStructField => - val maxOrdinal = exprToMaxOrdinal.getOrElse(g.child, -1) - if (maxOrdinal < g.ordinal) { - exprToMaxOrdinal.update(g.child, g.ordinal) - } - case _ => - } - exprToMaxOrdinal.foreach { - case (expr, maxOrdinal) => - val schema = expr.dataType.asInstanceOf[StructType] - if (maxOrdinal != schema.length - 1) { - fail(schema, maxOrdinal) - } - } - } - - /** - * Returns a new copy of this encoder, where the `deserializer` is resolved to the given schema. - */ - def resolve( - schema: Seq[Attribute], - outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { - // Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check - // analysis, go through optimizer, etc. - val plan = Project( - Alias(UnresolvedDeserializer(deserializer, schema), "")() :: Nil, - LocalRelation(schema)) - val analyzedPlan = SimpleAnalyzer.execute(plan) - SimpleAnalyzer.checkAnalysis(analyzedPlan) - copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head) - } - - /** - * Returns a copy of this encoder where the `deserializer` has been bound to the - * ordinals of the given schema. Note that you need to first call resolve before bind. - */ - def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = { - copy(deserializer = BindReferences.bindReference(deserializer, schema)) - } - - /** - * Returns a new encoder with input columns shifted by `delta` ordinals - */ - def shift(delta: Int): ExpressionEncoder[T] = { - copy(deserializer = deserializer transform { - case r: BoundReference => r.copy(ordinal = r.ordinal + delta) - }) - } - protected val attrs = serializer.flatMap(_.collect { case _: UnresolvedAttribute => "" case a: Attribute => s"#${a.exprId}" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 3ad0dae767be..7251202c7bd5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -41,17 +41,17 @@ class EncoderResolutionSuite extends PlanTest { // int type can be up cast to long type val attrs1 = Seq('a.string, 'b.int) - encoder.resolve(attrs1, null).bind(attrs1).fromRow(InternalRow(str, 1)) + encoder.resolveAndBind(attrs1).fromRow(InternalRow(str, 1)) // int type can be up cast to string type val attrs2 = Seq('a.int, 'b.long) - encoder.resolve(attrs2, null).bind(attrs2).fromRow(InternalRow(1, 2L)) + encoder.resolveAndBind(attrs2).fromRow(InternalRow(1, 2L)) } test("real type doesn't match encoder schema but they are compatible: nested product") { val encoder = ExpressionEncoder[ComplexClass] val attrs = Seq('a.int, 'b.struct('a.int, 'b.long)) - encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L))) + encoder.resolveAndBind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L))) } test("real type doesn't match encoder schema but they are compatible: tupled encoder") { @@ -59,7 +59,7 @@ class EncoderResolutionSuite extends PlanTest { ExpressionEncoder[StringLongClass], ExpressionEncoder[Long]) val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int) - encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2)) + encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2)) } test("nullability of array type element should not fail analysis") { @@ -67,7 +67,7 @@ class EncoderResolutionSuite extends PlanTest { val attrs = 'a.array(IntegerType) :: Nil // It should pass analysis - val bound = encoder.resolve(attrs, null).bind(attrs) + val bound = encoder.resolveAndBind(attrs) // If no null values appear, it should works fine bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2)))) @@ -84,20 +84,16 @@ class EncoderResolutionSuite extends PlanTest { { val attrs = Seq('a.string, 'b.long, 'c.int) - assert(intercept[AnalysisException](encoder.validate(attrs)).message == + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "Try to map struct to Tuple2, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct\n" + - " - Target schema: struct<_1:string,_2:bigint>") + "but failed as the number of fields does not line up.") } { val attrs = Seq('a.string) - assert(intercept[AnalysisException](encoder.validate(attrs)).message == + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "Try to map struct to Tuple2, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct\n" + - " - Target schema: struct<_1:string,_2:bigint>") + "but failed as the number of fields does not line up.") } } @@ -106,26 +102,22 @@ class EncoderResolutionSuite extends PlanTest { { val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int)) - assert(intercept[AnalysisException](encoder.validate(attrs)).message == + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "Try to map struct to Tuple2, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct>\n" + - " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>") + "but failed as the number of fields does not line up.") } { val attrs = Seq('a.string, 'b.struct('x.long)) - assert(intercept[AnalysisException](encoder.validate(attrs)).message == + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "Try to map struct to Tuple2, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct>\n" + - " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>") + "but failed as the number of fields does not line up.") } } test("throw exception if real type is not compatible with encoder schema") { val msg1 = intercept[AnalysisException] { - ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null) + ExpressionEncoder[StringIntClass].resolveAndBind(Seq('a.string, 'b.long)) }.message assert(msg1 == s""" @@ -138,7 +130,7 @@ class EncoderResolutionSuite extends PlanTest { val msg2 = intercept[AnalysisException] { val structType = new StructType().add("a", StringType).add("b", DecimalType.SYSTEM_DEFAULT) - ExpressionEncoder[ComplexClass].resolve(Seq('a.long, 'b.struct(structType)), null) + ExpressionEncoder[ComplexClass].resolveAndBind(Seq('a.long, 'b.struct(structType))) }.message assert(msg2 == s""" @@ -171,7 +163,7 @@ class EncoderResolutionSuite extends PlanTest { val to = ExpressionEncoder[U] val catalystType = from.schema.head.dataType.simpleString test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should success") { - to.resolve(from.schema.toAttributes, null) + to.resolveAndBind(from.schema.toAttributes) } } @@ -180,7 +172,7 @@ class EncoderResolutionSuite extends PlanTest { val to = ExpressionEncoder[U] val catalystType = from.schema.head.dataType.simpleString test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should fail") { - intercept[AnalysisException](to.resolve(from.schema.toAttributes, null)) + intercept[AnalysisException](to.resolveAndBind(from.schema.toAttributes)) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 232dcc9ee51c..282d5cdf01a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -334,7 +334,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { val encoder = implicitly[ExpressionEncoder[T]] val row = encoder.toRow(input) val schema = encoder.schema.toAttributes - val boundEncoder = encoder.defaultBinding + val boundEncoder = encoder.resolveAndBind() val convertedBack = try boundEncoder.fromRow(row) catch { case e: Exception => fail( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 369b772d322c..a52b1754f9ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -192,24 +192,23 @@ class Dataset[T] private[sql]( } /** - * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is - * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the - * same object type (that will be possibly resolved to a different schema). + * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the + * passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use + * it when constructing new [[Dataset]] objects that have the same object type (that will be + * possibly resolved to a different schema). */ - private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(encoder) - unresolvedTEncoder.validate(logicalPlan.output) - - /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ - private[sql] val resolvedTEncoder: ExpressionEncoder[T] = - unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes) + private[sql] implicit val enc: ExpressionEncoder[T] = encoderFor(encoder) /** - * The encoder where the expressions used to construct an object from an input row have been - * bound to the ordinals of this [[Dataset]]'s output schema. + * Encoder is used mostly as a container of serde expressions in [[Dataset]]. We build logical + * plans by these serde expressions and execute it within the query framework. However, for + * performance reasons we may want to use encoder as a function to deserialize internal rows to + * custom objects, e.g. collect. Here we resolve and bind the encoder so that we can call its + * `fromRow` method later. */ - private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output) + private val boundEnc = enc.resolveAndBind(logicalPlan.output) - private implicit def classTag = unresolvedTEncoder.clsTag + private implicit def classTag = enc.clsTag // sqlContext must be val because a stable identifier is expected when you import implicits @transient lazy val sqlContext: SQLContext = sparkSession.sqlContext @@ -761,7 +760,7 @@ class Dataset[T] private[sql]( // Note that we do this before joining them, to enable the join operator to return null for one // side, in cases like outer-join. val left = { - val combined = if (this.unresolvedTEncoder.flat) { + val combined = if (this.enc.flat) { assert(joined.left.output.length == 1) Alias(joined.left.output.head, "_1")() } else { @@ -771,7 +770,7 @@ class Dataset[T] private[sql]( } val right = { - val combined = if (other.unresolvedTEncoder.flat) { + val combined = if (other.enc.flat) { assert(joined.right.output.length == 1) Alias(joined.right.output.head, "_2")() } else { @@ -784,14 +783,14 @@ class Dataset[T] private[sql]( // combine the outputs of each join side. val conditionExpr = joined.condition.get transformUp { case a: Attribute if joined.left.outputSet.contains(a) => - if (this.unresolvedTEncoder.flat) { + if (this.enc.flat) { left.output.head } else { val index = joined.left.output.indexWhere(_.exprId == a.exprId) GetStructField(left.output.head, index) } case a: Attribute if joined.right.outputSet.contains(a) => - if (other.unresolvedTEncoder.flat) { + if (other.enc.flat) { right.output.head } else { val index = joined.right.output.indexWhere(_.exprId == a.exprId) @@ -800,7 +799,7 @@ class Dataset[T] private[sql]( } implicit val tuple2Encoder: Encoder[(T, U)] = - ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) + ExpressionEncoder.tuple(this.enc, other.enc) withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr))) } @@ -1024,7 +1023,7 @@ class Dataset[T] private[sql]( sparkSession, Project( c1.withInputType( - unresolvedTEncoder.deserializer, + enc.deserializer, logicalPlan.output).named :: Nil, logicalPlan), implicitly[Encoder[U1]]) @@ -1038,7 +1037,7 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(unresolvedTEncoder.deserializer, logicalPlan.output).named) + columns.map(_.withInputType(enc.deserializer, logicalPlan.output).named) val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders)) } @@ -2153,14 +2152,14 @@ class Dataset[T] private[sql]( */ def collectAsList(): java.util.List[T] = withCallback("collectAsList", toDF()) { _ => withNewExecutionId { - val values = queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow) + val values = queryExecution.executedPlan.executeCollect().map(boundEnc.fromRow) java.util.Arrays.asList(values : _*) } } private def collect(needCallback: Boolean): Array[T] = { def execute(): Array[T] = withNewExecutionId { - queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow) + queryExecution.executedPlan.executeCollect().map(boundEnc.fromRow) } if (needCallback) { @@ -2184,7 +2183,7 @@ class Dataset[T] private[sql]( */ def toLocalIterator(): java.util.Iterator[T] = withCallback("toLocalIterator", toDF()) { _ => withNewExecutionId { - queryExecution.executedPlan.executeToIterator().map(boundTEncoder.fromRow).asJava + queryExecution.executedPlan.executeToIterator().map(boundEnc.fromRow).asJava } } @@ -2322,7 +2321,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ lazy val rdd: RDD[T] = { - val objectType = unresolvedTEncoder.deserializer.dataType + val objectType = enc.deserializer.dataType val deserialized = CatalystSerde.deserialize[T](logicalPlan) sparkSession.sessionState.executePlan(deserialized).toRdd.mapPartitions { rows => rows.map(_.get(0, objectType).asInstanceOf[T]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 53f4ea647c85..7e0591f6e608 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -42,17 +42,9 @@ class KeyValueGroupedDataset[K, V] private[sql]( private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) extends Serializable { - // Similar to [[Dataset]], we use unresolved encoders for later composition and resolved encoders - // when constructing new logical plans that will operate on the output of the current - // queryexecution. - - private implicit val unresolvedKEncoder = encoderFor(kEncoder) - private implicit val unresolvedVEncoder = encoderFor(vEncoder) - - private val resolvedKEncoder = - unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes) - private val resolvedVEncoder = - unresolvedVEncoder.resolve(dataAttributes, OuterScopes.outerScopes) + // Similar to [[Dataset]], we turn the passed in encoder to `ExpressionEncoder` explicitly. + private implicit val kEnc = encoderFor(kEncoder) + private implicit val vEnc = encoderFor(vEncoder) private def logicalPlan = queryExecution.analyzed private def sparkSession = queryExecution.sparkSession @@ -67,7 +59,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( def keyAs[L : Encoder]: KeyValueGroupedDataset[L, V] = new KeyValueGroupedDataset( encoderFor[L], - unresolvedVEncoder, + vEnc, queryExecution, dataAttributes, groupingAttributes) @@ -187,7 +179,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = { val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) - implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedVEncoder) + implicit val resultEncoder = ExpressionEncoder.tuple(kEnc, vEnc) flatMapGroups(func) } @@ -209,8 +201,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(unresolvedVEncoder.deserializer, dataAttributes).named) - val keyColumn = if (resolvedKEncoder.flat) { + columns.map(_.withInputType(vEnc.deserializer, dataAttributes).named) + val keyColumn = if (kEnc.flat) { assert(groupingAttributes.length == 1) groupingAttributes.head } else { @@ -222,7 +214,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( new Dataset( sparkSession, execution, - ExpressionEncoder.tuple(unresolvedKEncoder +: encoders)) + ExpressionEncoder.tuple(kEnc +: encoders)) } /** @@ -287,7 +279,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( def cogroup[U, R : Encoder]( other: KeyValueGroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { - implicit val uEncoder = other.unresolvedVEncoder + implicit val uEncoder = other.vEnc Dataset[R]( sparkSession, CoGroup( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 58850a7d4b49..26ce28f09b43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -215,7 +215,7 @@ class RelationalGroupedDataset protected[sql]( def agg(expr: Column, exprs: Column*): DataFrame = { toDF((expr +: exprs).map { case typed: TypedColumn[_, _] => - typed.withInputType(df.unresolvedTEncoder.deserializer, df.logicalPlan.output).expr + typed.withInputType(df.enc.deserializer, df.logicalPlan.output).expr case c => c.expr }) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d89e98645b36..e319431bd1b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -924,7 +924,7 @@ object functions { * @since 1.5.0 */ def broadcast[T](df: Dataset[T]): Dataset[T] = { - Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.unresolvedTEncoder) + Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.enc) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index df8f4b0610af..d1c232974e9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -566,18 +566,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { }.message assert(message == "Try to map struct to Tuple3, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct\n" + - " - Target schema: struct<_1:string,_2:int,_3:bigint>") + "but failed as the number of fields does not line up.") val message2 = intercept[AnalysisException] { ds.as[Tuple1[String]] }.message assert(message2 == "Try to map struct to Tuple1, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct\n" + - " - Target schema: struct<_1:string>") + "but failed as the number of fields does not line up.") } test("SPARK-13440: Resolving option fields") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index a1a9b66c1fee..20beaa7e9cb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -81,7 +81,7 @@ abstract class QueryTest extends PlanTest { expectedAnswer: T*): Unit = { checkAnswer( ds.toDF(), - spark.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq) + spark.createDataset(expectedAnswer)(ds.enc).toDF().collect().toSeq) checkDecoding(ds, expectedAnswer: _*) } @@ -94,8 +94,8 @@ abstract class QueryTest extends PlanTest { fail( s""" |Exception collecting dataset as objects - |${ds.resolvedTEncoder} - |${ds.resolvedTEncoder.deserializer.treeString} + |${ds.enc} + |${ds.enc.deserializer.treeString} |${ds.queryExecution} """.stripMargin, e) } @@ -114,7 +114,7 @@ abstract class QueryTest extends PlanTest { fail( s"""Decoded objects do not match expected objects: |$comparison - |${ds.resolvedTEncoder.deserializer.treeString} + |${ds.enc.deserializer.treeString} """.stripMargin) } } From 2920e38c5b4c125c4cb31516cf081b2ca9dbbd79 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 26 May 2016 17:19:02 -0700 Subject: [PATCH 2/6] use GetColumnByOrdinal instead of BoundReference --- .../scala/org/apache/spark/sql/Encoders.scala | 3 +- .../sql/catalyst/JavaTypeInference.scala | 6 +- .../spark/sql/catalyst/ScalaReflection.scala | 307 +++++++++--------- .../sql/catalyst/analysis/Analyzer.scala | 35 +- .../sql/catalyst/analysis/unresolved.scala | 7 + .../catalyst/encoders/ExpressionEncoder.scala | 10 +- .../sql/catalyst/encoders/RowEncoder.scala | 8 +- .../sql/catalyst/expressions/Expression.scala | 2 +- .../sql/catalyst/plans/logical/object.scala | 19 +- .../encoders/ExpressionEncoderSuite.scala | 9 +- .../catalyst/encoders/RowEncoderSuite.scala | 14 +- .../scala/org/apache/spark/sql/Dataset.scala | 22 +- .../spark/sql/KeyValueGroupedDataset.scala | 16 +- .../spark/sql/RelationalGroupedDataset.scala | 2 +- .../aggregate/TypedAggregateExpression.scala | 6 +- .../org/apache/spark/sql/functions.scala | 2 +- .../org/apache/spark/sql/QueryTest.scala | 8 +- .../sql/execution/GroupedIteratorSuite.scala | 6 +- .../spark/sql/streaming/StreamTest.scala | 4 +- 19 files changed, 237 insertions(+), 249 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index fa96f8223d17..673c587b1832 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -23,6 +23,7 @@ import scala.reflect.{classTag, ClassTag} import scala.reflect.runtime.universe.TypeTag import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.objects.{DecodeUsingSerializer, EncodeUsingSerializer} import org.apache.spark.sql.catalyst.expressions.BoundReference @@ -208,7 +209,7 @@ object Encoders { BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), deserializer = DecodeUsingSerializer[T]( - BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), + GetColumnByOrdinal(0, BinaryType), classTag[T], kryo = useKryo), clsTag = classTag[T] ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 1fe143494aba..b3a233ae390a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -25,7 +25,7 @@ import scala.language.existentials import com.google.common.reflect.TypeToken -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} @@ -177,8 +177,8 @@ object JavaTypeInference { .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) .getOrElse(UnresolvedAttribute(part)) - /** Returns the current path or `BoundReference`. */ - def getPath: Expression = path.getOrElse(BoundReference(0, inferDataType(typeToken)._1, true)) + /** Returns the current path or `GetColumnByOrdinal`. */ + def getPath: Expression = path.getOrElse(GetColumnByOrdinal(0, inferDataType(typeToken)._1)) typeToken.getRawType match { case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 47508618178e..78c145d4fd93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} @@ -156,17 +156,17 @@ object ScalaReflection extends ScalaReflection { walkedTypePath: Seq[String]): Expression = { val newPath = path .map(p => GetStructField(p, ordinal)) - .getOrElse(BoundReference(ordinal, dataType, false)) + .getOrElse(GetColumnByOrdinal(ordinal, dataType)) upCastToExpectedType(newPath, dataType, walkedTypePath) } - /** Returns the current path or `BoundReference`. */ + /** Returns the current path or `GetColumnByOrdinal`. */ def getPath: Expression = { val dataType = schemaFor(tpe).dataType if (path.isDefined) { path.get } else { - upCastToExpectedType(BoundReference(0, dataType, true), dataType, walkedTypePath) + upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType, walkedTypePath) } } @@ -421,7 +421,7 @@ object ScalaReflection extends ScalaReflection { def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) - val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil + val walkedTypePath = s"""- root class: "$clsName"""" :: Nil serializerFor(inputObject, tpe, walkedTypePath) match { case expressions.If(_, _, s: CreateNamedStruct) if definedByConstructorParams(tpe) => s case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) @@ -449,157 +449,156 @@ object ScalaReflection extends ScalaReflection { } } - if (!inputObject.dataType.isInstanceOf[ObjectType]) { - inputObject - } else { - tpe match { - case t if t <:< localTypeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - val className = getClassNameFromType(optType) - val newPath = s"""- option value class: "$className"""" +: walkedTypePath - val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) - serializerFor(unwrapped, optType, newPath) - - // Since List[_] also belongs to localTypeOf[Product], we put this case before - // "case t if definedByConstructorParams(t)" to make sure it will match to the - // case "localTypeOf[Seq[_]]" - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - toCatalystArray(inputObject, elementType) - - case t if t <:< localTypeOf[Array[_]] => - val TypeRef(_, _, Seq(elementType)) = t - toCatalystArray(inputObject, elementType) - - case t if t <:< localTypeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - - val keys = - Invoke( - Invoke(inputObject, "keysIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedKeys = toCatalystArray(keys, keyType) - - val values = - Invoke( - Invoke(inputObject, "valuesIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedValues = toCatalystArray(values, valueType) - - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - NewInstance( - classOf[ArrayBasedMapData], - convertedKeys :: convertedValues :: Nil, - dataType = MapType(keyDataType, valueDataType, valueNullable)) - - case t if t <:< localTypeOf[String] => - StaticInvoke( - classOf[UTF8String], - StringType, - "fromString", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils.getClass, - TimestampType, - "fromJavaTimestamp", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils.getClass, - DateType, - "fromJavaDate", - inputObject :: Nil) - - case t if t <:< localTypeOf[BigDecimal] => - StaticInvoke( - Decimal.getClass, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.math.BigDecimal] => - StaticInvoke( - Decimal.getClass, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.math.BigInteger] => - StaticInvoke( - Decimal.getClass, - DecimalType.BigIntDecimal, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[scala.math.BigInt] => - StaticInvoke( - Decimal.getClass, - DecimalType.BigIntDecimal, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.lang.Integer] => - Invoke(inputObject, "intValue", IntegerType) - case t if t <:< localTypeOf[java.lang.Long] => - Invoke(inputObject, "longValue", LongType) - case t if t <:< localTypeOf[java.lang.Double] => - Invoke(inputObject, "doubleValue", DoubleType) - case t if t <:< localTypeOf[java.lang.Float] => - Invoke(inputObject, "floatValue", FloatType) - case t if t <:< localTypeOf[java.lang.Short] => - Invoke(inputObject, "shortValue", ShortType) - case t if t <:< localTypeOf[java.lang.Byte] => - Invoke(inputObject, "byteValue", ByteType) - case t if t <:< localTypeOf[java.lang.Boolean] => - Invoke(inputObject, "booleanValue", BooleanType) - - case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => - val udt = getClassFromType(t) - .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() - val obj = NewInstance( - udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), - Nil, - dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) - Invoke(obj, "serialize", udt, inputObject :: Nil) - - case t if UDTRegistration.exists(getClassNameFromType(t)) => - val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() - .asInstanceOf[UserDefinedType[_]] - val obj = NewInstance( - udt.getClass, - Nil, - dataType = ObjectType(udt.getClass)) - Invoke(obj, "serialize", udt, inputObject :: Nil) - - case t if definedByConstructorParams(t) => - val params = getConstructorParameters(t) - val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => - if (javaKeywords.contains(fieldName)) { - throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " + - "cannot be used as field name\n" + walkedTypePath.mkString("\n")) - } + tpe match { + case _ if !inputObject.dataType.isInstanceOf[ObjectType] => inputObject - val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) - val clsName = getClassNameFromType(fieldType) - val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil - }) - val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) - expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) - - case other => - throw new UnsupportedOperationException( - s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) - } + case t if t <:< localTypeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + val className = getClassNameFromType(optType) + val newPath = s"""- option value class: "$className"""" +: walkedTypePath + val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) + serializerFor(unwrapped, optType, newPath) + + // Since List[_] also belongs to localTypeOf[Product], we put this case before + // "case t if definedByConstructorParams(t)" to make sure it will match to the + // case "localTypeOf[Seq[_]]" + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + toCatalystArray(inputObject, elementType) + + case t if t <:< localTypeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = t + toCatalystArray(inputObject, elementType) + + case t if t <:< localTypeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + + val keys = + Invoke( + Invoke(inputObject, "keysIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedKeys = toCatalystArray(keys, keyType) + + val values = + Invoke( + Invoke(inputObject, "valuesIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedValues = toCatalystArray(values, valueType) + + val Schema(keyDataType, _) = schemaFor(keyType) + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + NewInstance( + classOf[ArrayBasedMapData], + convertedKeys :: convertedValues :: Nil, + dataType = MapType(keyDataType, valueDataType, valueNullable)) + + case t if t <:< localTypeOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils.getClass, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils.getClass, + DateType, + "fromJavaDate", + inputObject :: Nil) + + case t if t <:< localTypeOf[BigDecimal] => + StaticInvoke( + Decimal.getClass, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.math.BigDecimal] => + StaticInvoke( + Decimal.getClass, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.math.BigInteger] => + StaticInvoke( + Decimal.getClass, + DecimalType.BigIntDecimal, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[scala.math.BigInt] => + StaticInvoke( + Decimal.getClass, + DecimalType.BigIntDecimal, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case t if t <:< localTypeOf[java.lang.Long] => + Invoke(inputObject, "longValue", LongType) + case t if t <:< localTypeOf[java.lang.Double] => + Invoke(inputObject, "doubleValue", DoubleType) + case t if t <:< localTypeOf[java.lang.Float] => + Invoke(inputObject, "floatValue", FloatType) + case t if t <:< localTypeOf[java.lang.Short] => + Invoke(inputObject, "shortValue", ShortType) + case t if t <:< localTypeOf[java.lang.Byte] => + Invoke(inputObject, "byteValue", ByteType) + case t if t <:< localTypeOf[java.lang.Boolean] => + Invoke(inputObject, "booleanValue", BooleanType) + + case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => + val udt = getClassFromType(t) + .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "serialize", udt, inputObject :: Nil) + + case t if UDTRegistration.exists(getClassNameFromType(t)) => + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() + .asInstanceOf[UserDefinedType[_]] + val obj = NewInstance( + udt.getClass, + Nil, + dataType = ObjectType(udt.getClass)) + Invoke(obj, "serialize", udt, inputObject :: Nil) + + case t if definedByConstructorParams(t) => + val params = getConstructorParameters(t) + val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => + if (javaKeywords.contains(fieldName)) { + throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " + + "cannot be used as field name\n" + walkedTypePath.mkString("\n")) + } + + val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) + val clsName = getClassNameFromType(fieldType) + val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath + expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil + }) + val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) + expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) + + case other => + throw new UnsupportedOperationException( + s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) } + } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9b3df69c231e..a776772da5b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1925,15 +1925,14 @@ class Analyzer( inputAttributes } - validateBoundReference(deserializer, inputs) - - val unbound = deserializer transform { - case b: BoundReference => inputs(b.ordinal) + validateTupleColumns(deserializer, inputs) + val ordinalResolved = deserializer transform { + case GetColumnByOrdinal(ordinal, _) => inputs(ordinal) } - val resolved = resolveExpression(unbound, LocalRelation(inputs), throws = true) - - validateGetStructField(resolved) - resolved + val attrResolved = resolveExpression( + ordinalResolved, LocalRelation(inputs), throws = true) + validateInnerTupleFields(attrResolved) + attrResolved } } @@ -1943,14 +1942,15 @@ class Analyzer( } /** - * Deserializer expression may contains `BoundReference`s instead of `AttributeReference`s - * (e.g. if it's from a tuple encoder or tupled encoder), we should check if their - * ordinals match the real schema. + * For each Tuple field, we use [[GetColumnByOrdinal]] to get its corresponding column by + * position. However, the actual number of columns may be different from the number of Tuple + * fields. This method is used to check the number of columns and fields, and throw an + * exception if they do not match. */ - private def validateBoundReference(deserializer: Expression, inputs: Seq[Attribute]): Unit = { + private def validateTupleColumns(deserializer: Expression, inputs: Seq[Attribute]): Unit = { var maxOrdinal = -1 deserializer.foreach { - case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal + case GetColumnByOrdinal(ordinal, _) => if (ordinal > maxOrdinal) maxOrdinal = ordinal case _ => } if (maxOrdinal >= 0 && maxOrdinal != inputs.length - 1) { @@ -1959,11 +1959,12 @@ class Analyzer( } /** - * Deserializer expression may contains `GetStructField`s instead of `UnresolvedExtractValue`s - * (e.g. if we are encoding a nested tuple), we should check if their ordinals match the real - * schema. + * For each inner Tuple field, we use [[GetStructField]] to get its corresponding struct field + * by position. However, the actual number of struct fields may be different from the number + * of inner Tuple fields. This method is used to check the number of struct fields and inner + * Tuple fields, and throw an exception if they do not match. */ - private def validateGetStructField(deserializer: Expression): Unit = { + private def validateInnerTupleFields(deserializer: Expression): Unit = { val exprToMaxOrdinal = scala.collection.mutable.HashMap.empty[Expression, Int] deserializer foreach { case g: GetStructField => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index e953eda7843c..b883546135f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -366,3 +366,10 @@ case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false } + +case class GetColumnByOrdinal(ordinal: Int, dataType: DataType) extends LeafExpression + with Unevaluable with NonSQLExpression { + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 4358e19380da..cc59d06fa351 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst.encoders import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} -import org.apache.spark.sql.{AnalysisException, Encoder} +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.analysis.{Analyzer, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance} @@ -119,15 +119,15 @@ object ExpressionEncoder { val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => if (enc.flat) { enc.deserializer.transform { - case b: BoundReference => b.copy(ordinal = index) + case g: GetColumnByOrdinal => g.copy(ordinal = index) } } else { - val input = BoundReference(index, enc.schema, nullable = true) + val input = GetColumnByOrdinal(index, enc.schema) val deserialized = enc.deserializer.transformUp { case UnresolvedAttribute(nameParts) => assert(nameParts.length == 1) UnresolvedExtractValue(input, Literal(nameParts.head)) - case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal) + case GetColumnByOrdinal(ordinal, _) => GetStructField(input, ordinal) } If(IsNull(input), Literal.create(null, deserialized.dataType), deserialized) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 0de9166aa29a..3c6ae1c5ccd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -210,12 +211,7 @@ object RowEncoder { case p: PythonUserDefinedType => p.sqlType case other => other } - val field = BoundReference(i, dt, f.nullable) - If( - IsNull(field), - Literal.create(null, externalDataTypeFor(dt)), - deserializerFor(field) - ) + deserializerFor(GetColumnByOrdinal(i, dt)) } CreateExternalRow(fields, schema) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 2ec46216e1cd..529717772b52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -190,7 +190,7 @@ abstract class Expression extends TreeNode[Expression] { case single => single :: Nil } - override def simpleString: String = toString + // override def simpleString: String = toString override def toString: String = prettyName + flatArguments.mkString("(", ", ", ")") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 98ce5dd2efd9..55d8adf0408f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -30,26 +30,13 @@ object CatalystSerde { DeserializeToObject(deserializer, generateObjAttr[T], child) } - def deserialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): DeserializeToObject = { - val deserializer = UnresolvedDeserializer(encoder.deserializer) - DeserializeToObject(deserializer, generateObjAttrForRow(encoder), child) - } - def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = { SerializeFromObject(encoderFor[T].namedExpressions, child) } - def serialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): SerializeFromObject = { - SerializeFromObject(encoder.namedExpressions, child) - } - def generateObjAttr[T : Encoder]: Attribute = { AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)() } - - def generateObjAttrForRow(encoder: ExpressionEncoder[Row]): Attribute = { - AttributeReference("obj", encoder.deserializer.dataType, nullable = false)() - } } /** @@ -128,16 +115,16 @@ object MapPartitionsInR { schema: StructType, encoder: ExpressionEncoder[Row], child: LogicalPlan): LogicalPlan = { - val deserialized = CatalystSerde.deserialize(child, encoder) + val deserialized = CatalystSerde.deserialize(child)(encoder) val mapped = MapPartitionsInR( func, packageNames, broadcastVars, encoder.schema, schema, - CatalystSerde.generateObjAttrForRow(RowEncoder(schema)), + CatalystSerde.generateObjAttr(RowEncoder(schema)), deserialized) - CatalystSerde.serialize(mapped, RowEncoder(schema)) + CatalystSerde.serialize(mapped)(RowEncoder(schema)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 282d5cdf01a7..a1f9259f139e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -27,6 +27,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.catalyst.analysis.AnalysisTest +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} @@ -350,12 +351,8 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { } // Test the correct resolution of serialization / deserialization. - val attr = AttributeReference("obj", ObjectType(encoder.clsTag.runtimeClass))() - val inputPlan = LocalRelation(attr) - val plan = - Project(Alias(encoder.deserializer, "obj")() :: Nil, - Project(encoder.namedExpressions, - inputPlan)) + val attr = AttributeReference("obj", encoder.deserializer.dataType)() + val plan = LocalRelation(attr).serialize[T].deserialize[T] assertAnalysisSuccess(plan) val isCorrect = (input, convertedBack) match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 39fcc7225b37..6f1bc80c1cdd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -135,7 +135,7 @@ class RowEncoderSuite extends SparkFunSuite { .add("string", StringType) .add("double", DoubleType)) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val input: Row = Row((100, "test", 0.123)) val row = encoder.toRow(input) @@ -152,7 +152,7 @@ class RowEncoderSuite extends SparkFunSuite { .add("scala_decimal", DecimalType.SYSTEM_DEFAULT) .add("catalyst_decimal", DecimalType.SYSTEM_DEFAULT) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val javaDecimal = new java.math.BigDecimal("1234.5678") val scalaDecimal = BigDecimal("1234.5678") @@ -169,7 +169,7 @@ class RowEncoderSuite extends SparkFunSuite { test("RowEncoder should preserve decimal precision and scale") { val schema = new StructType().add("decimal", DecimalType(10, 5), false) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val decimal = Decimal("67123.45") val input = Row(decimal) val row = encoder.toRow(input) @@ -179,7 +179,7 @@ class RowEncoderSuite extends SparkFunSuite { test("RowEncoder should preserve schema nullability") { val schema = new StructType().add("int", IntegerType, nullable = false) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() assert(encoder.serializer.length == 1) assert(encoder.serializer.head.dataType == IntegerType) assert(encoder.serializer.head.nullable == false) @@ -195,7 +195,7 @@ class RowEncoderSuite extends SparkFunSuite { new StructType().add("int", IntegerType, nullable = false), nullable = false), nullable = false) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() assert(encoder.serializer.length == 1) assert(encoder.serializer.head.dataType == new StructType() @@ -212,7 +212,7 @@ class RowEncoderSuite extends SparkFunSuite { .add("array", ArrayType(IntegerType)) .add("nestedArray", ArrayType(ArrayType(StringType))) .add("deepNestedArray", ArrayType(ArrayType(ArrayType(LongType)))) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val input = Row( Array(1, 2, null), Array(Array("abc", null), null), @@ -226,7 +226,7 @@ class RowEncoderSuite extends SparkFunSuite { private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val inputGenerator = RandomDataGenerator.forType(schema, nullable = false).get var input: Row = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a52b1754f9ee..e90b1783f937 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -197,7 +197,7 @@ class Dataset[T] private[sql]( * it when constructing new [[Dataset]] objects that have the same object type (that will be * possibly resolved to a different schema). */ - private[sql] implicit val enc: ExpressionEncoder[T] = encoderFor(encoder) + private[sql] implicit val exprEnc: ExpressionEncoder[T] = encoderFor(encoder) /** * Encoder is used mostly as a container of serde expressions in [[Dataset]]. We build logical @@ -206,9 +206,9 @@ class Dataset[T] private[sql]( * custom objects, e.g. collect. Here we resolve and bind the encoder so that we can call its * `fromRow` method later. */ - private val boundEnc = enc.resolveAndBind(logicalPlan.output) + private val boundEnc = exprEnc.resolveAndBind(logicalPlan.output) - private implicit def classTag = enc.clsTag + private implicit def classTag = exprEnc.clsTag // sqlContext must be val because a stable identifier is expected when you import implicits @transient lazy val sqlContext: SQLContext = sparkSession.sqlContext @@ -760,7 +760,7 @@ class Dataset[T] private[sql]( // Note that we do this before joining them, to enable the join operator to return null for one // side, in cases like outer-join. val left = { - val combined = if (this.enc.flat) { + val combined = if (this.exprEnc.flat) { assert(joined.left.output.length == 1) Alias(joined.left.output.head, "_1")() } else { @@ -770,7 +770,7 @@ class Dataset[T] private[sql]( } val right = { - val combined = if (other.enc.flat) { + val combined = if (other.exprEnc.flat) { assert(joined.right.output.length == 1) Alias(joined.right.output.head, "_2")() } else { @@ -783,14 +783,14 @@ class Dataset[T] private[sql]( // combine the outputs of each join side. val conditionExpr = joined.condition.get transformUp { case a: Attribute if joined.left.outputSet.contains(a) => - if (this.enc.flat) { + if (this.exprEnc.flat) { left.output.head } else { val index = joined.left.output.indexWhere(_.exprId == a.exprId) GetStructField(left.output.head, index) } case a: Attribute if joined.right.outputSet.contains(a) => - if (other.enc.flat) { + if (other.exprEnc.flat) { right.output.head } else { val index = joined.right.output.indexWhere(_.exprId == a.exprId) @@ -799,7 +799,7 @@ class Dataset[T] private[sql]( } implicit val tuple2Encoder: Encoder[(T, U)] = - ExpressionEncoder.tuple(this.enc, other.enc) + ExpressionEncoder.tuple(this.exprEnc, other.exprEnc) withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr))) } @@ -1023,7 +1023,7 @@ class Dataset[T] private[sql]( sparkSession, Project( c1.withInputType( - enc.deserializer, + exprEnc.deserializer, logicalPlan.output).named :: Nil, logicalPlan), implicitly[Encoder[U1]]) @@ -1037,7 +1037,7 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(enc.deserializer, logicalPlan.output).named) + columns.map(_.withInputType(exprEnc.deserializer, logicalPlan.output).named) val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders)) } @@ -2321,7 +2321,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ lazy val rdd: RDD[T] = { - val objectType = enc.deserializer.dataType + val objectType = exprEnc.deserializer.dataType val deserialized = CatalystSerde.deserialize[T](logicalPlan) sparkSession.sessionState.executePlan(deserialized).toRdd.mapPartitions { rows => rows.map(_.get(0, objectType).asInstanceOf[T]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 7e0591f6e608..a6867a67eead 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -43,8 +43,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( private val groupingAttributes: Seq[Attribute]) extends Serializable { // Similar to [[Dataset]], we turn the passed in encoder to `ExpressionEncoder` explicitly. - private implicit val kEnc = encoderFor(kEncoder) - private implicit val vEnc = encoderFor(vEncoder) + private implicit val kExprEnc = encoderFor(kEncoder) + private implicit val vExprEnc = encoderFor(vEncoder) private def logicalPlan = queryExecution.analyzed private def sparkSession = queryExecution.sparkSession @@ -59,7 +59,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( def keyAs[L : Encoder]: KeyValueGroupedDataset[L, V] = new KeyValueGroupedDataset( encoderFor[L], - vEnc, + vExprEnc, queryExecution, dataAttributes, groupingAttributes) @@ -179,7 +179,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = { val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) - implicit val resultEncoder = ExpressionEncoder.tuple(kEnc, vEnc) + implicit val resultEncoder = ExpressionEncoder.tuple(kExprEnc, vExprEnc) flatMapGroups(func) } @@ -201,8 +201,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(vEnc.deserializer, dataAttributes).named) - val keyColumn = if (kEnc.flat) { + columns.map(_.withInputType(vExprEnc.deserializer, dataAttributes).named) + val keyColumn = if (kExprEnc.flat) { assert(groupingAttributes.length == 1) groupingAttributes.head } else { @@ -214,7 +214,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( new Dataset( sparkSession, execution, - ExpressionEncoder.tuple(kEnc +: encoders)) + ExpressionEncoder.tuple(kExprEnc +: encoders)) } /** @@ -279,7 +279,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( def cogroup[U, R : Encoder]( other: KeyValueGroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { - implicit val uEncoder = other.vEnc + implicit val uEncoder = other.vExprEnc Dataset[R]( sparkSession, CoGroup( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 26ce28f09b43..49b6eab8db5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -215,7 +215,7 @@ class RelationalGroupedDataset protected[sql]( def agg(expr: Column, exprs: Column*): DataFrame = { toDF((expr +: exprs).map { case typed: TypedColumn[_, _] => - typed.withInputType(df.enc.deserializer, df.logicalPlan.output).expr + typed.withInputType(df.exprEnc.deserializer, df.logicalPlan.output).expr case c => c.expr }) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 8f94184764c0..ecb56e2a2848 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -33,9 +33,9 @@ object TypedAggregateExpression { aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { val bufferEncoder = encoderFor[BUF] val bufferSerializer = bufferEncoder.namedExpressions - val bufferDeserializer = bufferEncoder.deserializer.transform { - case b: BoundReference => bufferSerializer(b.ordinal).toAttribute - } + val bufferDeserializer = UnresolvedDeserializer( + bufferEncoder.deserializer, + bufferSerializer.map(_.toAttribute)) val outputEncoder = encoderFor[OUT] val outputType = if (outputEncoder.flat) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e319431bd1b2..4dbd1665e4bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -924,7 +924,7 @@ object functions { * @since 1.5.0 */ def broadcast[T](df: Dataset[T]): Dataset[T] = { - Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.enc) + Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.exprEnc) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 20beaa7e9cb9..9c044f4e8fd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -81,7 +81,7 @@ abstract class QueryTest extends PlanTest { expectedAnswer: T*): Unit = { checkAnswer( ds.toDF(), - spark.createDataset(expectedAnswer)(ds.enc).toDF().collect().toSeq) + spark.createDataset(expectedAnswer)(ds.exprEnc).toDF().collect().toSeq) checkDecoding(ds, expectedAnswer: _*) } @@ -94,8 +94,8 @@ abstract class QueryTest extends PlanTest { fail( s""" |Exception collecting dataset as objects - |${ds.enc} - |${ds.enc.deserializer.treeString} + |${ds.exprEnc} + |${ds.exprEnc.deserializer.treeString} |${ds.queryExecution} """.stripMargin, e) } @@ -114,7 +114,7 @@ abstract class QueryTest extends PlanTest { fail( s"""Decoded objects do not match expected objects: |$comparison - |${ds.enc.deserializer.treeString} + |${ds.exprEnc.deserializer.treeString} """.stripMargin) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala index 6f10e4b80577..80340b5552c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala @@ -27,7 +27,7 @@ class GroupedIteratorSuite extends SparkFunSuite { test("basic") { val schema = new StructType().add("i", IntegerType).add("s", StringType) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) val grouped = GroupedIterator(input.iterator.map(encoder.toRow), Seq('i.int.at(0)), schema.toAttributes) @@ -45,7 +45,7 @@ class GroupedIteratorSuite extends SparkFunSuite { test("group by 2 columns") { val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val input = Seq( Row(1, 2L, "a"), @@ -72,7 +72,7 @@ class GroupedIteratorSuite extends SparkFunSuite { test("do nothing to the value iterator") { val schema = new StructType().add("i", IntegerType).add("s", StringType) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) val grouped = GroupedIterator(input.iterator.map(encoder.toRow), Seq('i.int.at(0)), schema.toAttributes) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index dd8672aa641d..194c3e730725 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -110,7 +110,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { object CheckAnswer { def apply[A : Encoder](data: A*): CheckAnswerRows = { val encoder = encoderFor[A] - val toExternalRow = RowEncoder(encoder.schema) + val toExternalRow = RowEncoder(encoder.schema).resolveAndBind() CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), false) } @@ -124,7 +124,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { object CheckLastBatch { def apply[A : Encoder](data: A*): CheckAnswerRows = { val encoder = encoderFor[A] - val toExternalRow = RowEncoder(encoder.schema) + val toExternalRow = RowEncoder(encoder.schema).resolveAndBind() CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), true) } From c923486603b3063aa893d1bc4a210d00b9ce390e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 1 Jun 2016 13:15:57 -0700 Subject: [PATCH 3/6] update --- .../org/apache/spark/sql/catalyst/expressions/Expression.scala | 2 +- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 529717772b52..2ec46216e1cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -190,7 +190,7 @@ abstract class Expression extends TreeNode[Expression] { case single => single :: Nil } - // override def simpleString: String = toString + override def simpleString: String = toString override def toString: String = prettyName + flatArguments.mkString("(", ", ", ")") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index e90b1783f937..96c871d03435 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -206,7 +206,8 @@ class Dataset[T] private[sql]( * custom objects, e.g. collect. Here we resolve and bind the encoder so that we can call its * `fromRow` method later. */ - private val boundEnc = exprEnc.resolveAndBind(logicalPlan.output) + private val boundEnc = + exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer) private implicit def classTag = exprEnc.clsTag From 1377b5edcc3f26748c699c122f59e1858d251c7f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 2 Jun 2016 11:00:11 -0700 Subject: [PATCH 4/6] address comments --- .../sql/catalyst/analysis/Analyzer.scala | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a776772da5b3..0bec44f57c0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1948,13 +1948,12 @@ class Analyzer( * exception if they do not match. */ private def validateTupleColumns(deserializer: Expression, inputs: Seq[Attribute]): Unit = { - var maxOrdinal = -1 - deserializer.foreach { - case GetColumnByOrdinal(ordinal, _) => if (ordinal > maxOrdinal) maxOrdinal = ordinal - case _ => - } - if (maxOrdinal >= 0 && maxOrdinal != inputs.length - 1) { - fail(inputs.toStructType, maxOrdinal) + val ordinals = deserializer.collect { + case GetColumnByOrdinal(ordinal, _) => ordinal + }.distinct.sorted + + if (ordinals.nonEmpty && ordinals != inputs.indices) { + fail(inputs.toStructType, ordinals.last) } } @@ -1965,20 +1964,22 @@ class Analyzer( * Tuple fields, and throw an exception if they do not match. */ private def validateInnerTupleFields(deserializer: Expression): Unit = { - val exprToMaxOrdinal = scala.collection.mutable.HashMap.empty[Expression, Int] + val exprToOrdinals = scala.collection.mutable.HashMap.empty[Expression, ArrayBuffer[Int]] deserializer foreach { case g: GetStructField => - val maxOrdinal = exprToMaxOrdinal.getOrElse(g.child, -1) - if (maxOrdinal < g.ordinal) { - exprToMaxOrdinal.update(g.child, g.ordinal) + if (exprToOrdinals.contains(g.child)) { + exprToOrdinals(g.child) += g.ordinal + } else { + exprToOrdinals += g.child -> ArrayBuffer(g.ordinal) } case _ => } - exprToMaxOrdinal.foreach { - case (expr, maxOrdinal) => + exprToOrdinals.foreach { + case (expr, ordinals) => val schema = expr.dataType.asInstanceOf[StructType] - if (maxOrdinal != schema.length - 1) { - fail(schema, maxOrdinal) + val sortedOrdinals: Seq[Int] = ordinals.distinct.sorted + if (sortedOrdinals.nonEmpty && sortedOrdinals != schema.indices) { + fail(schema, sortedOrdinals.last) } } } From 52531c397a52a92acb6ab1ea62dd0ec37d8d6e3c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 2 Jun 2016 17:17:25 -0700 Subject: [PATCH 5/6] address comments --- .../sql/catalyst/analysis/Analyzer.scala | 53 ++++++++----------- 1 file changed, 22 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0bec44f57c0f..497e8b760fac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -722,6 +722,7 @@ class Analyzer( // Else, throw exception. try { expr transformUp { + case GetColumnByOrdinal(ordinal, _) => plan.output(ordinal) case u @ UnresolvedAttribute(nameParts) => withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) } case UnresolvedExtractValue(child, fieldName) if child.resolved => @@ -1925,14 +1926,11 @@ class Analyzer( inputAttributes } - validateTupleColumns(deserializer, inputs) - val ordinalResolved = deserializer transform { - case GetColumnByOrdinal(ordinal, _) => inputs(ordinal) - } - val attrResolved = resolveExpression( - ordinalResolved, LocalRelation(inputs), throws = true) - validateInnerTupleFields(attrResolved) - attrResolved + validateTopLevelTupleFields(deserializer, inputs) + val resolved = resolveExpression( + deserializer, LocalRelation(inputs), throws = true) + validateNestedTupleFields(resolved) + resolved } } @@ -1942,12 +1940,12 @@ class Analyzer( } /** - * For each Tuple field, we use [[GetColumnByOrdinal]] to get its corresponding column by - * position. However, the actual number of columns may be different from the number of Tuple + * For each top-level Tuple field, we use [[GetColumnByOrdinal]] to get its corresponding column + * by position. However, the actual number of columns may be different from the number of Tuple * fields. This method is used to check the number of columns and fields, and throw an * exception if they do not match. */ - private def validateTupleColumns(deserializer: Expression, inputs: Seq[Attribute]): Unit = { + private def validateTopLevelTupleFields(deserializer: Expression, inputs: Seq[Attribute]): Unit = { val ordinals = deserializer.collect { case GetColumnByOrdinal(ordinal, _) => ordinal }.distinct.sorted @@ -1958,29 +1956,22 @@ class Analyzer( } /** - * For each inner Tuple field, we use [[GetStructField]] to get its corresponding struct field + * For each nested Tuple field, we use [[GetStructField]] to get its corresponding struct field * by position. However, the actual number of struct fields may be different from the number - * of inner Tuple fields. This method is used to check the number of struct fields and inner + * of nested Tuple fields. This method is used to check the number of struct fields and nested * Tuple fields, and throw an exception if they do not match. */ - private def validateInnerTupleFields(deserializer: Expression): Unit = { - val exprToOrdinals = scala.collection.mutable.HashMap.empty[Expression, ArrayBuffer[Int]] - deserializer foreach { - case g: GetStructField => - if (exprToOrdinals.contains(g.child)) { - exprToOrdinals(g.child) += g.ordinal - } else { - exprToOrdinals += g.child -> ArrayBuffer(g.ordinal) - } - case _ => - } - exprToOrdinals.foreach { - case (expr, ordinals) => - val schema = expr.dataType.asInstanceOf[StructType] - val sortedOrdinals: Seq[Int] = ordinals.distinct.sorted - if (sortedOrdinals.nonEmpty && sortedOrdinals != schema.indices) { - fail(schema, sortedOrdinals.last) - } + private def validateNestedTupleFields(deserializer: Expression): Unit = { + val structChildToOrdinals = deserializer + .collect { case g: GetStructField => g } + .groupBy(_.child) + .mapValues(_.map(_.ordinal).distinct.sorted) + + structChildToOrdinals.foreach { case (expr, ordinals) => + val schema = expr.dataType.asInstanceOf[StructType] + if (ordinals != schema.indices) { + fail(schema, ordinals.last) + } } } } From efe0cd57056673fcd3406ffaa468b431fa71eac9 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 2 Jun 2016 17:28:50 -0700 Subject: [PATCH 6/6] Fixes ScalaStyle check failure --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 497e8b760fac..4f6b4830cd6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1945,7 +1945,8 @@ class Analyzer( * fields. This method is used to check the number of columns and fields, and throw an * exception if they do not match. */ - private def validateTopLevelTupleFields(deserializer: Expression, inputs: Seq[Attribute]): Unit = { + private def validateTopLevelTupleFields( + deserializer: Expression, inputs: Seq[Attribute]): Unit = { val ordinals = deserializer.collect { case GetColumnByOrdinal(ordinal, _) => ordinal }.distinct.sorted