From 5be4dff693afe5379e4fe0ece081b91b98cc7c9a Mon Sep 17 00:00:00 2001 From: Asher Saban Date: Fri, 23 Mar 2018 23:17:29 +0000 Subject: [PATCH 01/12] add bucket pruning functionality --- .../sql/execution/DataSourceScanExec.scala | 27 ++++- .../datasources/FileSourceStrategy.scala | 112 ++++++++++++++++++ 2 files changed, 134 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 08ff33afbba3d..bb96a795ddd73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.sources.{BaseRelation, Filter} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils +import org.apache.spark.util.collection.BitSet trait DataSourceScanExec extends LeafExecNode with CodegenSupport { val relation: BaseRelation @@ -151,6 +152,7 @@ case class RowDataSourceScanExec( * @param output Output attributes of the scan, including data attributes and partition attributes. * @param requiredSchema Required schema of the underlying relation, excluding partition columns. * @param partitionFilters Predicates to use for partition pruning. + * @param OptionalBucketSet Bucket ids for bucket pruning * @param dataFilters Filters on non-partition columns. * @param tableIdentifier identifier for the table in the metastore. */ @@ -159,6 +161,7 @@ case class FileSourceScanExec( output: Seq[Attribute], requiredSchema: StructType, partitionFilters: Seq[Expression], + OptionalBucketSet: Option[BitSet], dataFilters: Seq[Expression], override val tableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with ColumnarBatchScan { @@ -371,14 +374,27 @@ case class FileSourceScanExec( val hosts = getBlockHosts(getBlockLocations(f), 0, f.getLen) PartitionedFile(p.values, f.getPath.toUri.toString, 0, f.getLen, hosts) } - }.groupBy { f => - BucketingUtils - .getBucketId(new Path(f.filePath).getName) - .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}")) } + val prunedBucketed = if (OptionalBucketSet.isDefined) { + val bucketSet = OptionalBucketSet.get + bucketed.filter { + f => bucketSet.get( + BucketingUtils.getBucketId(new Path(f.filePath).getName) + .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}"))) + } + } else { + bucketed + } + + val filesGroupedToBuckets = prunedBucketed.groupBy { f => + BucketingUtils + .getBucketId(new Path(f.filePath).getName) + .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}")) + } + val filePartitions = Seq.tabulate(bucketSpec.numBuckets) { bucketId => - FilePartition(bucketId, bucketed.getOrElse(bucketId, Nil)) + FilePartition(bucketId, filesGroupedToBuckets.getOrElse(bucketId, Nil)) } new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions) @@ -503,6 +519,7 @@ case class FileSourceScanExec( output.map(QueryPlan.normalizeExprId(_, output)), requiredSchema, QueryPlan.normalizePredicates(partitionFilters, output), + OptionalBucketSet, QueryPlan.normalizePredicates(dataFilters, output), None) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 16b22717b8d92..bc2b724c6275f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.util.collection.BitSet /** * A strategy for planning scans over collections of files that might be partitioned or bucketed @@ -50,6 +53,107 @@ import org.apache.spark.sql.execution.SparkPlan * and add it. Proceed to the next file. */ object FileSourceStrategy extends Strategy with Logging { + + // should prune buckets iff num buckets is greater than 1 and there is only one bucket column + private def shouldPruneBuckets(bucketSpec: Option[BucketSpec]): Boolean = { + bucketSpec match { + case Some(spec) => spec.bucketColumnNames.length == 1 && spec.numBuckets > 1 + case None => false + } + } + + private def getExpressionBuckets(expr: Expression, + bucketColumnName: String, + numBuckets: Int): BitSet = { + + def getMatchedBucketBitSet(attr: Attribute, v: Any): BitSet = { + val matchedBuckets = new BitSet(numBuckets) + matchedBuckets.set(getBucketId(attr, numBuckets, v)) + matchedBuckets + } + + expr match { + case expressions.EqualTo(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => + getMatchedBucketBitSet(a, v) + case expressions.EqualTo(Literal(v, _), a: Attribute) if a.name == bucketColumnName => + getMatchedBucketBitSet(a, v) + case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => + getMatchedBucketBitSet(a, v) + case expressions.EqualNullSafe(Literal(v, _), a: Attribute) if a.name == bucketColumnName => + getMatchedBucketBitSet(a, v) + case expressions.In(a: Attribute, list) + if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => + val valuesSet = list.map(e => e.eval(EmptyRow)) + valuesSet + .map(v => getMatchedBucketBitSet(a, v)) + .fold(new BitSet(numBuckets))(_ | _) + case expressions.IsNull(a: Attribute) if a.name == bucketColumnName => + getMatchedBucketBitSet(a, null) + case expressions.And(left, right) => + getExpressionBuckets(left, bucketColumnName, numBuckets) | + getExpressionBuckets(right, bucketColumnName, numBuckets) + case expressions.Or(left, right) => + val leftBuckets = getExpressionBuckets(left, bucketColumnName, numBuckets) + val rightBuckets = getExpressionBuckets(right, bucketColumnName, numBuckets) + + // if some expression in OR condition requires all buckets, return an empty BitSet + if (leftBuckets.cardinality() == 0 || rightBuckets.cardinality() == 0) { + new BitSet(numBuckets) + } else { + // return a BitSet that includes all required buckets + leftBuckets | rightBuckets + } + case _ => new BitSet(numBuckets) + } + } + + private def getBuckets(normalizedFilters: Seq[Expression], + bucketSpec: BucketSpec): Option[BitSet] = { + + val bucketColumnName = bucketSpec.bucketColumnNames.head + val numBuckets = bucketSpec.numBuckets + // val matchedBuckets = new BitSet(numBuckets) + // matchedBuckets.clear() + + // TODO should be OR? + val matchedBuckets = normalizedFilters + .map(f => getExpressionBuckets(f, bucketColumnName, numBuckets)) + .fold(new BitSet(numBuckets))(_ | _) + + val numBucketsSelected = if (matchedBuckets.cardinality() != 0) matchedBuckets.cardinality() + else numBuckets + logInfo { + s"Pruned ${numBuckets - numBucketsSelected} out of $numBuckets buckets." + } + + // None means all the buckets need to be scanned + if (matchedBuckets.cardinality() == 0) None else Some(matchedBuckets) + } + + // // Get the bucket ID based on the bucketing values. + // // Restriction: Bucket pruning works iff the bucketing column has one and only one column. + // def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { + // val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType)) + // mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null) + // val bucketIdGeneration = UnsafeProjection.create( + // HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil, + // bucketColumn :: Nil) + // + // bucketIdGeneration(mutableRow).getInt(0) + // } + + // Given bucketColumn, numBuckets and value, returns the corresponding bucketId + // TODO replace with getBucketId from DataSourceStrategy + private def getBucketId(attr: Attribute, numBuckets: Int, value: Any): Int = { + val mutableInternalRow = new SpecificInternalRow(Seq(attr.dataType)) + mutableInternalRow.update(0, value) + + val bucketIdGenerator = UnsafeProjection.create( + HashPartitioning(Seq(attr), numBuckets).partitionIdExpression :: Nil, + attr :: Nil) + bucketIdGenerator(mutableInternalRow).getInt(0) + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) => @@ -79,6 +183,13 @@ object FileSourceStrategy extends Strategy with Logging { ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") + val bucketSpec: Option[BucketSpec] = fsRelation.bucketSpec + val bucketSet = if (shouldPruneBuckets(bucketSpec)) { + getBuckets(normalizedFilters, bucketSpec.get) + } else { + None + } + val dataColumns = l.resolve(fsRelation.dataSchema, fsRelation.sparkSession.sessionState.analyzer.resolver) @@ -108,6 +219,7 @@ object FileSourceStrategy extends Strategy with Logging { outputAttributes, outputSchema, partitionKeyFilters.toSeq, + bucketSet, dataFilters, table.map(_.identifier)) From 4ab1583d26ab039a8458784d0a93d6aa25077603 Mon Sep 17 00:00:00 2001 From: Asher Saban Date: Fri, 23 Mar 2018 23:33:25 +0000 Subject: [PATCH 02/12] add composite filters test cases and refactor pruning test --- .../spark/sql/sources/BucketedReadSuite.scala | 64 ++++++++++++++----- 1 file changed, 47 insertions(+), 17 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index fb61fa716b946..57699ef34bedb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -80,14 +80,14 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { } } + // TODO update description // To verify if the bucket pruning works, this function checks two conditions: // 1) Check if the pruned buckets (before filtering) are empty. // 2) Verify the final result is the same as the expected one - private def checkPrunedAnswers( - bucketSpec: BucketSpec, - bucketValues: Seq[Integer], - filterCondition: Column, - originalDataFrame: DataFrame): Unit = { + private def checkPrunedAnswers(bucketSpec: BucketSpec, + bucketValues: Seq[Integer], + filterCondition: Column, + originalDataFrame: DataFrame): Unit = { // This test verifies parts of the plan. Disable whole stage codegen. withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { val strategy = DataSourceStrategy(spark.sessionState.conf) @@ -97,25 +97,31 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { assert(bucketColumnNames.length == 1) val bucketColumnIndex = bucketedDataFrame.schema.fieldIndex(bucketColumnNames.head) val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex) - val matchedBuckets = new BitSet(numBuckets) - bucketValues.foreach { value => - matchedBuckets.set(strategy.getBucketId(bucketColumn, numBuckets, value)) - } // Filter could hide the bug in bucket pruning. Thus, skipping all the filters val plan = bucketedDataFrame.filter(filterCondition).queryExecution.executedPlan val rdd = plan.find(_.isInstanceOf[DataSourceScanExec]) assert(rdd.isDefined, plan) - val checkedResult = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => - if (matchedBuckets.get(index % numBuckets) && iter.nonEmpty) Iterator(index) else Iterator() + // if nothing should be pruned, skip the pruning test + if (bucketValues.nonEmpty) { + val matchedBuckets = new BitSet(numBuckets) + bucketValues.foreach { value => + matchedBuckets.set(strategy.getBucketId(bucketColumn, numBuckets, value)) + } + val invalidBuckets = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => + // return indexes of partitions that should have been pruned and are not empty + if (!matchedBuckets.get(index % numBuckets) && iter.nonEmpty) { + Iterator(index) + } else { + Iterator() + } + }.collect() + + if (invalidBuckets.nonEmpty) { + fail(s"Buckets $invalidBuckets should have been pruned from:\n$plan") + } } - // TODO: These tests are not testing the right columns. -// // checking if all the pruned buckets are empty -// val invalidBuckets = checkedResult.collect().toList -// if (invalidBuckets.nonEmpty) { -// fail(s"Buckets $invalidBuckets should have been pruned from:\n$plan") -// } checkAnswer( bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k"), @@ -229,6 +235,30 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { bucketValues = j :: Nil, filterCondition = $"j" === j && $"i" > j % 5, df) + + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(j, j + 1), + filterCondition = $"j" === j || $"j" === (j + 1), + df) + + checkPrunedAnswers( + bucketSpec, + bucketValues = Nil, + filterCondition = $"j" === j || $"i" === 0, + df) + + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(j), + filterCondition = ($"i" === 0 || $"k" > $"j") && $"j" === j, + df) + + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(j, j + 1), + filterCondition = $"j" === j || ($"j" === (j + 1) && $"i" === 0), + df) } } } From 3bb7a2eecf8c2d15e1abc728e8432be9e346e22a Mon Sep 17 00:00:00 2001 From: Asher Saban Date: Mon, 26 Mar 2018 21:52:56 +0300 Subject: [PATCH 03/12] remove redundant code and move shared getBucketId to BucketingUtils --- .../datasources/BucketingUtils.scala | 14 +++++++++ .../datasources/DataSourceStrategy.scala | 12 -------- .../datasources/FileSourceStrategy.scala | 29 +------------------ .../spark/sql/sources/BucketedReadSuite.scala | 18 +++++------- 4 files changed, 22 insertions(+), 51 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala index ea4fe9c8ade5f..a776fc3e7021d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning + object BucketingUtils { // The file name of bucketed data should have 3 parts: // 1. some other information in the head of file name @@ -35,5 +38,16 @@ object BucketingUtils { case other => None } + // Given bucketColumn, numBuckets and value, returns the corresponding bucketId + def getBucketIdFromValue(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { + val mutableInternalRow = new SpecificInternalRow(Seq(bucketColumn.dataType)) + mutableInternalRow.update(0, value) + + val bucketIdGenerator = UnsafeProjection.create( + HashPartitioning(Seq(bucketColumn), numBuckets).partitionIdExpression :: Nil, + bucketColumn :: Nil) + bucketIdGenerator(mutableInternalRow).getInt(0) + } + def bucketIdToString(id: Int): String = f"_$id%05d" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 3f41612c08065..7b129435c45db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -312,18 +312,6 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with case _ => Nil } - // Get the bucket ID based on the bucketing values. - // Restriction: Bucket pruning works iff the bucketing column has one and only one column. - def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { - val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType)) - mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null) - val bucketIdGeneration = UnsafeProjection.create( - HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil, - bucketColumn :: Nil) - - bucketIdGeneration(mutableRow).getInt(0) - } - // Based on Public API. private def pruneFilterProject( relation: LogicalRelation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index bc2b724c6275f..55b47e6c185f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -68,7 +68,7 @@ object FileSourceStrategy extends Strategy with Logging { def getMatchedBucketBitSet(attr: Attribute, v: Any): BitSet = { val matchedBuckets = new BitSet(numBuckets) - matchedBuckets.set(getBucketId(attr, numBuckets, v)) + matchedBuckets.set(BucketingUtils.getBucketIdFromValue(attr, numBuckets, v)) matchedBuckets } @@ -112,10 +112,7 @@ object FileSourceStrategy extends Strategy with Logging { val bucketColumnName = bucketSpec.bucketColumnNames.head val numBuckets = bucketSpec.numBuckets - // val matchedBuckets = new BitSet(numBuckets) - // matchedBuckets.clear() - // TODO should be OR? val matchedBuckets = normalizedFilters .map(f => getExpressionBuckets(f, bucketColumnName, numBuckets)) .fold(new BitSet(numBuckets))(_ | _) @@ -130,30 +127,6 @@ object FileSourceStrategy extends Strategy with Logging { if (matchedBuckets.cardinality() == 0) None else Some(matchedBuckets) } - // // Get the bucket ID based on the bucketing values. - // // Restriction: Bucket pruning works iff the bucketing column has one and only one column. - // def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { - // val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType)) - // mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null) - // val bucketIdGeneration = UnsafeProjection.create( - // HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil, - // bucketColumn :: Nil) - // - // bucketIdGeneration(mutableRow).getInt(0) - // } - - // Given bucketColumn, numBuckets and value, returns the corresponding bucketId - // TODO replace with getBucketId from DataSourceStrategy - private def getBucketId(attr: Attribute, numBuckets: Int, value: Any): Int = { - val mutableInternalRow = new SpecificInternalRow(Seq(attr.dataType)) - mutableInternalRow.update(0, value) - - val bucketIdGenerator = UnsafeProjection.create( - HashPartitioning(Seq(attr), numBuckets).partitionIdExpression :: Nil, - attr :: Nil) - bucketIdGenerator(mutableInternalRow).getInt(0) - } - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 57699ef34bedb..d5b964396d529 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -24,14 +24,14 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning -import org.apache.spark.sql.execution.{DataSourceScanExec, SortExec} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec +import org.apache.spark.sql.execution.{DataSourceScanExec, SortExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION -import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.test.{SQLTestUtils, SharedSQLContext} import org.apache.spark.util.Utils import org.apache.spark.util.collection.BitSet @@ -90,7 +90,6 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { originalDataFrame: DataFrame): Unit = { // This test verifies parts of the plan. Disable whole stage codegen. withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { - val strategy = DataSourceStrategy(spark.sessionState.conf) val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k") val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec // Limit: bucket pruning only works when the bucket column has one and only one column @@ -107,7 +106,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { if (bucketValues.nonEmpty) { val matchedBuckets = new BitSet(numBuckets) bucketValues.foreach { value => - matchedBuckets.set(strategy.getBucketId(bucketColumn, numBuckets, value)) + matchedBuckets.set(BucketingUtils.getBucketIdFromValue(bucketColumn, numBuckets, value)) } val invalidBuckets = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => // return indexes of partitions that should have been pruned and are not empty @@ -236,29 +235,26 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { filterCondition = $"j" === j && $"i" > j % 5, df) + // check multiple bucket values OR condition checkPrunedAnswers( bucketSpec, bucketValues = Seq(j, j + 1), filterCondition = $"j" === j || $"j" === (j + 1), df) + // check bucket value and none bucket value OR condition checkPrunedAnswers( bucketSpec, bucketValues = Nil, filterCondition = $"j" === j || $"i" === 0, df) + // check AND condition in complex expression checkPrunedAnswers( bucketSpec, bucketValues = Seq(j), filterCondition = ($"i" === 0 || $"k" > $"j") && $"j" === j, df) - - checkPrunedAnswers( - bucketSpec, - bucketValues = Seq(j, j + 1), - filterCondition = $"j" === j || ($"j" === (j + 1) && $"i" === 0), - df) } } } From c45da4b6dd523fd5f2eaa041a5f684ab80db02f4 Mon Sep 17 00:00:00 2001 From: Asher Saban Date: Tue, 27 Mar 2018 19:31:34 +0300 Subject: [PATCH 04/12] fix variable name and add select buckets count to metadata --- .../sql/execution/DataSourceScanExec.scala | 25 ++++++++++++++----- .../spark/sql/sources/BucketedReadSuite.scala | 1 - 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index bb96a795ddd73..de5322d1c5f89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -152,7 +152,7 @@ case class RowDataSourceScanExec( * @param output Output attributes of the scan, including data attributes and partition attributes. * @param requiredSchema Required schema of the underlying relation, excluding partition columns. * @param partitionFilters Predicates to use for partition pruning. - * @param OptionalBucketSet Bucket ids for bucket pruning + * @param optionalBucketSet Bucket ids for bucket pruning * @param dataFilters Filters on non-partition columns. * @param tableIdentifier identifier for the table in the metastore. */ @@ -161,7 +161,7 @@ case class FileSourceScanExec( output: Seq[Attribute], requiredSchema: StructType, partitionFilters: Seq[Expression], - OptionalBucketSet: Option[BitSet], + optionalBucketSet: Option[BitSet], dataFilters: Seq[Expression], override val tableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with ColumnarBatchScan { @@ -289,7 +289,20 @@ case class FileSourceScanExec( } getOrElse { metadata } - withOptPartitionCount + + val withSelectedBucketsCount = relation.bucketSpec.map { spec => + val numSelectedBuckets = optionalBucketSet.map { b => + b.cardinality() + } getOrElse { + spec.numBuckets + } + withOptPartitionCount + ("SelectedBucketsCount" -> + s"$numSelectedBuckets out of ${spec.numBuckets}") + } getOrElse { + withOptPartitionCount + } + + withSelectedBucketsCount } private lazy val inputRDD: RDD[InternalRow] = { @@ -376,8 +389,8 @@ case class FileSourceScanExec( } } - val prunedBucketed = if (OptionalBucketSet.isDefined) { - val bucketSet = OptionalBucketSet.get + val prunedBucketed = if (optionalBucketSet.isDefined) { + val bucketSet = optionalBucketSet.get bucketed.filter { f => bucketSet.get( BucketingUtils.getBucketId(new Path(f.filePath).getName) @@ -519,7 +532,7 @@ case class FileSourceScanExec( output.map(QueryPlan.normalizeExprId(_, output)), requiredSchema, QueryPlan.normalizePredicates(partitionFilters, output), - OptionalBucketSet, + optionalBucketSet, QueryPlan.normalizePredicates(dataFilters, output), None) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index d5b964396d529..f04cac6b24003 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -80,7 +80,6 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { } } - // TODO update description // To verify if the bucket pruning works, this function checks two conditions: // 1) Check if the pruned buckets (before filtering) are empty. // 2) Verify the final result is the same as the expected one From f0b84bd9231915f30f9643c083b2ab35cbbce472 Mon Sep 17 00:00:00 2001 From: Asher Saban Date: Tue, 27 Mar 2018 20:10:25 +0300 Subject: [PATCH 05/12] optimize imports --- .../spark/sql/execution/datasources/FileSourceStrategy.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 55b47e6c185f9..af11695931b48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -24,9 +24,7 @@ import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning -import org.apache.spark.sql.execution.FileSourceScanExec -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.util.collection.BitSet /** From cb36012f8aadc9e6c842e1ff143394dd5de71be5 Mon Sep 17 00:00:00 2001 From: Asher Saban Date: Tue, 27 Mar 2018 20:47:19 +0300 Subject: [PATCH 06/12] fix checkstyle errors --- .../org/apache/spark/sql/sources/BucketedReadSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index f04cac6b24003..82a88472c724d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -24,14 +24,14 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.execution.{DataSourceScanExec, SortExec} import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec -import org.apache.spark.sql.execution.{DataSourceScanExec, SortExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION -import org.apache.spark.sql.test.{SQLTestUtils, SharedSQLContext} +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.util.Utils import org.apache.spark.util.collection.BitSet From 8f6bc28752a5ef59d8914f4b37e9e47339feb04f Mon Sep 17 00:00:00 2001 From: Asher Saban Date: Thu, 29 Mar 2018 14:46:31 +0300 Subject: [PATCH 07/12] style guidelines --- .../sql/execution/datasources/FileSourceStrategy.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index af11695931b48..7ca2f58188d5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -115,8 +115,13 @@ object FileSourceStrategy extends Strategy with Logging { .map(f => getExpressionBuckets(f, bucketColumnName, numBuckets)) .fold(new BitSet(numBuckets))(_ | _) - val numBucketsSelected = if (matchedBuckets.cardinality() != 0) matchedBuckets.cardinality() - else numBuckets + val numBucketsSelected = if (matchedBuckets.cardinality() != 0) { + matchedBuckets.cardinality() + } + else { + numBuckets + } + logInfo { s"Pruned ${numBuckets - numBucketsSelected} out of $numBuckets buckets." } From 0a4aeab164337d0c7c5c2758fc1a3ebf998455d8 Mon Sep 17 00:00:00 2001 From: Asher Saban Date: Fri, 30 Mar 2018 17:06:00 +0300 Subject: [PATCH 08/12] if all column are required, return BitSet with 1's instead of 0's. --- .../datasources/FileSourceStrategy.scala | 43 +++++++++---------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 7ca2f58188d5f..8778e460c9efa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -88,46 +88,45 @@ object FileSourceStrategy extends Strategy with Logging { case expressions.IsNull(a: Attribute) if a.name == bucketColumnName => getMatchedBucketBitSet(a, null) case expressions.And(left, right) => - getExpressionBuckets(left, bucketColumnName, numBuckets) | + getExpressionBuckets(left, bucketColumnName, numBuckets) & getExpressionBuckets(right, bucketColumnName, numBuckets) case expressions.Or(left, right) => - val leftBuckets = getExpressionBuckets(left, bucketColumnName, numBuckets) - val rightBuckets = getExpressionBuckets(right, bucketColumnName, numBuckets) - - // if some expression in OR condition requires all buckets, return an empty BitSet - if (leftBuckets.cardinality() == 0 || rightBuckets.cardinality() == 0) { - new BitSet(numBuckets) - } else { - // return a BitSet that includes all required buckets - leftBuckets | rightBuckets - } - case _ => new BitSet(numBuckets) + getExpressionBuckets(left, bucketColumnName, numBuckets) | + getExpressionBuckets(right, bucketColumnName, numBuckets) + case _ => + val matchedBuckets = new BitSet(numBuckets) + matchedBuckets.setUntil(numBuckets) + matchedBuckets } } private def getBuckets(normalizedFilters: Seq[Expression], bucketSpec: BucketSpec): Option[BitSet] = { + if (normalizedFilters.isEmpty) { + return None + } + val bucketColumnName = bucketSpec.bucketColumnNames.head val numBuckets = bucketSpec.numBuckets - val matchedBuckets = normalizedFilters - .map(f => getExpressionBuckets(f, bucketColumnName, numBuckets)) - .fold(new BitSet(numBuckets))(_ | _) + val normalizedFiltersAndExpr = normalizedFilters + .reduce(expressions.And) + val matchedBuckets = getExpressionBuckets(normalizedFiltersAndExpr, bucketColumnName, + numBuckets) - val numBucketsSelected = if (matchedBuckets.cardinality() != 0) { - matchedBuckets.cardinality() - } - else { - numBuckets - } + val numBucketsSelected = matchedBuckets.cardinality() logInfo { s"Pruned ${numBuckets - numBucketsSelected} out of $numBuckets buckets." } // None means all the buckets need to be scanned - if (matchedBuckets.cardinality() == 0) None else Some(matchedBuckets) + if (numBucketsSelected == numBuckets) { + None + } else { + Some(matchedBuckets) + } } def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { From 697784d57642ffe2bda55e518a34265829e4ba2c Mon Sep 17 00:00:00 2001 From: Asher Saban Date: Fri, 30 Mar 2018 17:19:20 +0300 Subject: [PATCH 09/12] use bucket numbers that don't generate empty buckets on df and nullDF --- .../spark/sql/sources/BucketedReadSuite.scala | 49 +++++++++++++++++-- 1 file changed, 44 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 82a88472c724d..2d008efbced48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -52,6 +52,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { s <- Seq(null, "a", "b", "c", "d", "e", "f", null, "g") } yield (i % 5, s, i % 13)).toDF("i", "j", "k") + // number of buckets that doesn't yield empty buckets when bucketing on column j on df/nullDF + // empty buckets before filtering might hide bugs in pruning logic + private val NumBucketsForPruningDF = 7 + private val NumBucketsForPruningNullDf = 5 + test("read bucketed data") { withTable("bucketed_table") { df.write @@ -117,7 +122,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { }.collect() if (invalidBuckets.nonEmpty) { - fail(s"Buckets $invalidBuckets should have been pruned from:\n$plan") + fail(s"Buckets ${invalidBuckets.mkString(",")} should have been pruned from:\n$plan") } } @@ -129,7 +134,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { test("read partitioning bucketed tables with bucket pruning filters") { withTable("bucketed_table") { - val numBuckets = 8 + val numBuckets = NumBucketsForPruningDF val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) // json does not support predicate push-down, and thus json is used here df.write @@ -165,7 +170,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { test("read non-partitioning bucketed tables with bucket pruning filters") { withTable("bucketed_table") { - val numBuckets = 8 + val numBuckets = NumBucketsForPruningDF val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) // json does not support predicate push-down, and thus json is used here df.write @@ -185,7 +190,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { test("read partitioning bucketed tables having null in bucketing key") { withTable("bucketed_table") { - val numBuckets = 8 + val numBuckets = NumBucketsForPruningNullDf val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) // json does not support predicate push-down, and thus json is used here nullDF.write @@ -212,7 +217,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { test("read partitioning bucketed tables having composite filters") { withTable("bucketed_table") { - val numBuckets = 8 + val numBuckets = NumBucketsForPruningDF val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) // json does not support predicate push-down, and thus json is used here df.write @@ -258,6 +263,40 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { } } + test("read bucketed table without filters") { + withTable("bucketed_table") { + val numBuckets = NumBucketsForPruningDF + val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) + // json does not support predicate push-down, and thus json is used here + df.write + .format("json") + .bucketBy(numBuckets, "j") + .saveAsTable("bucketed_table") + + val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k") + val plan = bucketedDataFrame.queryExecution.executedPlan + val rdd = plan.find(_.isInstanceOf[DataSourceScanExec]) + assert(rdd.isDefined, plan) + + val emptyBuckets = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => + // return indexes of empty partitions + if (iter.isEmpty) { + Iterator(index) + } else { + Iterator() + } + }.collect() + + if (emptyBuckets.nonEmpty) { + fail(s"Buckets ${emptyBuckets.mkString(",")} should not have been pruned from:\n$plan") + } + + checkAnswer( + bucketedDataFrame.orderBy("i", "j", "k"), + df.orderBy("i", "j", "k")) + } + } + private lazy val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1") private lazy val df2 = From de9ecb64068d1cafa79ce61a2bf1795981a46f46 Mon Sep 17 00:00:00 2001 From: asaban Date: Wed, 6 Jun 2018 00:08:15 +0100 Subject: [PATCH 10/12] CR fixes --- .../datasources/FileSourceStrategy.scala | 51 +++++++++++-------- .../spark/sql/sources/BucketedReadSuite.scala | 18 +++++-- 2 files changed, 43 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 8778e460c9efa..8b352168be3e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -60,33 +60,40 @@ object FileSourceStrategy extends Strategy with Logging { } } - private def getExpressionBuckets(expr: Expression, - bucketColumnName: String, - numBuckets: Int): BitSet = { + private def getExpressionBuckets( + expr: Expression, + bucketColumnName: String, + numBuckets: Int): BitSet = { - def getMatchedBucketBitSet(attr: Attribute, v: Any): BitSet = { + def getBucketNumber(attr: Attribute, v: Any): Int = { + BucketingUtils.getBucketIdFromValue(attr, numBuckets, v) + } + + def getBucketSetFromIterable(attr: Attribute, iter: Iterable[Any]): BitSet = { val matchedBuckets = new BitSet(numBuckets) - matchedBuckets.set(BucketingUtils.getBucketIdFromValue(attr, numBuckets, v)) + iter + .map(v => getBucketNumber(attr, v)) + .foreach(bucketNum => matchedBuckets.set(bucketNum)) + matchedBuckets + } + + def getBucketSetFromValue(attr: Attribute, v: Any): BitSet = { + val matchedBuckets = new BitSet(numBuckets) + matchedBuckets.set(getBucketNumber(attr, v)) matchedBuckets } expr match { - case expressions.EqualTo(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => - getMatchedBucketBitSet(a, v) - case expressions.EqualTo(Literal(v, _), a: Attribute) if a.name == bucketColumnName => - getMatchedBucketBitSet(a, v) - case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => - getMatchedBucketBitSet(a, v) - case expressions.EqualNullSafe(Literal(v, _), a: Attribute) if a.name == bucketColumnName => - getMatchedBucketBitSet(a, v) + case expressions.Equality(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => + getBucketSetFromValue(a, v) case expressions.In(a: Attribute, list) if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => - val valuesSet = list.map(e => e.eval(EmptyRow)) - valuesSet - .map(v => getMatchedBucketBitSet(a, v)) - .fold(new BitSet(numBuckets))(_ | _) + getBucketSetFromIterable(a, list.map(e => e.eval(EmptyRow))) + case expressions.InSet(a: Attribute, hset) + if hset.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => + getBucketSetFromIterable(a, hset.map(e => expressions.Literal(e).eval(EmptyRow))) case expressions.IsNull(a: Attribute) if a.name == bucketColumnName => - getMatchedBucketBitSet(a, null) + getBucketSetFromValue(a, null) case expressions.And(left, right) => getExpressionBuckets(left, bucketColumnName, numBuckets) & getExpressionBuckets(right, bucketColumnName, numBuckets) @@ -100,9 +107,9 @@ object FileSourceStrategy extends Strategy with Logging { } } - private def getBuckets(normalizedFilters: Seq[Expression], - bucketSpec: BucketSpec): Option[BitSet] = { - + private def genBucketSet( + normalizedFilters: Seq[Expression], + bucketSpec: BucketSpec): Option[BitSet] = { if (normalizedFilters.isEmpty) { return None } @@ -160,7 +167,7 @@ object FileSourceStrategy extends Strategy with Logging { val bucketSpec: Option[BucketSpec] = fsRelation.bucketSpec val bucketSet = if (shouldPruneBuckets(bucketSpec)) { - getBuckets(normalizedFilters, bucketSpec.get) + genBucketSet(normalizedFilters, bucketSpec.get) } else { None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 2d008efbced48..a9414200e70f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -22,6 +22,7 @@ import java.net.URI import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{DataSourceScanExec, SortExec} @@ -88,10 +89,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { // To verify if the bucket pruning works, this function checks two conditions: // 1) Check if the pruned buckets (before filtering) are empty. // 2) Verify the final result is the same as the expected one - private def checkPrunedAnswers(bucketSpec: BucketSpec, - bucketValues: Seq[Integer], - filterCondition: Column, - originalDataFrame: DataFrame): Unit = { + private def checkPrunedAnswers( + bucketSpec: BucketSpec, + bucketValues: Seq[Integer], + filterCondition: Column, + originalDataFrame: DataFrame): Unit = { // This test verifies parts of the plan. Disable whole stage codegen. withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k") @@ -164,6 +166,14 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { bucketValues = Seq(j, j + 1, j + 2, j + 3), filterCondition = $"j".isin(j, j + 1, j + 2, j + 3), df) + + // Case 4: InSet + val inSetExpr = expressions.InSet($"j".expr, Set(j, j + 1, j + 2, j + 3).map(lit(_).expr)) + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(j, j + 1, j + 2, j + 3), + filterCondition = Column(inSetExpr), + df) } } } From f94912896c1859d078fc52fda867beaecaa68a47 Mon Sep 17 00:00:00 2001 From: asaban Date: Wed, 6 Jun 2018 11:14:33 +0100 Subject: [PATCH 11/12] calculate bucket num only once --- .../sql/execution/DataSourceScanExec.scala | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index de5322d1c5f89..d0efa74d6094b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -389,25 +389,23 @@ case class FileSourceScanExec( } } - val prunedBucketed = if (optionalBucketSet.isDefined) { - val bucketSet = optionalBucketSet.get - bucketed.filter { - f => bucketSet.get( - BucketingUtils.getBucketId(new Path(f.filePath).getName) - .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}"))) - } - } else { - bucketed - } - - val filesGroupedToBuckets = prunedBucketed.groupBy { f => + val filesGroupedToBuckets = bucketed.groupBy { f => BucketingUtils .getBucketId(new Path(f.filePath).getName) .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}")) } + val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { + val bucketSet = optionalBucketSet.get + filesGroupedToBuckets.filter { + f => bucketSet.get(f._1) + } + } else { + filesGroupedToBuckets + } + val filePartitions = Seq.tabulate(bucketSpec.numBuckets) { bucketId => - FilePartition(bucketId, filesGroupedToBuckets.getOrElse(bucketId, Nil)) + FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Nil)) } new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions) From 72712c951338d3acbd1e141717d779ed33b03d82 Mon Sep 17 00:00:00 2001 From: asaban Date: Wed, 6 Jun 2018 11:20:40 +0100 Subject: [PATCH 12/12] inline statement --- .../spark/sql/execution/DataSourceScanExec.scala | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index d0efa74d6094b..610bbdaa8a488 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -381,20 +381,18 @@ case class FileSourceScanExec( selectedPartitions: Seq[PartitionDirectory], fsRelation: HadoopFsRelation): RDD[InternalRow] = { logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") - val bucketed = + val filesGroupedToBuckets = selectedPartitions.flatMap { p => p.files.map { f => val hosts = getBlockHosts(getBlockLocations(f), 0, f.getLen) PartitionedFile(p.values, f.getPath.toUri.toString, 0, f.getLen, hosts) } + }.groupBy { f => + BucketingUtils + .getBucketId(new Path(f.filePath).getName) + .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}")) } - val filesGroupedToBuckets = bucketed.groupBy { f => - BucketingUtils - .getBucketId(new Path(f.filePath).getName) - .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}")) - } - val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { val bucketSet = optionalBucketSet.get filesGroupedToBuckets.filter {