diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocolV2.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocolV2.scala new file mode 100644 index 0000000000000..0678cfde8e49a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocolV2.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import org.apache.hadoop.mapreduce.TaskAttemptContext + +abstract class FileCommitProtocolV2 extends FileCommitProtocol { + + @deprecated("use newTaskTempFileV2", "3.1.0") + override def newTaskTempFile( + taskContext: TaskAttemptContext, dir: Option[String], ext: String): String + + @deprecated("use newTaskTempFileAbsPathV2", "3.1.0") + override def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String + + /** + * Notifies the commit protocol to add a new file, and gets back the full path that should be + * used. Must be called on the executors when running tasks. + * + * Note that the returned temp file may have an arbitrary path. The commit protocol only + * promises that the file will be at the location specified by the arguments after job commit. + * + * A full file path consists of the following parts: + * 1. the base path + * 2. the relative file path + * + * The "relativeFilePath" parameter specifies 2. The base path is left to the commit protocol + * implementation to decide. + * + * Important: it is the caller's responsibility to add uniquely identifying content to + * "relativeFilePath" if a task is going to write out multiple files to the same dir. The file + * commit protocol only guarantees that files written by different tasks will not conflict. + */ + def newTaskTempFileV2(taskContext: TaskAttemptContext, relativeFilePath: String): String + + /** + * Similar to newTaskTempFileV2(), but allows files to committed to an absolute output location. + * Depending on the implementation, there may be weaker guarantees around adding files this way. + * + * Important: it is the caller's responsibility to add uniquely identifying content to + * "absoluteFilePath" if a task is going to write out multiple files to the same dir. The file + * commit protocol only guarantees that files written by different tasks will not conflict. + */ + def newTaskTempFileAbsPathV2(taskContext: TaskAttemptContext, absoluteFilePath: String): String +} + +object FileCommitProtocolV2 { + + final def getFilename( + taskContext: TaskAttemptContext, jobId: String, prefix: String, ext: String): String = { + // The file name looks like part-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003-c000.parquet + // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, + // the file name is fine and won't overflow. + val split = taskContext.getTaskAttemptID.getTaskID.getId + f"${prefix}part-$split%05d-$jobId$ext" + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 11ce608f52ee2..f6b0dbe766f7e 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -53,9 +53,10 @@ class HadoopMapReduceCommitProtocol( jobId: String, path: String, dynamicPartitionOverwrite: Boolean = false) - extends FileCommitProtocol with Serializable with Logging { + extends FileCommitProtocolV2 with Serializable with Logging { import FileCommitProtocol._ + import FileCommitProtocolV2._ /** OutputCommitter from Hadoop is not serializable so marking it transient. */ @transient private var committer: OutputCommitter = _ @@ -101,9 +102,40 @@ class HadoopMapReduceCommitProtocol( format.getOutputCommitter(context) } + override def newTaskTempFileV2( + taskContext: TaskAttemptContext, relativeFilePath: String): String = { + val stagingDir: Path = committer match { + case _ if dynamicPartitionOverwrite => + val dir = new Path(relativeFilePath).getParent + assert(dir != null, + "The dataset to be written must be partitioned when dynamicPartitionOverwrite is true.") + partitionPaths += dir.toString + this.stagingDir + // For FileOutputCommitter it has its own staging path called "work path". + case f: FileOutputCommitter => + new Path(Option(f.getWorkPath).map(_.toString).getOrElse(path)) + case _ => new Path(path) + } + + new Path(stagingDir, relativeFilePath).toString + } + + override def newTaskTempFileAbsPathV2( + taskContext: TaskAttemptContext, absoluteFilePath: String): String = { + val filename = new Path(absoluteFilePath).getName + val absOutputPath = new Path(absoluteFilePath).toString + + // Include a UUID here to prevent file collisions for one task writing to different dirs. + // In principle we could include hash(absoluteDir) instead but this is simpler. + val tmpOutputPath = new Path(stagingDir, UUID.randomUUID().toString + "-" + filename).toString + + addedAbsPathFiles(tmpOutputPath) = absOutputPath + tmpOutputPath + } + override def newTaskTempFile( taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { - val filename = getFilename(taskContext, ext) + val filename = getFilename(taskContext, jobId, "", ext) val stagingDir: Path = committer match { case _ if dynamicPartitionOverwrite => @@ -126,7 +158,7 @@ class HadoopMapReduceCommitProtocol( override def newTaskTempFileAbsPath( taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { - val filename = getFilename(taskContext, ext) + val filename = getFilename(taskContext, jobId, "", ext) val absOutputPath = new Path(absoluteDir, filename).toString // Include a UUID here to prevent file collisions for one task writing to different dirs. @@ -137,14 +169,6 @@ class HadoopMapReduceCommitProtocol( tmpOutputPath } - protected def getFilename(taskContext: TaskAttemptContext, ext: String): String = { - // The file name looks like part-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003-c000.parquet - // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, - // the file name is fine and won't overflow. - val split = taskContext.getTaskAttemptID.getTaskID.getId - f"part-$split%05d-$jobId$ext" - } - override def setupJob(jobContext: JobContext): Unit = { // Setup IDs val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0) diff --git a/hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala b/hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala index 2ca50878485c0..ff8741bd88f1b 100644 --- a/hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala +++ b/hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.TaskAttemptContext import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, PathOutputCommitter, PathOutputCommitterFactory} +import org.apache.spark.internal.io.FileCommitProtocolV2._ import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol /** @@ -134,7 +135,7 @@ class PathOutputCommitProtocol( val parent = dir.map { d => new Path(workDir, d) }.getOrElse(workDir) - val file = new Path(parent, getFilename(taskContext, ext)) + val file = new Path(parent, getFilename(taskContext, jobId, "", ext)) logTrace(s"Creating task file $file for dir $dir and ext $ext") file.toString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 44636beeec7fc..fc18fe742d609 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -260,7 +261,7 @@ case class FileSourceScanExec( // exposed for testing lazy val bucketedScan: Boolean = { if (relation.sparkSession.sessionState.conf.bucketingEnabled && relation.bucketSpec.isDefined - && !disableBucketedScan) { + && !disableBucketedScan && !DDLUtils.isHiveTable(relation.options.get(DDLUtils.PROVIDER))) { val spec = relation.bucketSpec.get val bucketColumns = spec.bucketColumnNames.flatMap(n => toAttribute(n)) bucketColumns.size == spec.bucketColumnNames.size diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index d550fe270c753..6e2d73ac4c83b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -837,7 +837,9 @@ case class AlterTableSetLocationCommand( object DDLUtils { + val PROVIDER = "provider" val HIVE_PROVIDER = "hive" + val HIVE_VERSION = "hive_version" def isHiveTable(table: CatalogTable): Boolean = { isHiveTable(table.provider) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index edb49d3f90ca3..11c2d25759d9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.TaskAttemptContext -import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.internal.io.{FileCommitProtocol, FileCommitProtocolV2} import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec @@ -165,6 +165,10 @@ class DynamicPartitionDataWriter( |WriteJobDescription: $description """.stripMargin) + /** Flag saying whether or not to use [[FileCommitProtocolV2]]. */ + private val isFileCommitProtocolV2 = committer.isInstanceOf[FileCommitProtocolV2] && + description.bucketFileNamePrefix.isDefined + private var fileCounter: Int = _ private var recordsInFile: Long = _ private var currentPartionValues: Option[UnsafeRow] = None @@ -229,11 +233,39 @@ class DynamicPartitionDataWriter( val customPath = partDir.flatMap { dir => description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) } - val currentPath = if (customPath.isDefined) { - committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) - } else { - committer.newTaskTempFile(taskAttemptContext, partDir, ext) - } + + val currentPath = + if (isFileCommitProtocolV2) { + val fileNamePrefix = (description.bucketFileNamePrefix, bucketId) match { + case (Some(prefix), Some(id)) => Some(prefix(id)) + case _ => None + } + + (committer, fileNamePrefix) match { + case (c: FileCommitProtocolV2, Some(prefix)) => + val fileName = FileCommitProtocolV2.getFilename( + taskAttemptContext, description.uuid, prefix, ext) + if (customPath.isDefined) { + val absoluteFilePath = new Path(customPath.get, fileName).toString + c.newTaskTempFileAbsPathV2(taskAttemptContext, absoluteFilePath) + } else { + val relativeFilePath = partDir match { + case Some(dir) => new Path(dir, fileName).toString + case None => fileName + } + c.newTaskTempFileV2(taskAttemptContext, relativeFilePath) + } + case c => + throw new IllegalArgumentException( + s"DynamicPartitionDataWriter should not take $c as the file commit protocol") + } + } else { + if (customPath.isDefined) { + committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) + } else { + committer.newTaskTempFile(taskAttemptContext, partDir, ext) + } + } currentWriter = description.outputWriterFactory.newInstance( path = currentPath, @@ -286,6 +318,7 @@ class WriteJobDescription( val dataColumns: Seq[Attribute], val partitionColumns: Seq[Attribute], val bucketIdExpression: Option[Expression], + val bucketFileNamePrefix: Option[Int => String], val path: String, val customPartitionLocations: Map[TablePartitionSpec, String], val maxRecordsPerFile: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index a71aeb47872ce..ebc5684e50bd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StringType import org.apache.spark.unsafe.types.UTF8String @@ -69,6 +70,17 @@ object FileFormatWriter extends Logging { } } + /** + * A function that gets bucket file name prefix given bucket id. + * The new bucket file name is following Hive and Presto conversion, so this makes sure + * Hive bucketed table written by Spark, can be read by other SQL engines like Hive and Presto. + * + * Hive bucketing naming: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`. + * Presto bucketing naming (prestosql here): + * `io.prestosql.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`. + */ + def compatibleBucketFileNamePrefix(bucketId: Int): String = f"$bucketId%05d_0_" + /** * Basic work flow of this command is: * 1. Driver side setup, including output committer initialization and data source specific @@ -113,12 +125,32 @@ object FileFormatWriter extends Logging { } val empty2NullPlan = if (needConvert) ProjectExec(projectList, plan) else plan - val bucketIdExpression = bucketSpec.map { spec => + var bucketFileNamePrefix: Option[Int => String] = None + val bucketIdExpression = bucketSpec.flatMap { spec => val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) - // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can - // guarantee the data distribution is same between shuffle and bucketed data source, which - // enables us to only shuffle one side when join a bucketed table and a normal one. - HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression + if (DDLUtils.isHiveTable(options.get(DDLUtils.PROVIDER))) { + val hiveVersion = options.getOrElse(DDLUtils.HIVE_VERSION, "") + val hiveVersion012 = Seq("0.", "1.", "2.") + if (hiveVersion012.exists(hiveVersion.startsWith)) { + bucketFileNamePrefix = Some(compatibleBucketFileNamePrefix) + // For Hive bucketed table, use `HiveHash` and bitwise-and as our bucket id expression. + // Without the extra bitwise-and operation, we can get wrong bucket id when hash value + // of columns is negative. See Hive implementation in + // `org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`. + val bucketId = HiveHash(bucketColumns) + val bucketIdAfterAnd = BitwiseAnd(bucketId, Literal(Int.MaxValue)) + Some(Pmod(bucketIdAfterAnd, Literal(spec.numBuckets))) + } else { + // TODO(SPARK-32710/32711): Write Hive 3.x ORC/Parquet bucketed table + None + } + } else { + // For Spark data source bucketed table, use `HashPartitioning.partitionIdExpression` + // as our bucket id expression, so that we can guarantee the data distribution is same + // between shuffle and bucketed data source, which enables us to only shuffle one side + // when join a bucketed table and a normal one. + Some(HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression) + } } val sortColumns = bucketSpec.toSeq.flatMap { spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) @@ -140,6 +172,7 @@ object FileFormatWriter extends Logging { dataColumns = dataColumns, partitionColumns = partitionColumns, bucketIdExpression = bucketIdExpression, + bucketFileNamePrefix = bucketFileNamePrefix, path = outputSpec.outputPath, customPartitionLocations = outputSpec.customPartitionLocations, maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala index cd62ee7814bf2..023ecf3f9c26c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala @@ -130,6 +130,7 @@ abstract class FileWriteBuilder( dataColumns = allColumns, partitionColumns = Seq.empty, bucketIdExpression = None, + bucketFileNamePrefix = None, path = pathName, customPartitionLocations = Map.empty, maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index a410f32d4af7e..f4d6a98e81630 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.sources import java.io.File -import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.functions._ @@ -136,29 +136,37 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") } - def tableDir: File = { - val identifier = spark.sessionState.sqlParser.parseTableIdentifier("bucketed_table") + def tableDir(table: String = "bucketed_table"): File = { + val identifier = spark.sessionState.sqlParser.parseTableIdentifier(table) new File(spark.sessionState.catalog.defaultTablePath(identifier)) } + private def bucketIdExpression(expressions: Seq[Expression], numBuckets: Int): Expression = + HashPartitioning(expressions, numBuckets).partitionIdExpression + /** * A helper method to check the bucket write functionality in low level, i.e. check the written * bucket files to see if the data are correct. User should pass in a data dir that these bucket * files are written to, and the format of data(parquet, json, etc.), and the bucketing * information. */ - private def testBucketing( + protected def testBucketing( dataDir: File, source: String, numBuckets: Int, bucketCols: Seq[String], - sortCols: Seq[String] = Nil): Unit = { + sortCols: Seq[String] = Nil, + inputDF: DataFrame = df, + bucketIdExpression: (Seq[Expression], Int) => Expression = bucketIdExpression, + getBucketIdFromFileName: String => Option[Int] = BucketingUtils.getBucketId) + : Unit = { + val allBucketFiles = dataDir.listFiles().filterNot(f => f.getName.startsWith(".") || f.getName.startsWith("_") ) for (bucketFile <- allBucketFiles) { - val bucketId = BucketingUtils.getBucketId(bucketFile.getName).getOrElse { + val bucketId = getBucketIdFromFileName(bucketFile.getName).getOrElse { fail(s"Unable to find the related bucket files.") } @@ -167,7 +175,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { val selectedColumns = (bucketCols ++ sortCols).distinct // We may lose the type information after write(e.g. json format doesn't keep schema // information), here we get the types from the original dataframe. - val types = df.select(selectedColumns.map(col): _*).schema.map(_.dataType) + val types = inputDF.select(selectedColumns.map(col): _*).schema.map(_.dataType) val columns = selectedColumns.zip(types).map { case (colName, dt) => col(colName).cast(dt) } @@ -188,7 +196,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { val qe = readBack.select(bucketCols.map(col): _*).queryExecution val rows = qe.toRdd.map(_.copy()).collect() val getBucketId = UnsafeProjection.create( - HashPartitioning(qe.analyzed.output, numBuckets).partitionIdExpression :: Nil, + bucketIdExpression(qe.analyzed.output, numBuckets) :: Nil, qe.analyzed.output) for (row <- rows) { @@ -208,7 +216,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { .saveAsTable("bucketed_table") for (i <- 0 until 5) { - testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j", "k")) + testBucketing(new File(tableDir(), s"i=$i"), source, 8, Seq("j", "k")) } } } @@ -225,7 +233,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { .saveAsTable("bucketed_table") for (i <- 0 until 5) { - testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j"), Seq("k")) + testBucketing(new File(tableDir(), s"i=$i"), source, 8, Seq("j"), Seq("k")) } } } @@ -255,7 +263,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { .bucketBy(8, "i", "j") .saveAsTable("bucketed_table") - testBucketing(tableDir, source, 8, Seq("i", "j")) + testBucketing(tableDir(), source, 8, Seq("i", "j")) } } } @@ -269,7 +277,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { .sortBy("k") .saveAsTable("bucketed_table") - testBucketing(tableDir, source, 8, Seq("i", "j"), Seq("k")) + testBucketing(tableDir(), source, 8, Seq("i", "j"), Seq("k")) } } } @@ -286,7 +294,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { .saveAsTable("bucketed_table") for (i <- 0 until 5) { - testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j", "k")) + testBucketing(new File(tableDir(), s"i=$i"), source, 8, Seq("j", "k")) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index a89243c331c7b..4d6737a0b226b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} import org.apache.spark.sql.internal.SQLConf @@ -167,6 +168,18 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val tablePath = new Path(relation.tableMeta.location) val fileFormat = fileFormatClass.getConstructor().newInstance() + val hiveVersion012 = Seq("0.", "1.", "2.") + val hiveVersion = SQLConf.get.getConf(HiveUtils.HIVE_METASTORE_VERSION) + val bucketSpec = + if (hiveVersion012.exists(hiveVersion.startsWith)) { + relation.tableMeta.bucketSpec + } else { + // TODO(SPARK-32710/32711): Write Hive 3.x ORC/Parquet bucketed table + None + } + val optionsWithHiveInfo = options.updated(DDLUtils.PROVIDER, DDLUtils.HIVE_PROVIDER) + .updated(DDLUtils.HIVE_VERSION, SQLConf.get.getConf(HiveUtils.HIVE_METASTORE_VERSION)) + val result = if (relation.isPartitioned) { val partitionSchema = relation.tableMeta.partitionSchema val rootPaths: Seq[Path] = if (lazyPruningEnabled) { @@ -211,12 +224,13 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log // Spark SQL's data source table now support static and dynamic partition insert. Source // table converted from Hive table should always use dynamic. - val enableDynamicPartition = options.updated("partitionOverwriteMode", "dynamic") + val enableDynamicPartition = optionsWithHiveInfo.updated( + "partitionOverwriteMode", "dynamic") val fsRelation = HadoopFsRelation( location = fileIndex, partitionSchema = partitionSchema, dataSchema = updatedTable.dataSchema, - bucketSpec = None, + bucketSpec = bucketSpec, fileFormat = fileFormat, options = enableDynamicPartition)(sparkSession = sparkSession) val created = LogicalRelation(fsRelation, updatedTable) @@ -243,8 +257,8 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log sparkSession = sparkSession, paths = rootPath.toString :: Nil, userSpecifiedSchema = Option(updatedTable.dataSchema), - bucketSpec = None, - options = options, + bucketSpec = bucketSpec, + options = optionsWithHiveInfo, className = fileType).resolveRelation(), table = updatedTable) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala index bdbdcc2951072..931d92251b44d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteWithHiveSupportSuite.scala @@ -17,6 +17,11 @@ package org.apache.spark.sql.sources +import java.io.File + +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, Expression, HiveHash, Literal, Pmod} +import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION @@ -27,4 +32,44 @@ class BucketedWriteWithHiveSupportSuite extends BucketedWriteSuite with TestHive } override protected def fileFormatsToTest: Seq[String] = Seq("parquet", "orc") + + test("write hive bucketed table") { + def bucketIdExpression(expressions: Seq[Expression], numBuckets: Int): Expression = + Pmod(BitwiseAnd(HiveHash(expressions), Literal(Int.MaxValue)), Literal(8)) + + def getBucketIdFromFileName(fileName: String): Option[Int] = { + val hiveBucketedFileName = """^(\d+)_0_.*$""".r + fileName match { + case hiveBucketedFileName(bucketId) => Some(bucketId.toInt) + case _ => None + } + } + + val table = "hive_bucketed_table" + withTable(table) { + sql( + s""" + |CREATE TABLE IF NOT EXISTS $table (i int, j string) + |PARTITIONED BY(k string) + |CLUSTERED BY (i, j) SORTED BY (i) INTO 8 BUCKETS + |STORED AS PARQUET + """.stripMargin) + + val df = + (0 until 50).map(i => (i % 13, i.toString, i % 5)).toDF("i", "j", "k") + df.write.mode(SaveMode.Overwrite).insertInto(table) + + for (k <- 0 until 5) { + testBucketing( + new File(tableDir(table), s"k=$k"), + "parquet", + 8, + Seq("i", "j"), + Seq("i"), + df, + bucketIdExpression, + getBucketIdFromFileName) + } + } + } }