diff --git a/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java b/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java index 29c71a8a56ef..e822d45147b9 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java +++ b/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java @@ -80,6 +80,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import scala.Function2; import scala.Option; +import scala.Predef; import scala.Some; import scala.Tuple2; import scala.collection.JavaConverters; @@ -140,7 +141,7 @@ public static Dataset partitionDFByFilter(SparkSession spark, String table, public static List getPartitions(SparkSession spark, String table) { try { TableIdentifier tableIdent = spark.sessionState().sqlParser().parseTableIdentifier(table); - return getPartitions(spark, tableIdent); + return getPartitions(spark, tableIdent, null); } catch (ParseException e) { throw SparkExceptionUtil.toUncheckedException(e, "Unable to parse table identifier: %s", table); } @@ -151,15 +152,23 @@ public static List getPartitions(SparkSession spark, String tabl * * @param spark a Spark session * @param tableIdent a table identifier + * @param partitionFilter partition filter, or null if no filter * @return all table's partitions */ - public static List getPartitions(SparkSession spark, TableIdentifier tableIdent) { + public static List getPartitions(SparkSession spark, TableIdentifier tableIdent, + Map partitionFilter) { try { SessionCatalog catalog = spark.sessionState().catalog(); CatalogTable catalogTable = catalog.getTableMetadata(tableIdent); - Seq partitions = catalog.listPartitions(tableIdent, Option.empty()); - + Option> scalaPartitionFilter; + if (partitionFilter != null && !partitionFilter.isEmpty()) { + scalaPartitionFilter = Option.apply(JavaConverters.mapAsScalaMapConverter(partitionFilter).asScala() + .toMap(Predef.conforms())); + } else { + scalaPartitionFilter = Option.empty(); + } + Seq partitions = catalog.listPartitions(tableIdent, scalaPartitionFilter); return JavaConverters .seqAsJavaListConverter(partitions) .asJava() @@ -375,14 +384,11 @@ public static void importSparkTable(SparkSession spark, TableIdentifier sourceTa if (Objects.equal(spec, PartitionSpec.unpartitioned())) { importUnpartitionedSparkTable(spark, sourceTableIdentWithDB, targetTable); } else { - List sourceTablePartitions = getPartitions(spark, sourceTableIdent); + List sourceTablePartitions = getPartitions(spark, sourceTableIdent, + partitionFilter); Preconditions.checkArgument(!sourceTablePartitions.isEmpty(), "Cannot find any partitions in table %s", sourceTableIdent); - List filteredPartitions = filterPartitions(sourceTablePartitions, partitionFilter); - Preconditions.checkArgument(!filteredPartitions.isEmpty(), - "Cannot find any partitions which match the given filter. Partition filter is %s", - MAP_JOINER.join(partitionFilter)); - importSparkPartitions(spark, filteredPartitions, targetTable, spec, stagingDir); + importSparkPartitions(spark, sourceTablePartitions, targetTable, spec, stagingDir); } } catch (AnalysisException e) { throw SparkExceptionUtil.toUncheckedException( diff --git a/spark2/src/test/java/org/apache/iceberg/spark/source/TestSparkTableUtil.java b/spark2/src/test/java/org/apache/iceberg/spark/source/TestSparkTableUtil.java index 967f7b86e298..1c9362a26133 100644 --- a/spark2/src/test/java/org/apache/iceberg/spark/source/TestSparkTableUtil.java +++ b/spark2/src/test/java/org/apache/iceberg/spark/source/TestSparkTableUtil.java @@ -22,6 +22,9 @@ import java.io.File; import java.io.IOException; import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.iceberg.FileFormat; @@ -34,6 +37,7 @@ import org.apache.iceberg.mapping.NameMappingParser; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.spark.SparkSchemaUtil; import org.apache.iceberg.spark.SparkTableUtil; import org.apache.iceberg.spark.SparkTableUtil.SparkPartition; @@ -369,6 +373,78 @@ public void testImportUnpartitionedWithWhitespace() throws Exception { } } + public static class GetPartitions { + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + // This logic does not really depend on format + private final FileFormat format = FileFormat.PARQUET; + + @Test + public void testPartitionScan() throws Exception { + + List records = Lists.newArrayList( + new ThreeColumnRecord(1, "ab", "data"), + new ThreeColumnRecord(2, "b c", "data"), + new ThreeColumnRecord(1, "b c", "data"), + new ThreeColumnRecord(2, "ab", "data")); + + String tableName = "external_table"; + + spark.createDataFrame(records, ThreeColumnRecord.class) + .write().mode("overwrite").format(format.toString()) + .partitionBy("c1", "c2").saveAsTable(tableName); + + TableIdentifier source = spark.sessionState().sqlParser() + .parseTableIdentifier(tableName); + + Map partition1 = ImmutableMap.of( + "c1", "1", + "c2", "ab"); + Map partition2 = ImmutableMap.of( + "c1", "2", + "c2", "b c"); + Map partition3 = ImmutableMap.of( + "c1", "1", + "c2", "b c"); + Map partition4 = ImmutableMap.of( + "c1", "2", + "c2", "ab"); + + List partitionsC11 = + SparkTableUtil.getPartitions(spark, source, ImmutableMap.of("c1", "1")); + Set> expectedC11 = + Sets.newHashSet(partition1, partition3); + Set> actualC11 = partitionsC11.stream().map( + p -> p.getValues()).collect(Collectors.toSet()); + Assert.assertEquals("Wrong partitions fetched for c1=1", expectedC11, actualC11); + + List partitionsC12 = + SparkTableUtil.getPartitions(spark, source, ImmutableMap.of("c1", "2")); + Set> expectedC12 = Sets.newHashSet(partition2, partition4); + Set> actualC12 = partitionsC12.stream().map( + p -> p.getValues()).collect(Collectors.toSet()); + Assert.assertEquals("Wrong partitions fetched for c1=2", expectedC12, actualC12); + + List partitionsC21 = + SparkTableUtil.getPartitions(spark, source, ImmutableMap.of("c2", "ab")); + Set> expectedC21 = + Sets.newHashSet(partition1, partition4); + Set> actualC21 = partitionsC21.stream().map( + p -> p.getValues()).collect(Collectors.toSet()); + Assert.assertEquals("Wrong partitions fetched for c2=ab", expectedC21, actualC21); + + List partitionsC22 = + SparkTableUtil.getPartitions(spark, source, ImmutableMap.of("c2", "b c")); + Set> expectedC22 = + Sets.newHashSet(partition2, partition3); + Set> actualC22 = partitionsC22.stream().map( + p -> p.getValues()).collect(Collectors.toSet()); + Assert.assertEquals("Wrong partitions fetched for c2=b c", expectedC22, actualC22); + } + } + public static class PartitionScan { @Before diff --git a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java index edc47f64c320..dea01a10a647 100644 --- a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java +++ b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java @@ -318,6 +318,26 @@ public void addFilteredPartitionsToPartitioned() { sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); } + @Test + public void addFilteredPartitionsToPartitioned2() { + createCompositePartitionedTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg " + + "PARTITIONED BY (id, dept)"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`', map('dept', 'hr'))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(6L, result); + + assertEquals("Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE dept = 'hr' ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + @Test public void addWeirdCaseHiveTable() { createWeirdCaseTable();