From 2b8e735bdc21f4abc0ec155c763203b6de611bfc Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Mon, 5 Dec 2022 16:27:50 -0800 Subject: [PATCH] initial commit --- .../datasources/v2/BatchScanExec.scala | 36 +++++++++++++------ .../KeyGroupedPartitioningSuite.scala | 6 ++-- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 48569ddc07de5..0f7bdd9e1fb4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -81,18 +81,21 @@ case class BatchScanExec( val newRows = new InternalRowSet(p.expressions.map(_.dataType)) newRows ++= newPartitions.map(_.asInstanceOf[HasPartitionKey].partitionKey()) - val oldRows = p.partitionValuesOpt.get - if (oldRows.size != newRows.size) { - throw new SparkException("Data source must have preserved the original partitioning " + - "during runtime filtering: the number of unique partition values obtained " + - s"through HasPartitionKey changed: before ${oldRows.size}, after ${newRows.size}") + val oldRows = p.partitionValuesOpt.get.toSet + // We require the new number of partition keys to be equal or less than the old number + // of partition keys here. In the case of less than, empty partitions will be added for + // those missing keys that are not present in the new input partitions. + if (oldRows.size < newRows.size) { + throw new SparkException("During runtime filtering, data source must either report " + + "the same number of partition keys, or a subset of partition keys from the " + + s"original. Before: ${oldRows.size} partition keys. After: ${newRows.size} " + + "partition keys") } - if (!oldRows.forall(newRows.contains)) { - throw new SparkException("Data source must have preserved the original partitioning " + - "during runtime filtering: the number of unique partition values obtained " + - s"through HasPartitionKey remain the same but do not exactly match") + if (!newRows.forall(oldRows.contains)) { + throw new SparkException("During runtime filtering, data source must not report new " + + "partition keys that are not present in the original partitioning.") } groupPartitions(newPartitions).get.map(_._2) @@ -114,8 +117,21 @@ case class BatchScanExec( // return an empty RDD with 1 partition if dynamic filtering removed the only split sparkContext.parallelize(Array.empty[InternalRow], 1) } else { + var finalPartitions = filteredPartitions + + outputPartitioning match { + case p: KeyGroupedPartitioning => + val partitionMapping = finalPartitions.map(s => + s.head.asInstanceOf[HasPartitionKey].partitionKey() -> s).toMap + finalPartitions = p.partitionValuesOpt.get.map { partKey => + // Use empty partition for those partition keys that are not present + partitionMapping.getOrElse(partKey, Seq.empty) + } + case _ => + } + new DataSourceRDD( - sparkContext, filteredPartitions, readerFactory, supportsColumnar, customMetrics) + sparkContext, finalPartitions, readerFactory, supportsColumnar, customMetrics) } postDriverMetrics() rdd diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index c0dc326361693..b2b8951a97997 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -433,11 +433,11 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { s"(2, 11.0, cast('2020-01-01' as timestamp)), " + s"(3, 19.5, cast('2020-02-01' as timestamp))") - // number of unique partitions changed after dynamic filtering - should throw exception + // number of unique partitions changed after dynamic filtering - the gap should be filled + // with empty partitions and the job should still succeed var df = sql(s"SELECT sum(p.price) from testcat.ns.$items i, testcat.ns.$purchases p WHERE " + s"i.id = p.item_id AND i.price > 40.0") - val e = intercept[Exception](df.collect()) - assert(e.getMessage.contains("number of unique partition values")) + checkAnswer(df, Seq(Row(131))) // dynamic filtering doesn't change partitioning so storage-partitioned join should kick in df = sql(s"SELECT sum(p.price) from testcat.ns.$items i, testcat.ns.$purchases p WHERE " +