From 358e5bf802385adcab86d3bca3ccdf3fe4053648 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Sun, 4 Aug 2019 14:00:21 -0700 Subject: [PATCH 01/10] SPARK-28612: Add DataFrameWriterV2 API. --- .../expressions/PartitionTransforms.scala | 77 +++ .../sql/catalyst/analysis/Analyzer.scala | 6 +- .../plans/logical/basicLogicalOperators.scala | 47 +- .../v2/DataSourceV2Implicits.scala | 9 + .../spark/sql/connector/InMemoryTable.scala | 5 +- .../apache/spark/sql/DataFrameWriter.scala | 11 +- .../apache/spark/sql/DataFrameWriterV2.scala | 419 +++++++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 28 + .../datasources/v2/DataSourceV2Strategy.scala | 20 +- .../datasources/v2/V2WriteSupportCheck.scala | 6 +- .../org/apache/spark/sql/functions.scala | 55 ++ .../sources/v2/DataFrameWriterV2Suite.scala | 578 ++++++++++++++++++ 12 files changed, 1225 insertions(+), 36 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala new file mode 100644 index 0000000000000..ab18c4a7cb1b3 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.types.{DataType, IntegerType} + +/** + * Base class for expressions that are converted to v2 partition transforms. + * + * Subclasses represent abstract transform functions with concrete implementations that are + * determined by data source implementations. Because the concrete implementation is not known, + * these expressions are [[Unevaluable]]. + * + * These expressions are used to pass transformations from the DataFrame API: + * + * {{{ + * df.writeTo("catalog.db.table").partitionBy($"category", days($"timestamp")).create() + * }}} + */ +abstract class PartitionTransformExpression extends Expression with Unevaluable { + override def nullable: Boolean = true +} + +/** + * Expression for the v2 partition transform years. + */ +case class Years(child: Expression) extends PartitionTransformExpression { + override def dataType: DataType = IntegerType + override def children: Seq[Expression] = Seq(child) +} + +/** + * Expression for the v2 partition transform months. + */ +case class Months(child: Expression) extends PartitionTransformExpression { + override def dataType: DataType = IntegerType + override def children: Seq[Expression] = Seq(child) +} + +/** + * Expression for the v2 partition transform days. + */ +case class Days(child: Expression) extends PartitionTransformExpression { + override def dataType: DataType = IntegerType + override def children: Seq[Expression] = Seq(child) +} + +/** + * Expression for the v2 partition transform hours. + */ +case class Hours(child: Expression) extends PartitionTransformExpression { + override def dataType: DataType = IntegerType + override def children: Seq[Expression] = Seq(child) +} + +/** + * Expression for the v2 partition transform bucket. + */ +case class Bucket(numBuckets: Literal, child: Expression) extends PartitionTransformExpression { + override def dataType: DataType = IntegerType + override def children: Seq[Expression] = Seq(numBuckets, child) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dcb6af6829c3f..0cb59411bd95a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2506,7 +2506,7 @@ class Analyzer( */ object ResolveOutputRelation extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { - case append @ AppendData(table, query, isByName) + case append @ AppendData(table, query, _, isByName) if table.resolved && query.resolved && !append.outputResolved => val projection = TableOutputResolver.resolveOutputColumns( @@ -2518,7 +2518,7 @@ class Analyzer( append } - case overwrite @ OverwriteByExpression(table, _, query, isByName) + case overwrite @ OverwriteByExpression(table, _, query, _, isByName) if table.resolved && query.resolved && !overwrite.outputResolved => val projection = TableOutputResolver.resolveOutputColumns( @@ -2530,7 +2530,7 @@ class Analyzer( overwrite } - case overwrite @ OverwritePartitionsDynamic(table, query, isByName) + case overwrite @ OverwritePartitionsDynamic(table, query, _, isByName) if table.resolved && query.resolved && !overwrite.outputResolved => val projection = TableOutputResolver.resolveOutputColumns( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 0be61cf147041..6e1825e4997c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -489,7 +489,7 @@ case class ReplaceTableAsSelect( override def tableSchema: StructType = query.schema override def children: Seq[LogicalPlan] = Seq(query) - override lazy val resolved: Boolean = { + override lazy val resolved: Boolean = childrenResolved && { // the table schema is created from the query schema, so the only resolution needed is to check // that the columns referenced by the table's partitioning exist in the query schema val references = partitioning.flatMap(_.references).toSet @@ -507,15 +507,22 @@ case class ReplaceTableAsSelect( case class AppendData( table: NamedRelation, query: LogicalPlan, + writeOptions: Map[String, String], isByName: Boolean) extends V2WriteCommand object AppendData { - def byName(table: NamedRelation, df: LogicalPlan): AppendData = { - new AppendData(table, df, isByName = true) + def byName( + table: NamedRelation, + df: LogicalPlan, + writeOptions: Map[String, String] = Map.empty): AppendData = { + new AppendData(table, df, writeOptions, isByName = true) } - def byPosition(table: NamedRelation, query: LogicalPlan): AppendData = { - new AppendData(table, query, isByName = false) + def byPosition( + table: NamedRelation, + query: LogicalPlan, + writeOptions: Map[String, String] = Map.empty): AppendData = { + new AppendData(table, query, writeOptions, isByName = false) } } @@ -526,19 +533,26 @@ case class OverwriteByExpression( table: NamedRelation, deleteExpr: Expression, query: LogicalPlan, + writeOptions: Map[String, String], isByName: Boolean) extends V2WriteCommand { override lazy val resolved: Boolean = outputResolved && deleteExpr.resolved } object OverwriteByExpression { def byName( - table: NamedRelation, df: LogicalPlan, deleteExpr: Expression): OverwriteByExpression = { - OverwriteByExpression(table, deleteExpr, df, isByName = true) + table: NamedRelation, + df: LogicalPlan, + deleteExpr: Expression, + writeOptions: Map[String, String] = Map.empty): OverwriteByExpression = { + OverwriteByExpression(table, deleteExpr, df, writeOptions, isByName = true) } def byPosition( - table: NamedRelation, query: LogicalPlan, deleteExpr: Expression): OverwriteByExpression = { - OverwriteByExpression(table, deleteExpr, query, isByName = false) + table: NamedRelation, + query: LogicalPlan, + deleteExpr: Expression, + writeOptions: Map[String, String] = Map.empty): OverwriteByExpression = { + OverwriteByExpression(table, deleteExpr, query, writeOptions, isByName = false) } } @@ -548,15 +562,22 @@ object OverwriteByExpression { case class OverwritePartitionsDynamic( table: NamedRelation, query: LogicalPlan, + writeOptions: Map[String, String], isByName: Boolean) extends V2WriteCommand object OverwritePartitionsDynamic { - def byName(table: NamedRelation, df: LogicalPlan): OverwritePartitionsDynamic = { - OverwritePartitionsDynamic(table, df, isByName = true) + def byName( + table: NamedRelation, + df: LogicalPlan, + writeOptions: Map[String, String] = Map.empty): OverwritePartitionsDynamic = { + OverwritePartitionsDynamic(table, df, writeOptions, isByName = true) } - def byPosition(table: NamedRelation, query: LogicalPlan): OverwritePartitionsDynamic = { - OverwritePartitionsDynamic(table, query, isByName = false) + def byPosition( + table: NamedRelation, + query: LogicalPlan, + writeOptions: Map[String, String] = Map.empty): OverwritePartitionsDynamic = { + OverwritePartitionsDynamic(table, query, writeOptions, isByName = false) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala index 2d59c42ee8684..ab33e8e5ceaf0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.collection.JavaConverters._ + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.sources.v2.{SupportsDelete, SupportsRead, SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.util.CaseInsensitiveStringMap object DataSourceV2Implicits { implicit class TableHelper(table: Table) { @@ -53,4 +56,10 @@ object DataSourceV2Implicits { def supportsAny(capabilities: TableCapability*): Boolean = capabilities.exists(supports) } + + implicit class OptionsHelper(options: Map[String, String]) { + def asOptions: CaseInsensitiveStringMap = { + new CaseInsensitiveStringMap(options.asJava) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 0dea1e3a68dc8..2dc4f8b680f6f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -41,8 +41,11 @@ class InMemoryTable( override val properties: util.Map[String, String]) extends Table with SupportsRead with SupportsWrite with SupportsDelete { + private val allowUnsupportedTransforms = + properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean + partitioning.foreach { t => - if (!t.isInstanceOf[IdentityTransform]) { + if (!t.isInstanceOf[IdentityTransform] && !allowUnsupportedTransforms) { throw new IllegalArgumentException(s"Transform $t must be IdentityTransform") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index d0a1d41c70dcb..13d38d4ae1e29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -271,13 +271,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { modeForDSV2 match { case SaveMode.Append => runCommand(df.sparkSession, "save") { - AppendData.byName(relation, df.logicalPlan) + AppendData.byName(relation, df.logicalPlan, extraOptions.toMap) } case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) => // truncate the table runCommand(df.sparkSession, "save") { - OverwriteByExpression.byName(relation, df.logicalPlan, Literal(true)) + OverwriteByExpression.byName( + relation, df.logicalPlan, Literal(true), extraOptions.toMap) } case other => @@ -383,7 +384,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val command = modeForDSV2 match { case SaveMode.Append => - AppendData.byPosition(table, df.logicalPlan) + AppendData.byPosition(table, df.logicalPlan, extraOptions.toMap) case SaveMode.Overwrite => val conf = df.sparkSession.sessionState.conf @@ -391,9 +392,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC if (dynamicPartitionOverwrite) { - OverwritePartitionsDynamic.byPosition(table, df.logicalPlan) + OverwritePartitionsDynamic.byPosition(table, df.logicalPlan, extraOptions.toMap) } else { - OverwriteByExpression.byPosition(table, df.logicalPlan, Literal(true)) + OverwriteByExpression.byPosition(table, df.logicalPlan, Literal(true), extraOptions.toMap) } case other => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala new file mode 100644 index 0000000000000..7986651ca4010 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -0,0 +1,419 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.http.annotation.Experimental + +import org.apache.spark.sql.catalog.v2.expressions.{LogicalExpressions, Transform} +import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Bucket, Days, Hours, Literal, Months, Years} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect} +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.IntegerType + +/** + * Interface used to write a [[Dataset]] to external storage using the v2 API. + * + * @since 3.0.0 + */ +@Experimental +final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) + extends CreateTableWriter[T] { + + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ + import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util._ + import df.sparkSession.sessionState.analyzer.CatalogObjectIdentifier + + private val df: DataFrame = ds.toDF() + + private val sparkSession = ds.sparkSession + + private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table) + + private val (catalog, identifier) = { + val CatalogObjectIdentifier(maybeCatalog, identifier) = tableName + val catalog = maybeCatalog.orElse(sparkSession.sessionState.analyzer.sessionCatalog) + .getOrElse(throw new AnalysisException( + s"No catalog specified for table ${identifier.quoted} and no default v2 catalog is set")) + .asTableCatalog + + (catalog, identifier) + } + + private val logicalPlan = df.queryExecution.logical + + private var provider: Option[String] = None + + private val options = new mutable.HashMap[String, String]() + + private val properties = new mutable.HashMap[String, String]() + + private var partitioning: Option[Seq[Transform]] = None + + override def using(provider: String): CreateTableWriter[T] = { + this.provider = Some(provider) + this + } + + override def option(key: String, value: String): DataFrameWriterV2[T] = { + this.options.put(key, value) + this + } + + override def option(key: String, value: Boolean): DataFrameWriterV2[T] = + option(key, value.toString) + + override def option(key: String, value: Long): DataFrameWriterV2[T] = option(key, value.toString) + + override def option(key: String, value: Double): DataFrameWriterV2[T] = + option(key, value.toString) + + override def options(options: scala.collection.Map[String, String]): DataFrameWriterV2[T] = { + options.foreach { + case (key, value) => + this.options.put(key, value) + } + this + } + + override def options(options: java.util.Map[String, String]): DataFrameWriterV2[T] = { + this.options(options.asScala) + this + } + + override def tableProperty(property: String, value: String): DataFrameWriterV2[T] = { + this.properties.put(property, value) + this + } + + @scala.annotation.varargs + override def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T] = { + val asTransforms = (column +: columns).map(_.expr).map { + case Years(attr: Attribute) => + LogicalExpressions.years(attr.name) + case Months(attr: Attribute) => + LogicalExpressions.months(attr.name) + case Days(attr: Attribute) => + LogicalExpressions.days(attr.name) + case Hours(attr: Attribute) => + LogicalExpressions.hours(attr.name) + case Bucket(Literal(numBuckets: Int, IntegerType), attr: Attribute) => + LogicalExpressions.bucket(numBuckets, attr.name) + case attr: Attribute => + LogicalExpressions.identity(attr.name) + case expr => + throw new AnalysisException(s"Invalid partition transformation: ${expr.sql}") + } + + this.partitioning = Some(asTransforms) + this + } + + override def create(): Unit = { + // create and replace could alternatively create ParsedPlan statements, like + // `CreateTableFromDataFrameStatement(UnresolvedRelation(tableName), ...)`, to keep the catalog + // resolution logic in the analyzer. + runCommand("create") { + CreateTableAsSelect( + catalog, + identifier, + partitioning.getOrElse(Seq.empty), + logicalPlan, + properties = provider.map(p => properties + ("provider" -> p)).getOrElse(properties).toMap, + writeOptions = options.toMap, + ignoreIfExists = false) + } + } + + override def replace(): Unit = { + internalReplace(orCreate = false) + } + + override def createOrReplace(): Unit = { + internalReplace(orCreate = true) + } + + + /** + * Append the contents of the data frame to the output table. + * + * If the output table does not exist, this operation will fail with [[NoSuchTableException]]. The + * data frame will be validated to ensure it is compatible with the existing table. + * + * + * @throws NoSuchTableException If the table does not exist. + */ + def append(): Unit = { + val append = loadTable(catalog, identifier) match { + case Some(t) => + AppendData.byName(DataSourceV2Relation.create(t), logicalPlan, options.toMap) + case _ => + throw new NoSuchTableException(identifier) + } + + runCommand("append")(append) + } + + /** + * Overwrite rows matching the given filter condition with the contents of the data frame in + * the output table. + * + * If the output table does not exist, this operation will fail with [[NoSuchTableException]]. The + * data frame will be validated to ensure it is compatible with the existing table. + * + * @throws NoSuchTableException If the table does not exist. + */ + def overwrite(condition: Column): Unit = { + val overwrite = loadTable(catalog, identifier) match { + case Some(t) => + OverwriteByExpression.byName( + DataSourceV2Relation.create(t), logicalPlan, condition.expr, options.toMap) + case _ => + throw new NoSuchTableException(identifier) + } + + runCommand("overwrite")(overwrite) + } + + /** + * Overwrite all partition for which the data frame contains at least one row with the contents + * of the data frame in the output table. + * + * This operation is equivalent to Hive's `INSERT OVERWRITE ... PARTITION`, which replaces + * partitions dynamically depending on the contents of the data frame. + * + * If the output table does not exist, this operation will fail with [[NoSuchTableException]]. The + * data frame will be validated to ensure it is compatible with the existing table. + * + * @throws NoSuchTableException If the table does not exist. + */ + def overwritePartitions(): Unit = { + val dynamicOverwrite = loadTable(catalog, identifier) match { + case Some(t) => + OverwritePartitionsDynamic.byName( + DataSourceV2Relation.create(t), logicalPlan, options.toMap) + case _ => + throw new NoSuchTableException(identifier) + } + + runCommand("overwritePartitions")(dynamicOverwrite) + } + + /** + * Wrap an action to track the QueryExecution and time cost, then report to the user-registered + * callback functions. + * + * Visible for testing. + */ + private[sql] def runCommand(name: String)(command: LogicalPlan): Unit = { + val qe = sparkSession.sessionState.executePlan(command) + // call `QueryExecution.toRDD` to trigger the execution of commands. + SQLExecution.withNewExecutionId(sparkSession, qe, Some(name))(qe.toRdd) + } + + private def internalReplace(orCreate: Boolean): Unit = { + runCommand("replace") { + ReplaceTableAsSelect( + catalog, + identifier, + partitioning.getOrElse(Seq.empty), + logicalPlan, + properties = provider.map(p => properties + ("provider" -> p)).getOrElse(properties).toMap, + writeOptions = options.toMap, + orCreate = orCreate) + } + } +} + +/** + * Configuration methods common to create/replace operations and insert/overwrite operations. + * @tparam R builder type to return + */ +trait WriteConfigMethods[R] { + /** + * Add a write option. + * + * @since 3.0.0 + */ + def option(key: String, value: String): R + + /** + * Add a boolean output option. + * + * @since 3.0.0 + */ + def option(key: String, value: Boolean): R + + /** + * Add a long output option. + * + * @since 3.0.0 + */ + def option(key: String, value: Long): R + + /** + * Add a double output option. + * + * @since 3.0.0 + */ + def option(key: String, value: Double): R + + /** + * Add write options from a Scala Map. + * + * @since 3.0.0 + */ + def options(options: scala.collection.Map[String, String]): R + + /** + * Add write options from a Java Map. + * + * @since 3.0.0 + */ + def options(options: java.util.Map[String, String]): R +} + +/** + * Trait to restrict calls to create and replace operations. + */ +trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] { + /** + * Create a new table from the contents of the data frame. + * + * The new table's schema, partition layout, properties, and other configuration will be + * based on the configuration set on this writer. + * + * If the output table exists, this operation will fail with [[TableAlreadyExistsException]]. + * + * @throws TableAlreadyExistsException If the table already exists. + */ + def create(): Unit + + /** + * Replace an existing table with the contents of the data frame. + * + * The existing table's schema, partition layout, properties, and other configuration will be + * replaced with the contents of the data frame and the configuration set on this writer. + * + * If the output table exists, this operation will fail with [[TableAlreadyExistsException]]. + * + * @throws TableAlreadyExistsException If the table already exists. + */ + def replace(): Unit + + /** + * Create a new table or replace an existing table with the contents of the data frame. + * + * The output table's schema, partition layout, properties, and other configuration will be based + * on the contents of the data frame and the configuration set on this writer. If the table + * exists, its configuration and data will be replaced. + * + * If the output table exists, this operation will fail with [[TableAlreadyExistsException]]. + * + * @throws TableAlreadyExistsException If the table already exists. + */ + def createOrReplace(): Unit + + /** + * Partition the output table created by [[create]], [[createOrReplace]], or [[replace]] using + * the given columns or transforms. + * + * When specified, the table data will be stored by these values for efficient reads. + * + * For example, when a table is partitioned by day, it may be stored in a directory layout like: + * + * + * Partitioning is one of the most widely used techniques to optimize physical data layout. + * It provides a coarse-grained index for skipping unnecessary data reads when queries have + * predicates on the partitioned columns. In order for partitioning to work well, the number + * of distinct values in each column should typically be less than tens of thousands. + * + * @since 3.0.0 + */ + def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T] + + /** + * Specifies a provider for the underlying output data source. Spark's default catalog supports + * "parquet", "json", etc. + * + * @since 3.0.0 + */ + def using(provider: String): CreateTableWriter[T] + + /** + * Use the "csv" provider. + * + * This is equivalent to: + * {{{ + * using("csv") + * }}} + */ + def asCsv: CreateTableWriter[T] = using("csv") + + /** + * Use the "text" provider. + * + * This is equivalent to: + * {{{ + * using("text") + * }}} + */ + def asText: CreateTableWriter[T] = using("text") + + /** + * Use the "json" provider. + * + * This is equivalent to: + * {{{ + * using("json") + * }}} + */ + def asJson: CreateTableWriter[T] = using("json") + + /** + * Use the "parquet" provider. + * + * This is equivalent to: + * {{{ + * using("parquet") + * }}} + */ + def asParquet: CreateTableWriter[T] = using("parquet") + + /** + * Use the "orc" provider. + * + * This is equivalent to: + * {{{ + * using("orc") + * }}} + */ + def asOrc: CreateTableWriter[T] = using("orc") + + /** + * Add a table property. + */ + def tableProperty(property: String, value: String): CreateTableWriter[T] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 7c25397e32beb..8619849f399a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3190,6 +3190,34 @@ class Dataset[T] private[sql]( new DataFrameWriter[T](this) } + /** + * Create a write configuration builder for v2 sources. + * + * This builder is used to configure and execute write operations. For example, to append to an + * existing table, run: + * + * {{{ + * df.writeTo("catalog.db.table").append() + * }}} + * + * This can also be used to create or replace existing tables: + * + * {{{ + * df.writeTo("catalog.db.table").partitionedBy($"col").createOrReplace() + * }}} + * + * @group basic + * @since 3.0.0 + */ + def writeTo(table: String): DataFrameWriterV2[T] = { + // TODO: streaming could be adapted to use this interface + if (isStreaming) { + logicalPlan.failAnalysis( + "'write' can not be called on streaming Dataset/DataFrame") + } + new DataFrameWriterV2[T](table, this) + } + /** * Interface for saving the content of the streaming Dataset out into external storage. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index a934c095eee11..b5a573c170a2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.datasources.v2 -import java.util.UUID - import scala.collection.JavaConverters._ import scala.collection.mutable @@ -34,7 +32,6 @@ import org.apache.spark.sql.sources import org.apache.spark.sql.sources.v2.TableCapability import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} -import org.apache.spark.sql.sources.v2.writer.V1WriteBuilder import org.apache.spark.sql.util.CaseInsensitiveStringMap object DataSourceV2Strategy extends Strategy with PredicateHelper { @@ -212,15 +209,15 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { orCreate = orCreate) :: Nil } - case AppendData(r: DataSourceV2Relation, query, _) => + case AppendData(r: DataSourceV2Relation, query, writeOptions, _) => r.table.asWritable match { case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => - AppendDataExecV1(v1, r.options, query) :: Nil + AppendDataExecV1(v1, writeOptions.asOptions, query) :: Nil case v2 => - AppendDataExec(v2, r.options, planLater(query)) :: Nil + AppendDataExec(v2, writeOptions.asOptions, planLater(query)) :: Nil } - case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, _) => + case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, writeOptions, _) => // fail if any filter cannot be converted. correctness depends on removing all matching data. val filters = splitConjunctivePredicates(deleteExpr).map { filter => DataSourceStrategy.translateFilter(deleteExpr).getOrElse( @@ -228,13 +225,14 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { }.toArray r.table.asWritable match { case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => - OverwriteByExpressionExecV1(v1, filters, r.options, query) :: Nil + OverwriteByExpressionExecV1(v1, filters, writeOptions.asOptions, query) :: Nil case v2 => - OverwriteByExpressionExec(v2, filters, r.options, planLater(query)) :: Nil + OverwriteByExpressionExec(v2, filters, writeOptions.asOptions, planLater(query)) :: Nil } - case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _) => - OverwritePartitionsDynamicExec(r.table.asWritable, r.options, planLater(query)) :: Nil + case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, writeOptions, _) => + OverwritePartitionsDynamicExec( + r.table.asWritable, writeOptions.asOptions, planLater(query)) :: Nil case DeleteFromTable(r: DataSourceV2Relation, condition) => // fail if any filter cannot be converted. correctness depends on removing all matching data. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala index 5648d5439ba5e..5a093ba5d5d30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala @@ -29,14 +29,14 @@ object V2WriteSupportCheck extends (LogicalPlan => Unit) { def failAnalysis(msg: String): Unit = throw new AnalysisException(msg) override def apply(plan: LogicalPlan): Unit = plan foreach { - case AppendData(rel: DataSourceV2Relation, _, _) if !rel.table.supports(BATCH_WRITE) => + case AppendData(rel: DataSourceV2Relation, _, _, _) if !rel.table.supports(BATCH_WRITE) => failAnalysis(s"Table does not support append in batch mode: ${rel.table}") - case OverwritePartitionsDynamic(rel: DataSourceV2Relation, _, _) + case OverwritePartitionsDynamic(rel: DataSourceV2Relation, _, _, _) if !rel.table.supports(BATCH_WRITE) || !rel.table.supports(OVERWRITE_DYNAMIC) => failAnalysis(s"Table does not support dynamic overwrite in batch mode: ${rel.table}") - case OverwriteByExpression(rel: DataSourceV2Relation, expr, _, _) => + case OverwriteByExpression(rel: DataSourceV2Relation, expr, _, _, _) => expr match { case Literal(true, BooleanType) => if (!rel.table.supports(BATCH_WRITE) || diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6b8127bab1cb4..4ab0b5a5793d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3942,6 +3942,61 @@ object functions { */ def to_csv(e: Column): Column = to_csv(e, Map.empty[String, String].asJava) + /** + * A transform for timestamps and dates to partition data into years. + * + * @group partition_transforms + * @since 3.0.0 + */ + def years(e: Column): Column = withExpr { Years(e.expr) } + + /** + * A transform for timestamps and dates to partition data into months. + * + * @group partition_transforms + * @since 3.0.0 + */ + def months(e: Column): Column = withExpr { Months(e.expr) } + + /** + * A transform for timestamps and dates to partition data into days. + * + * @group partition_transforms + * @since 3.0.0 + */ + def days(e: Column): Column = withExpr { Days(e.expr) } + + /** + * A transform for timestamps to partition data into hours. + * + * @group partition_transforms + * @since 3.0.0 + */ + def hours(e: Column): Column = withExpr { Hours(e.expr) } + + /** + * A transform for any type that partitions by a hash of the input column. + * + * @group partition_transforms + * @since 3.0.0 + */ + def bucket(numBuckets: Column, e: Column): Column = withExpr { + numBuckets.expr match { + case lit @ Literal(_, IntegerType) => + Bucket(lit, e.expr) + case _ => + throw new AnalysisException(s"Invalid number of buckets: $numBuckets") + } + } + + /** + * A transform for any type that partitions by a hash of the input column. + * + * @group partition_transforms + * @since 3.0.0 + */ + def bucket(numBuckets: Int, e: Column): Column = withExpr { Bucket(Literal(numBuckets), e.expr) } + // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala new file mode 100644 index 0000000000000..2bf0451e99246 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala @@ -0,0 +1,578 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import scala.collection.JavaConverters._ + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalog.v2.Identifier +import org.apache.spark.sql.catalog.v2.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} +import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} + +class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with BeforeAndAfter { + import org.apache.spark.sql.functions._ + import testImplicits._ + + before { + spark.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName) + + val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") + df.createOrReplaceTempView("source") + val df2 = spark.createDataFrame(Seq((4L, "d"), (5L, "e"), (6L, "f"))).toDF("id", "data") + df2.createOrReplaceTempView("source2") + } + + after { + spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog].clearTables() + } + + test("Append: basic append") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + + checkAnswer(spark.table("testcat.table_name"), Seq.empty) + + spark.table("source").writeTo("testcat.table_name").append() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + spark.table("source2").writeTo("testcat.table_name").append() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"), Row(4L, "d"), Row(5L, "e"), Row(6L, "f"))) + } + + test("Append: by name not position") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + + checkAnswer(spark.table("testcat.table_name"), Seq.empty) + + val exc = intercept[AnalysisException] { + spark.table("source").withColumnRenamed("data", "d").writeTo("testcat.table_name").append() + } + + assert(exc.getMessage.contains("Cannot find data for output column")) + assert(exc.getMessage.contains("'data'")) + + checkAnswer( + spark.table("testcat.table_name"), + Seq()) + } + + test("Append: fail if table does not exist") { + val exc = intercept[NoSuchTableException] { + spark.table("source").writeTo("testcat.table_name").append() + } + + assert(exc.getMessage.contains("table_name")) + } + + test("Overwrite: overwrite by expression: true") { + spark.sql( + "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)") + + checkAnswer(spark.table("testcat.table_name"), Seq.empty) + + spark.table("source").writeTo("testcat.table_name").append() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + spark.table("source2").writeTo("testcat.table_name").overwrite(lit(true)) + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f"))) + } + + test("Overwrite: overwrite by expression: id = 3") { + spark.sql( + "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)") + + checkAnswer(spark.table("testcat.table_name"), Seq.empty) + + spark.table("source").writeTo("testcat.table_name").append() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + spark.table("source2").writeTo("testcat.table_name").overwrite($"id" === 3) + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(4L, "d"), Row(5L, "e"), Row(6L, "f"))) + } + + test("Overwrite: by name not position") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + + checkAnswer(spark.table("testcat.table_name"), Seq.empty) + + val exc = intercept[AnalysisException] { + spark.table("source").withColumnRenamed("data", "d") + .writeTo("testcat.table_name").overwrite(lit(true)) + } + + assert(exc.getMessage.contains("Cannot find data for output column")) + assert(exc.getMessage.contains("'data'")) + + checkAnswer( + spark.table("testcat.table_name"), + Seq()) + } + + test("Overwrite: fail if table does not exist") { + val exc = intercept[NoSuchTableException] { + spark.table("source").writeTo("testcat.table_name").overwrite(lit(true)) + } + + assert(exc.getMessage.contains("table_name")) + } + + test("OverwritePartitions: overwrite conflicting partitions") { + spark.sql( + "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)") + + checkAnswer(spark.table("testcat.table_name"), Seq.empty) + + spark.table("source").writeTo("testcat.table_name").append() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + spark.table("source2").withColumn("id", $"id" - 2) + .writeTo("testcat.table_name").overwritePartitions() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "d"), Row(3L, "e"), Row(4L, "f"))) + } + + test("OverwritePartitions: overwrite all rows if not partitioned") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + + checkAnswer(spark.table("testcat.table_name"), Seq.empty) + + spark.table("source").writeTo("testcat.table_name").append() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + spark.table("source2").writeTo("testcat.table_name").overwritePartitions() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f"))) + } + + test("OverwritePartitions: by name not position") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + + checkAnswer(spark.table("testcat.table_name"), Seq.empty) + + val exc = intercept[AnalysisException] { + spark.table("source").withColumnRenamed("data", "d") + .writeTo("testcat.table_name").overwritePartitions() + } + + assert(exc.getMessage.contains("Cannot find data for output column")) + assert(exc.getMessage.contains("'data'")) + + checkAnswer( + spark.table("testcat.table_name"), + Seq()) + } + + test("OverwritePartitions: fail if table does not exist") { + val exc = intercept[NoSuchTableException] { + spark.table("source").writeTo("testcat.table_name").overwritePartitions() + } + + assert(exc.getMessage.contains("table_name")) + } + + test("Create: basic behavior") { + spark.table("source").writeTo("testcat.table_name").create() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.schema === new StructType() + .add("id", LongType, nullable = false) + .add("data", StringType)) + assert(table.partitioning.isEmpty) + assert(table.properties.isEmpty) + } + + test("Create: with using") { + spark.table("source").writeTo("testcat.table_name").using("foo").create() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.schema === new StructType() + .add("id", LongType, nullable = false) + .add("data", StringType)) + assert(table.partitioning.isEmpty) + assert(table.properties === Map("provider" -> "foo").asJava) + } + + test("Create: with property") { + spark.table("source").writeTo("testcat.table_name").tableProperty("prop", "value").create() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.schema === new StructType() + .add("id", LongType, nullable = false) + .add("data", StringType)) + assert(table.partitioning.isEmpty) + assert(table.properties === Map("prop" -> "value").asJava) + } + + test("Create: asText") { + spark.table("source").writeTo("testcat.table_name").asText.create() + + val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.properties === Map("provider" -> "text").asJava) + } + + test("Create: asCsv") { + spark.table("source").writeTo("testcat.table_name").asCsv.create() + + val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.properties === Map("provider" -> "csv").asJava) + } + + test("Create: asJson") { + spark.table("source").writeTo("testcat.table_name").asJson.create() + + val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.properties === Map("provider" -> "json").asJava) + } + + test("Create: asParquet") { + spark.table("source").writeTo("testcat.table_name").asParquet.create() + + val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.properties === Map("provider" -> "parquet").asJava) + } + + test("Create: asOrc") { + spark.table("source").writeTo("testcat.table_name").asOrc.create() + + val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.properties === Map("provider" -> "orc").asJava) + } + + test("Create: identity partitioned table") { + spark.table("source").writeTo("testcat.table_name").partitionedBy($"id").create() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.schema === new StructType() + .add("id", LongType, nullable = false) + .add("data", StringType)) + assert(table.partitioning === Seq(IdentityTransform(FieldReference("id")))) + assert(table.properties.isEmpty) + } + + test("Create: partitioned by years(ts)") { + spark.table("source") + .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) + .writeTo("testcat.table_name") + .tableProperty("allow-unsupported-transforms", "true") + .partitionedBy(years($"ts")) + .create() + + val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq(YearsTransform(FieldReference("ts")))) + } + + test("Create: partitioned by months(ts)") { + spark.table("source") + .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) + .writeTo("testcat.table_name") + .tableProperty("allow-unsupported-transforms", "true") + .partitionedBy(months($"ts")) + .create() + + val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq(MonthsTransform(FieldReference("ts")))) + } + + test("Create: partitioned by days(ts)") { + spark.table("source") + .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) + .writeTo("testcat.table_name") + .tableProperty("allow-unsupported-transforms", "true") + .partitionedBy(days($"ts")) + .create() + + val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq(DaysTransform(FieldReference("ts")))) + } + + test("Create: partitioned by hours(ts)") { + spark.table("source") + .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) + .writeTo("testcat.table_name") + .tableProperty("allow-unsupported-transforms", "true") + .partitionedBy(hours($"ts")) + .create() + + val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq(HoursTransform(FieldReference("ts")))) + } + + test("Create: partitioned by bucket(4, id)") { + spark.table("source") + .writeTo("testcat.table_name") + .tableProperty("allow-unsupported-transforms", "true") + .partitionedBy(bucket(4, $"id")) + .create() + + val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.partitioning === + Seq(BucketTransform(LiteralValue(4, IntegerType), Seq(FieldReference("id"))))) + } + + test("Create: fail if table already exists") { + spark.sql( + "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)") + + val exc = intercept[TableAlreadyExistsException] { + spark.table("source").writeTo("testcat.table_name").create() + } + + assert(exc.getMessage.contains("table_name")) + + val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + // table should not have been changed + assert(table.name === "testcat.table_name") + assert(table.schema === new StructType().add("id", LongType).add("data", StringType)) + assert(table.partitioning === Seq(IdentityTransform(FieldReference("id")))) + assert(table.properties === Map("provider" -> "foo").asJava) + } + + test("Replace: basic behavior") { + spark.sql( + "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)") + spark.sql("INSERT INTO TABLE testcat.table_name SELECT * FROM source") + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + // validate the initial table + assert(table.name === "testcat.table_name") + assert(table.schema === new StructType().add("id", LongType).add("data", StringType)) + assert(table.partitioning === Seq(IdentityTransform(FieldReference("id")))) + assert(table.properties === Map("provider" -> "foo").asJava) + + spark.table("source2") + .withColumn("even_or_odd", when(($"id" % 2) === 0, "even").otherwise("odd")) + .writeTo("testcat.table_name").replace() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(4L, "d", "even"), Row(5L, "e", "odd"), Row(6L, "f", "even"))) + + val replaced = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + // validate the replacement table + assert(replaced.name === "testcat.table_name") + assert(replaced.schema === new StructType() + .add("id", LongType, nullable = false) + .add("data", StringType) + .add("even_or_odd", StringType, nullable = false)) + assert(replaced.partitioning.isEmpty) + assert(replaced.properties.isEmpty) + } + + test("Replace: partitioned table") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + spark.sql("INSERT INTO TABLE testcat.table_name SELECT * FROM source") + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + // validate the initial table + assert(table.name === "testcat.table_name") + assert(table.schema === new StructType().add("id", LongType).add("data", StringType)) + assert(table.partitioning.isEmpty) + assert(table.properties === Map("provider" -> "foo").asJava) + + spark.table("source2") + .withColumn("even_or_odd", when(($"id" % 2) === 0, "even").otherwise("odd")) + .writeTo("testcat.table_name").partitionedBy($"id").replace() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(4L, "d", "even"), Row(5L, "e", "odd"), Row(6L, "f", "even"))) + + val replaced = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + // validate the replacement table + assert(replaced.name === "testcat.table_name") + assert(replaced.schema === new StructType() + .add("id", LongType, nullable = false) + .add("data", StringType) + .add("even_or_odd", StringType, nullable = false)) + assert(replaced.partitioning === Seq(IdentityTransform(FieldReference("id")))) + assert(replaced.properties.isEmpty) + } + + test("Replace: fail if table does not exist") { + val exc = intercept[CannotReplaceMissingTableException] { + spark.table("source").writeTo("testcat.table_name").replace() + } + + assert(exc.getMessage.contains("table_name")) + } + + test("CreateOrReplace: table does not exist") { + spark.table("source2").writeTo("testcat.table_name").createOrReplace() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f"))) + + val replaced = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + // validate the replacement table + assert(replaced.name === "testcat.table_name") + assert(replaced.schema === new StructType() + .add("id", LongType, nullable = false) + .add("data", StringType)) + assert(replaced.partitioning.isEmpty) + assert(replaced.properties.isEmpty) + } + + test("CreateOrReplace: table exists") { + spark.sql( + "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)") + spark.sql("INSERT INTO TABLE testcat.table_name SELECT * FROM source") + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + // validate the initial table + assert(table.name === "testcat.table_name") + assert(table.schema === new StructType().add("id", LongType).add("data", StringType)) + assert(table.partitioning === Seq(IdentityTransform(FieldReference("id")))) + assert(table.properties === Map("provider" -> "foo").asJava) + + spark.table("source2") + .withColumn("even_or_odd", when(($"id" % 2) === 0, "even").otherwise("odd")) + .writeTo("testcat.table_name").createOrReplace() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(4L, "d", "even"), Row(5L, "e", "odd"), Row(6L, "f", "even"))) + + val replaced = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + .loadTable(Identifier.of(Array(), "table_name")) + + // validate the replacement table + assert(replaced.name === "testcat.table_name") + assert(replaced.schema === new StructType() + .add("id", LongType, nullable = false) + .add("data", StringType) + .add("even_or_odd", StringType, nullable = false)) + assert(replaced.partitioning.isEmpty) + assert(replaced.properties.isEmpty) + } +} From 5dbc850f4f01808460aad3d42cf590bb3c416a37 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Sun, 4 Aug 2019 14:39:04 -0700 Subject: [PATCH 02/10] Fix bad import. --- .../main/scala/org/apache/spark/sql/DataFrameWriterV2.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index 7986651ca4010..da3a8d3c5a78b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import scala.collection.mutable -import org.apache.http.annotation.Experimental - +import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalog.v2.expressions.{LogicalExpressions, Transform} import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.expressions.{Attribute, Bucket, Days, Hours, Literal, Months, Years} From 002339b47f19de1427f2d3da67f4c1b84b2f64b2 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Mon, 5 Aug 2019 09:18:58 -0700 Subject: [PATCH 03/10] Fix DataFrameWriterV2 javadoc. --- .../apache/spark/sql/DataFrameWriterV2.scala | 67 +++++++++---------- 1 file changed, 32 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index da3a8d3c5a78b..2e9607822bbdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalog.v2.expressions.{LogicalExpressions, Transform} -import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.expressions.{Attribute, Bucket, Days, Hours, Literal, Months, Years} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect} import org.apache.spark.sql.execution.SQLExecution @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.IntegerType /** - * Interface used to write a [[Dataset]] to external storage using the v2 API. + * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 API. * * @since 3.0.0 */ @@ -78,14 +78,6 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) this } - override def option(key: String, value: Boolean): DataFrameWriterV2[T] = - option(key, value.toString) - - override def option(key: String, value: Long): DataFrameWriterV2[T] = option(key, value.toString) - - override def option(key: String, value: Double): DataFrameWriterV2[T] = - option(key, value.toString) - override def options(options: scala.collection.Map[String, String]): DataFrameWriterV2[T] = { options.foreach { case (key, value) => @@ -155,12 +147,13 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) /** * Append the contents of the data frame to the output table. * - * If the output table does not exist, this operation will fail with [[NoSuchTableException]]. The - * data frame will be validated to ensure it is compatible with the existing table. - * + * If the output table does not exist, this operation will fail with + * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be + * validated to ensure it is compatible with the existing table. * - * @throws NoSuchTableException If the table does not exist. + * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException If the table does not exist */ + @throws(classOf[NoSuchTableException]) def append(): Unit = { val append = loadTable(catalog, identifier) match { case Some(t) => @@ -176,11 +169,13 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) * Overwrite rows matching the given filter condition with the contents of the data frame in * the output table. * - * If the output table does not exist, this operation will fail with [[NoSuchTableException]]. The - * data frame will be validated to ensure it is compatible with the existing table. + * If the output table does not exist, this operation will fail with + * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. + * The data frame will be validated to ensure it is compatible with the existing table. * - * @throws NoSuchTableException If the table does not exist. + * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException If the table does not exist */ + @throws(classOf[NoSuchTableException]) def overwrite(condition: Column): Unit = { val overwrite = loadTable(catalog, identifier) match { case Some(t) => @@ -200,11 +195,13 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) * This operation is equivalent to Hive's `INSERT OVERWRITE ... PARTITION`, which replaces * partitions dynamically depending on the contents of the data frame. * - * If the output table does not exist, this operation will fail with [[NoSuchTableException]]. The - * data frame will be validated to ensure it is compatible with the existing table. + * If the output table does not exist, this operation will fail with + * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be + * validated to ensure it is compatible with the existing table. * - * @throws NoSuchTableException If the table does not exist. + * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException If the table does not exist */ + @throws(classOf[NoSuchTableException]) def overwritePartitions(): Unit = { val dynamicOverwrite = loadTable(catalog, identifier) match { case Some(t) => @@ -220,10 +217,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) /** * Wrap an action to track the QueryExecution and time cost, then report to the user-registered * callback functions. - * - * Visible for testing. */ - private[sql] def runCommand(name: String)(command: LogicalPlan): Unit = { + private def runCommand(name: String)(command: LogicalPlan): Unit = { val qe = sparkSession.sessionState.executePlan(command) // call `QueryExecution.toRDD` to trigger the execution of commands. SQLExecution.withNewExecutionId(sparkSession, qe, Some(name))(qe.toRdd) @@ -260,21 +255,21 @@ trait WriteConfigMethods[R] { * * @since 3.0.0 */ - def option(key: String, value: Boolean): R + def option(key: String, value: Boolean): R = option(key, value.toString) /** * Add a long output option. * * @since 3.0.0 */ - def option(key: String, value: Long): R + def option(key: String, value: Long): R = option(key, value.toString) /** * Add a double output option. * * @since 3.0.0 */ - def option(key: String, value: Double): R + def option(key: String, value: Double): R = option(key, value.toString) /** * Add write options from a Scala Map. @@ -301,10 +296,13 @@ trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] { * The new table's schema, partition layout, properties, and other configuration will be * based on the configuration set on this writer. * - * If the output table exists, this operation will fail with [[TableAlreadyExistsException]]. + * If the output table exists, this operation will fail with + * [[org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException]]. * - * @throws TableAlreadyExistsException If the table already exists. + * @throws org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException + * If the table already exists */ + @throws(classOf[TableAlreadyExistsException]) def create(): Unit /** @@ -313,10 +311,13 @@ trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] { * The existing table's schema, partition layout, properties, and other configuration will be * replaced with the contents of the data frame and the configuration set on this writer. * - * If the output table exists, this operation will fail with [[TableAlreadyExistsException]]. + * If the output table does not exist, this operation will fail with + * [[org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException]]. * - * @throws TableAlreadyExistsException If the table already exists. + * @throws org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException + * If the table already exists */ + @throws(classOf[CannotReplaceMissingTableException]) def replace(): Unit /** @@ -325,15 +326,11 @@ trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] { * The output table's schema, partition layout, properties, and other configuration will be based * on the contents of the data frame and the configuration set on this writer. If the table * exists, its configuration and data will be replaced. - * - * If the output table exists, this operation will fail with [[TableAlreadyExistsException]]. - * - * @throws TableAlreadyExistsException If the table already exists. */ def createOrReplace(): Unit /** - * Partition the output table created by [[create]], [[createOrReplace]], or [[replace]] using + * Partition the output table created by `create`, `createOrReplace`, or `replace` using * the given columns or transforms. * * When specified, the table data will be stored by these values for efficient reads. From 6a5250932acfb4e247aaee6ab1b6940a7cc3b377 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 7 Aug 2019 15:40:48 -0700 Subject: [PATCH 04/10] Fix comments from reviewers. --- .../spark/sql/catalyst/expressions/PartitionTransforms.scala | 2 +- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala index ab18c4a7cb1b3..e48fd8adaef09 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types.{DataType, IntegerType} * These expressions are used to pass transformations from the DataFrame API: * * {{{ - * df.writeTo("catalog.db.table").partitionBy($"category", days($"timestamp")).create() + * df.writeTo("catalog.db.table").partitionedBy($"category", days($"timestamp")).create() * }}} */ abstract class PartitionTransformExpression extends Expression with Unevaluable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 8619849f399a0..23360df04594b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3213,7 +3213,7 @@ class Dataset[T] private[sql]( // TODO: streaming could be adapted to use this interface if (isStreaming) { logicalPlan.failAnalysis( - "'write' can not be called on streaming Dataset/DataFrame") + "'writeTo' can not be called on streaming Dataset/DataFrame") } new DataFrameWriterV2[T](table, this) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4ab0b5a5793d5..92f4a8fcfeb23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -69,6 +69,7 @@ import org.apache.spark.util.Utils * @groupname window_funcs Window functions * @groupname string_funcs String functions * @groupname collection_funcs Collection functions + * @groupname partition_transforms Partition transform functions * @groupname Ungrouped Support functions for DataFrames * @since 1.3.0 */ From f545692f9c601a2b543151125cb248085da21bd3 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Tue, 20 Aug 2019 16:27:00 -0700 Subject: [PATCH 05/10] Update tests for recent changes in master. --- .../sources/v2/DataFrameWriterV2Suite.scala | 57 ++++++++++--------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala index 2bf0451e99246..ec51cb240b9a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala @@ -22,16 +22,20 @@ import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{AnalysisException, QueryTest, Row} -import org.apache.spark.sql.catalog.v2.Identifier +import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier} import org.apache.spark.sql.catalog.v2.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} -class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with BeforeAndAfter { +class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with BeforeAndAfter { import org.apache.spark.sql.functions._ import testImplicits._ + private def catalog(name: String): CatalogPlugin = { + spark.sessionState.catalogManager.catalog(name) + } + before { spark.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName) @@ -42,7 +46,8 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before } after { - spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog].clearTables() + spark.sessionState.catalogManager.reset() + spark.sessionState.conf.clear() } test("Append: basic append") { @@ -223,7 +228,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before spark.table("testcat.table_name"), Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) - val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") @@ -241,7 +246,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before spark.table("testcat.table_name"), Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) - val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") @@ -259,7 +264,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before spark.table("testcat.table_name"), Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) - val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") @@ -273,7 +278,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before test("Create: asText") { spark.table("source").writeTo("testcat.table_name").asText.create() - val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") @@ -283,7 +288,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before test("Create: asCsv") { spark.table("source").writeTo("testcat.table_name").asCsv.create() - val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") @@ -293,7 +298,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before test("Create: asJson") { spark.table("source").writeTo("testcat.table_name").asJson.create() - val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") @@ -303,7 +308,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before test("Create: asParquet") { spark.table("source").writeTo("testcat.table_name").asParquet.create() - val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") @@ -313,7 +318,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before test("Create: asOrc") { spark.table("source").writeTo("testcat.table_name").asOrc.create() - val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") @@ -327,7 +332,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before spark.table("testcat.table_name"), Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) - val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") @@ -346,7 +351,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before .partitionedBy(years($"ts")) .create() - val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") @@ -361,7 +366,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before .partitionedBy(months($"ts")) .create() - val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") @@ -376,7 +381,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before .partitionedBy(days($"ts")) .create() - val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") @@ -391,7 +396,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before .partitionedBy(hours($"ts")) .create() - val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") @@ -405,7 +410,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before .partitionedBy(bucket(4, $"id")) .create() - val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") @@ -423,7 +428,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before assert(exc.getMessage.contains("table_name")) - val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) // table should not have been changed @@ -442,7 +447,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before spark.table("testcat.table_name"), Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) - val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) // validate the initial table @@ -459,7 +464,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before spark.table("testcat.table_name"), Seq(Row(4L, "d", "even"), Row(5L, "e", "odd"), Row(6L, "f", "even"))) - val replaced = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val replaced = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) // validate the replacement table @@ -480,7 +485,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before spark.table("testcat.table_name"), Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) - val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) // validate the initial table @@ -497,7 +502,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before spark.table("testcat.table_name"), Seq(Row(4L, "d", "even"), Row(5L, "e", "odd"), Row(6L, "f", "even"))) - val replaced = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val replaced = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) // validate the replacement table @@ -525,7 +530,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before spark.table("testcat.table_name"), Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f"))) - val replaced = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val replaced = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) // validate the replacement table @@ -546,7 +551,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before spark.table("testcat.table_name"), Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) - val table = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) // validate the initial table @@ -563,7 +568,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSQLContext with Before spark.table("testcat.table_name"), Seq(Row(4L, "d", "even"), Row(5L, "e", "odd"), Row(6L, "f", "even"))) - val replaced = spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] + val replaced = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] .loadTable(Identifier.of(Array(), "table_name")) // validate the replacement table From 6c0a98be7f7bd099539dbae9dd301627ace8e345 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 21 Aug 2019 10:02:48 -0700 Subject: [PATCH 06/10] Improve error message in bucket function. --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 92f4a8fcfeb23..3e3facea6af04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3986,7 +3986,7 @@ object functions { case lit @ Literal(_, IntegerType) => Bucket(lit, e.expr) case _ => - throw new AnalysisException(s"Invalid number of buckets: $numBuckets") + throw new AnalysisException(s"Invalid number of buckets: bucket($numBuckets, $e)") } } From 1bc6954e7d6176ecdc56fb459e2e5bce01e52f5e Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Thu, 22 Aug 2019 10:46:52 -0700 Subject: [PATCH 07/10] Move partitioning functions into a partitioning object. --- .../org/apache/spark/sql/functions.scala | 110 ++++++++++-------- .../sources/v2/DataFrameWriterV2Suite.scala | 10 +- 2 files changed, 64 insertions(+), 56 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3e3facea6af04..0ece755023ee2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3943,61 +3943,69 @@ object functions { */ def to_csv(e: Column): Column = to_csv(e, Map.empty[String, String].asJava) - /** - * A transform for timestamps and dates to partition data into years. - * - * @group partition_transforms - * @since 3.0.0 - */ - def years(e: Column): Column = withExpr { Years(e.expr) } - - /** - * A transform for timestamps and dates to partition data into months. - * - * @group partition_transforms - * @since 3.0.0 - */ - def months(e: Column): Column = withExpr { Months(e.expr) } - - /** - * A transform for timestamps and dates to partition data into days. - * - * @group partition_transforms - * @since 3.0.0 - */ - def days(e: Column): Column = withExpr { Days(e.expr) } - - /** - * A transform for timestamps to partition data into hours. - * - * @group partition_transforms - * @since 3.0.0 - */ - def hours(e: Column): Column = withExpr { Hours(e.expr) } + // turn off style check that object names must start with a capital letter + // scalastyle:off + object partitioning { + // scalastyle:on + + /** + * A transform for timestamps and dates to partition data into years. + * + * @group partition_transforms + * @since 3.0.0 + */ + def years(e: Column): Column = withExpr { Years(e.expr) } + + /** + * A transform for timestamps and dates to partition data into months. + * + * @group partition_transforms + * @since 3.0.0 + */ + def months(e: Column): Column = withExpr { Months(e.expr) } + + /** + * A transform for timestamps and dates to partition data into days. + * + * @group partition_transforms + * @since 3.0.0 + */ + def days(e: Column): Column = withExpr { Days(e.expr) } + + /** + * A transform for timestamps to partition data into hours. + * + * @group partition_transforms + * @since 3.0.0 + */ + def hours(e: Column): Column = withExpr { Hours(e.expr) } + + /** + * A transform for any type that partitions by a hash of the input column. + * + * @group partition_transforms + * @since 3.0.0 + */ + def bucket(numBuckets: Column, e: Column): Column = withExpr { + numBuckets.expr match { + case lit @ Literal(_, IntegerType) => + Bucket(lit, e.expr) + case _ => + throw new AnalysisException(s"Invalid number of buckets: bucket($numBuckets, $e)") + } + } - /** - * A transform for any type that partitions by a hash of the input column. - * - * @group partition_transforms - * @since 3.0.0 - */ - def bucket(numBuckets: Column, e: Column): Column = withExpr { - numBuckets.expr match { - case lit @ Literal(_, IntegerType) => - Bucket(lit, e.expr) - case _ => - throw new AnalysisException(s"Invalid number of buckets: bucket($numBuckets, $e)") + /** + * A transform for any type that partitions by a hash of the input column. + * + * @group partition_transforms + * @since 3.0.0 + */ + def bucket(numBuckets: Int, e: Column): Column = withExpr { + Bucket(Literal(numBuckets), e.expr) } } - /** - * A transform for any type that partitions by a hash of the input column. - * - * @group partition_transforms - * @since 3.0.0 - */ - def bucket(numBuckets: Int, e: Column): Column = withExpr { Bucket(Literal(numBuckets), e.expr) } - // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala index ec51cb240b9a5..07383960a7a9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala @@ -348,7 +348,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) .writeTo("testcat.table_name") .tableProperty("allow-unsupported-transforms", "true") - .partitionedBy(years($"ts")) + .partitionedBy(partitioning.years($"ts")) .create() val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] @@ -363,7 +363,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) .writeTo("testcat.table_name") .tableProperty("allow-unsupported-transforms", "true") - .partitionedBy(months($"ts")) + .partitionedBy(partitioning.months($"ts")) .create() val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] @@ -378,7 +378,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) .writeTo("testcat.table_name") .tableProperty("allow-unsupported-transforms", "true") - .partitionedBy(days($"ts")) + .partitionedBy(partitioning.days($"ts")) .create() val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] @@ -393,7 +393,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) .writeTo("testcat.table_name") .tableProperty("allow-unsupported-transforms", "true") - .partitionedBy(hours($"ts")) + .partitionedBy(partitioning.hours($"ts")) .create() val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] @@ -407,7 +407,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("source") .writeTo("testcat.table_name") .tableProperty("allow-unsupported-transforms", "true") - .partitionedBy(bucket(4, $"id")) + .partitionedBy(partitioning.bucket(4, $"id")) .create() val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] From e424c2c36ef2c38a689f90a12fe13ca4ff9a6098 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 30 Aug 2019 14:25:20 -0700 Subject: [PATCH 08/10] Remove asFormat methods. --- .../apache/spark/sql/DataFrameWriterV2.scala | 50 ------------------- .../sources/v2/DataFrameWriterV2Suite.scala | 50 ------------------- 2 files changed, 100 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index 2e9607822bbdd..57b212e6b9fe3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -358,56 +358,6 @@ trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] { */ def using(provider: String): CreateTableWriter[T] - /** - * Use the "csv" provider. - * - * This is equivalent to: - * {{{ - * using("csv") - * }}} - */ - def asCsv: CreateTableWriter[T] = using("csv") - - /** - * Use the "text" provider. - * - * This is equivalent to: - * {{{ - * using("text") - * }}} - */ - def asText: CreateTableWriter[T] = using("text") - - /** - * Use the "json" provider. - * - * This is equivalent to: - * {{{ - * using("json") - * }}} - */ - def asJson: CreateTableWriter[T] = using("json") - - /** - * Use the "parquet" provider. - * - * This is equivalent to: - * {{{ - * using("parquet") - * }}} - */ - def asParquet: CreateTableWriter[T] = using("parquet") - - /** - * Use the "orc" provider. - * - * This is equivalent to: - * {{{ - * using("orc") - * }}} - */ - def asOrc: CreateTableWriter[T] = using("orc") - /** * Add a table property. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala index 07383960a7a9e..9d497bf2df7a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala @@ -275,56 +275,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo assert(table.properties === Map("prop" -> "value").asJava) } - test("Create: asText") { - spark.table("source").writeTo("testcat.table_name").asText.create() - - val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) - - assert(table.name === "testcat.table_name") - assert(table.properties === Map("provider" -> "text").asJava) - } - - test("Create: asCsv") { - spark.table("source").writeTo("testcat.table_name").asCsv.create() - - val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) - - assert(table.name === "testcat.table_name") - assert(table.properties === Map("provider" -> "csv").asJava) - } - - test("Create: asJson") { - spark.table("source").writeTo("testcat.table_name").asJson.create() - - val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) - - assert(table.name === "testcat.table_name") - assert(table.properties === Map("provider" -> "json").asJava) - } - - test("Create: asParquet") { - spark.table("source").writeTo("testcat.table_name").asParquet.create() - - val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) - - assert(table.name === "testcat.table_name") - assert(table.properties === Map("provider" -> "parquet").asJava) - } - - test("Create: asOrc") { - spark.table("source").writeTo("testcat.table_name").asOrc.create() - - val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) - - assert(table.name === "testcat.table_name") - assert(table.properties === Map("provider" -> "orc").asJava) - } - test("Create: identity partitioned table") { spark.table("source").writeTo("testcat.table_name").partitionedBy($"id").create() From 57e6c5be49add3cc76f402a16ebe88c28fc07bed Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 30 Aug 2019 15:34:39 -0700 Subject: [PATCH 09/10] Update tests for InMemoryTableCatalog consolidation. --- .../sources/v2/DataFrameWriterV2Suite.scala | 61 +++++++------------ 1 file changed, 23 insertions(+), 38 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala index 9d497bf2df7a1..3e73b95423d33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala @@ -22,22 +22,24 @@ import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{AnalysisException, QueryTest, Row} -import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier} +import org.apache.spark.sql.catalog.v2.{ Identifier, TableCatalog} import org.apache.spark.sql.catalog.v2.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.connector.InMemoryTableCatalog import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with BeforeAndAfter { + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ import org.apache.spark.sql.functions._ import testImplicits._ - private def catalog(name: String): CatalogPlugin = { - spark.sessionState.catalogManager.catalog(name) + private def catalog(name: String): TableCatalog = { + spark.sessionState.catalogManager.catalog(name).asTableCatalog } before { - spark.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName) + spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") df.createOrReplaceTempView("source") @@ -228,8 +230,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("testcat.table_name"), Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) - val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") assert(table.schema === new StructType() @@ -246,8 +247,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("testcat.table_name"), Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) - val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") assert(table.schema === new StructType() @@ -264,8 +264,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("testcat.table_name"), Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) - val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") assert(table.schema === new StructType() @@ -282,8 +281,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("testcat.table_name"), Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) - val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") assert(table.schema === new StructType() @@ -301,8 +299,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo .partitionedBy(partitioning.years($"ts")) .create() - val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") assert(table.partitioning === Seq(YearsTransform(FieldReference("ts")))) @@ -316,8 +313,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo .partitionedBy(partitioning.months($"ts")) .create() - val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") assert(table.partitioning === Seq(MonthsTransform(FieldReference("ts")))) @@ -331,8 +327,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo .partitionedBy(partitioning.days($"ts")) .create() - val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") assert(table.partitioning === Seq(DaysTransform(FieldReference("ts")))) @@ -346,8 +341,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo .partitionedBy(partitioning.hours($"ts")) .create() - val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") assert(table.partitioning === Seq(HoursTransform(FieldReference("ts")))) @@ -360,8 +354,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo .partitionedBy(partitioning.bucket(4, $"id")) .create() - val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") assert(table.partitioning === @@ -378,8 +371,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo assert(exc.getMessage.contains("table_name")) - val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) // table should not have been changed assert(table.name === "testcat.table_name") @@ -397,8 +389,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("testcat.table_name"), Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) - val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) // validate the initial table assert(table.name === "testcat.table_name") @@ -414,8 +405,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("testcat.table_name"), Seq(Row(4L, "d", "even"), Row(5L, "e", "odd"), Row(6L, "f", "even"))) - val replaced = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) + val replaced = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) // validate the replacement table assert(replaced.name === "testcat.table_name") @@ -435,8 +425,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("testcat.table_name"), Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) - val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) // validate the initial table assert(table.name === "testcat.table_name") @@ -452,8 +441,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("testcat.table_name"), Seq(Row(4L, "d", "even"), Row(5L, "e", "odd"), Row(6L, "f", "even"))) - val replaced = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) + val replaced = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) // validate the replacement table assert(replaced.name === "testcat.table_name") @@ -480,8 +468,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("testcat.table_name"), Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f"))) - val replaced = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) + val replaced = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) // validate the replacement table assert(replaced.name === "testcat.table_name") @@ -501,8 +488,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("testcat.table_name"), Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) - val table = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) // validate the initial table assert(table.name === "testcat.table_name") @@ -518,8 +504,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("testcat.table_name"), Seq(Row(4L, "d", "even"), Row(5L, "e", "odd"), Row(6L, "f", "even"))) - val replaced = catalog("testcat").asInstanceOf[TestInMemoryTableCatalog] - .loadTable(Identifier.of(Array(), "table_name")) + val replaced = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) // validate the replacement table assert(replaced.name === "testcat.table_name") From 9864d42501feff1acd01e11aedd2a7dc84a88bd5 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 30 Aug 2019 19:18:32 -0700 Subject: [PATCH 10/10] Update test cases for CTAS with nullable schemas. --- .../sources/v2/DataFrameWriterV2Suite.scala | 32 +++++++------------ 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala index 3e73b95423d33..810a192f331d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala @@ -233,9 +233,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") - assert(table.schema === new StructType() - .add("id", LongType, nullable = false) - .add("data", StringType)) + assert(table.schema === new StructType().add("id", LongType).add("data", StringType)) assert(table.partitioning.isEmpty) assert(table.properties.isEmpty) } @@ -250,9 +248,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") - assert(table.schema === new StructType() - .add("id", LongType, nullable = false) - .add("data", StringType)) + assert(table.schema === new StructType().add("id", LongType).add("data", StringType)) assert(table.partitioning.isEmpty) assert(table.properties === Map("provider" -> "foo").asJava) } @@ -267,9 +263,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") - assert(table.schema === new StructType() - .add("id", LongType, nullable = false) - .add("data", StringType)) + assert(table.schema === new StructType().add("id", LongType).add("data", StringType)) assert(table.partitioning.isEmpty) assert(table.properties === Map("prop" -> "value").asJava) } @@ -284,9 +278,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) assert(table.name === "testcat.table_name") - assert(table.schema === new StructType() - .add("id", LongType, nullable = false) - .add("data", StringType)) + assert(table.schema === new StructType().add("id", LongType).add("data", StringType)) assert(table.partitioning === Seq(IdentityTransform(FieldReference("id")))) assert(table.properties.isEmpty) } @@ -410,9 +402,9 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo // validate the replacement table assert(replaced.name === "testcat.table_name") assert(replaced.schema === new StructType() - .add("id", LongType, nullable = false) + .add("id", LongType) .add("data", StringType) - .add("even_or_odd", StringType, nullable = false)) + .add("even_or_odd", StringType)) assert(replaced.partitioning.isEmpty) assert(replaced.properties.isEmpty) } @@ -446,9 +438,9 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo // validate the replacement table assert(replaced.name === "testcat.table_name") assert(replaced.schema === new StructType() - .add("id", LongType, nullable = false) + .add("id", LongType) .add("data", StringType) - .add("even_or_odd", StringType, nullable = false)) + .add("even_or_odd", StringType)) assert(replaced.partitioning === Seq(IdentityTransform(FieldReference("id")))) assert(replaced.properties.isEmpty) } @@ -472,9 +464,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo // validate the replacement table assert(replaced.name === "testcat.table_name") - assert(replaced.schema === new StructType() - .add("id", LongType, nullable = false) - .add("data", StringType)) + assert(replaced.schema === new StructType().add("id", LongType).add("data", StringType)) assert(replaced.partitioning.isEmpty) assert(replaced.properties.isEmpty) } @@ -509,9 +499,9 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo // validate the replacement table assert(replaced.name === "testcat.table_name") assert(replaced.schema === new StructType() - .add("id", LongType, nullable = false) + .add("id", LongType) .add("data", StringType) - .add("even_or_odd", StringType, nullable = false)) + .add("even_or_odd", StringType)) assert(replaced.partitioning.isEmpty) assert(replaced.properties.isEmpty) }