From 9059a673d88869b5cca7edb81df1cfdab7f2900a Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Thu, 24 Jun 2021 00:08:59 -0700 Subject: [PATCH 1/3] Core : Add Files Perf improvement by push down partition filter to Spark/Hive catalog --- build.gradle | 5 +++- .../apache/iceberg/spark/SparkTableUtil.java | 24 ++++++++++++------- .../extensions/TestAddFilesProcedure.java | 20 ++++++++++++++++ versions.props | 1 + 4 files changed, 40 insertions(+), 10 deletions(-) diff --git a/build.gradle b/build.gradle index d87eacebb692..1ea13e87c1ef 100644 --- a/build.gradle +++ b/build.gradle @@ -833,7 +833,10 @@ project(':iceberg-spark') { compile project(':iceberg-parquet') compile project(':iceberg-arrow') compile project(':iceberg-hive-metastore') - + compile ('org.scala-lang.modules:scala-java8-compat_2.12') { + exclude group: "org.scala-lang", module: 'scala-library' + } + compileOnly "org.apache.avro:avro" compileOnly("org.apache.spark:spark-hive_2.11") { exclude group: 'org.apache.avro', module: 'avro' diff --git a/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java b/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java index 29c71a8a56ef..505918620252 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java +++ b/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java @@ -26,6 +26,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; @@ -80,10 +81,12 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import scala.Function2; import scala.Option; +import scala.Predef; import scala.Some; import scala.Tuple2; import scala.collection.JavaConverters; import scala.collection.Seq; +import scala.compat.java8.OptionConverters; import scala.runtime.AbstractPartialFunction; import static org.apache.spark.sql.functions.col; @@ -140,7 +143,7 @@ public static Dataset partitionDFByFilter(SparkSession spark, String table, public static List getPartitions(SparkSession spark, String table) { try { TableIdentifier tableIdent = spark.sessionState().sqlParser().parseTableIdentifier(table); - return getPartitions(spark, tableIdent); + return getPartitions(spark, tableIdent, Optional.empty()); } catch (ParseException e) { throw SparkExceptionUtil.toUncheckedException(e, "Unable to parse table identifier: %s", table); } @@ -151,15 +154,21 @@ public static List getPartitions(SparkSession spark, String tabl * * @param spark a Spark session * @param tableIdent a table identifier + * @param partitionFilter the partition filter * @return all table's partitions */ - public static List getPartitions(SparkSession spark, TableIdentifier tableIdent) { + public static List getPartitions(SparkSession spark, TableIdentifier tableIdent, + Optional> partitionFilter) { try { SessionCatalog catalog = spark.sessionState().catalog(); CatalogTable catalogTable = catalog.getTableMetadata(tableIdent); - Seq partitions = catalog.listPartitions(tableIdent, Option.empty()); + Option> partSpec = + OptionConverters.toScala(partitionFilter.map(pf -> + JavaConverters.mapAsScalaMapConverter(pf).asScala() + .toMap(Predef.>conforms()))); + Seq partitions = catalog.listPartitions(tableIdent, partSpec); return JavaConverters .seqAsJavaListConverter(partitions) .asJava() @@ -375,14 +384,11 @@ public static void importSparkTable(SparkSession spark, TableIdentifier sourceTa if (Objects.equal(spec, PartitionSpec.unpartitioned())) { importUnpartitionedSparkTable(spark, sourceTableIdentWithDB, targetTable); } else { - List sourceTablePartitions = getPartitions(spark, sourceTableIdent); + List sourceTablePartitions = getPartitions(spark, sourceTableIdent, + Optional.of(partitionFilter)); Preconditions.checkArgument(!sourceTablePartitions.isEmpty(), "Cannot find any partitions in table %s", sourceTableIdent); - List filteredPartitions = filterPartitions(sourceTablePartitions, partitionFilter); - Preconditions.checkArgument(!filteredPartitions.isEmpty(), - "Cannot find any partitions which match the given filter. Partition filter is %s", - MAP_JOINER.join(partitionFilter)); - importSparkPartitions(spark, filteredPartitions, targetTable, spec, stagingDir); + importSparkPartitions(spark, sourceTablePartitions, targetTable, spec, stagingDir); } } catch (AnalysisException e) { throw SparkExceptionUtil.toUncheckedException( diff --git a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java index edc47f64c320..dea01a10a647 100644 --- a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java +++ b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestAddFilesProcedure.java @@ -318,6 +318,26 @@ public void addFilteredPartitionsToPartitioned() { sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); } + @Test + public void addFilteredPartitionsToPartitioned2() { + createCompositePartitionedTable("parquet"); + + String createIceberg = + "CREATE TABLE %s (id Integer, name String, dept String, subdept String) USING iceberg " + + "PARTITIONED BY (id, dept)"; + + sql(createIceberg, tableName); + + Object result = scalarSql("CALL %s.system.add_files('%s', '`parquet`.`%s`', map('dept', 'hr'))", + catalogName, tableName, fileTableDir.getAbsolutePath()); + + Assert.assertEquals(6L, result); + + assertEquals("Iceberg table contains correct data", + sql("SELECT id, name, dept, subdept FROM %s WHERE dept = 'hr' ORDER BY id", sourceTableName), + sql("SELECT id, name, dept, subdept FROM %s ORDER BY id", tableName)); + } + @Test public void addWeirdCaseHiveTable() { createWeirdCaseTable(); diff --git a/versions.props b/versions.props index 1b2c5f4f284f..27f9cb871e9b 100644 --- a/versions.props +++ b/versions.props @@ -18,6 +18,7 @@ org.apache.arrow:arrow-memory-netty = 2.0.0 com.github.stephenc.findbugs:findbugs-annotations = 1.3.9-1 software.amazon.awssdk:* = 2.15.7 org.scala-lang:scala-library = 2.12.10 +org.scala-lang.modules:scala-java8-compat_2.12 = 0.8.0 org.projectnessie:* = 0.5.1 javax.ws.rs:javax.ws.rs-api = 2.1.1 io.quarkus:* = 1.13.1.Final From fb91eee01ea35c1484b166b2ea639e8c5518529c Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Mon, 12 Jul 2021 17:57:49 -0700 Subject: [PATCH 2/3] Remove Scala conversion dependency and add unit test --- build.gradle | 5 +- .../apache/iceberg/spark/SparkTableUtil.java | 24 +++--- .../spark/source/TestSparkTableUtil.java | 85 +++++++++++++++++++ versions.props | 1 - 4 files changed, 98 insertions(+), 17 deletions(-) diff --git a/build.gradle b/build.gradle index 1ea13e87c1ef..d87eacebb692 100644 --- a/build.gradle +++ b/build.gradle @@ -833,10 +833,7 @@ project(':iceberg-spark') { compile project(':iceberg-parquet') compile project(':iceberg-arrow') compile project(':iceberg-hive-metastore') - compile ('org.scala-lang.modules:scala-java8-compat_2.12') { - exclude group: "org.scala-lang", module: 'scala-library' - } - + compileOnly "org.apache.avro:avro" compileOnly("org.apache.spark:spark-hive_2.11") { exclude group: 'org.apache.avro', module: 'avro' diff --git a/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java b/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java index 505918620252..e822d45147b9 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java +++ b/spark/src/main/java/org/apache/iceberg/spark/SparkTableUtil.java @@ -26,7 +26,6 @@ import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.stream.Collectors; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; @@ -86,7 +85,6 @@ import scala.Tuple2; import scala.collection.JavaConverters; import scala.collection.Seq; -import scala.compat.java8.OptionConverters; import scala.runtime.AbstractPartialFunction; import static org.apache.spark.sql.functions.col; @@ -143,7 +141,7 @@ public static Dataset partitionDFByFilter(SparkSession spark, String table, public static List getPartitions(SparkSession spark, String table) { try { TableIdentifier tableIdent = spark.sessionState().sqlParser().parseTableIdentifier(table); - return getPartitions(spark, tableIdent, Optional.empty()); + return getPartitions(spark, tableIdent, null); } catch (ParseException e) { throw SparkExceptionUtil.toUncheckedException(e, "Unable to parse table identifier: %s", table); } @@ -154,21 +152,23 @@ public static List getPartitions(SparkSession spark, String tabl * * @param spark a Spark session * @param tableIdent a table identifier - * @param partitionFilter the partition filter + * @param partitionFilter partition filter, or null if no filter * @return all table's partitions */ public static List getPartitions(SparkSession spark, TableIdentifier tableIdent, - Optional> partitionFilter) { + Map partitionFilter) { try { SessionCatalog catalog = spark.sessionState().catalog(); CatalogTable catalogTable = catalog.getTableMetadata(tableIdent); - Option> partSpec = - OptionConverters.toScala(partitionFilter.map(pf -> - JavaConverters.mapAsScalaMapConverter(pf).asScala() - .toMap(Predef.>conforms()))); - - Seq partitions = catalog.listPartitions(tableIdent, partSpec); + Option> scalaPartitionFilter; + if (partitionFilter != null && !partitionFilter.isEmpty()) { + scalaPartitionFilter = Option.apply(JavaConverters.mapAsScalaMapConverter(partitionFilter).asScala() + .toMap(Predef.conforms())); + } else { + scalaPartitionFilter = Option.empty(); + } + Seq partitions = catalog.listPartitions(tableIdent, scalaPartitionFilter); return JavaConverters .seqAsJavaListConverter(partitions) .asJava() @@ -385,7 +385,7 @@ public static void importSparkTable(SparkSession spark, TableIdentifier sourceTa importUnpartitionedSparkTable(spark, sourceTableIdentWithDB, targetTable); } else { List sourceTablePartitions = getPartitions(spark, sourceTableIdent, - Optional.of(partitionFilter)); + partitionFilter); Preconditions.checkArgument(!sourceTablePartitions.isEmpty(), "Cannot find any partitions in table %s", sourceTableIdent); importSparkPartitions(spark, sourceTablePartitions, targetTable, spec, stagingDir); diff --git a/spark2/src/test/java/org/apache/iceberg/spark/source/TestSparkTableUtil.java b/spark2/src/test/java/org/apache/iceberg/spark/source/TestSparkTableUtil.java index 967f7b86e298..ec7afed405f5 100644 --- a/spark2/src/test/java/org/apache/iceberg/spark/source/TestSparkTableUtil.java +++ b/spark2/src/test/java/org/apache/iceberg/spark/source/TestSparkTableUtil.java @@ -21,7 +21,13 @@ import java.io.File; import java.io.IOException; +import java.util.AbstractMap; +import java.util.Collections; import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.iceberg.FileFormat; @@ -34,6 +40,7 @@ import org.apache.iceberg.mapping.NameMappingParser; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.spark.SparkSchemaUtil; import org.apache.iceberg.spark.SparkTableUtil; import org.apache.iceberg.spark.SparkTableUtil.SparkPartition; @@ -369,6 +376,84 @@ public void testImportUnpartitionedWithWhitespace() throws Exception { } } + public static class GetPartitions { + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + // This logic does not really depend on format + private final FileFormat format = FileFormat.PARQUET; + + @Test + public void testPartitionScan() throws Exception { + + List records = Lists.newArrayList( + new ThreeColumnRecord(1, "ab", "data"), + new ThreeColumnRecord(2, "b c", "data"), + new ThreeColumnRecord(1, "b c", "data"), + new ThreeColumnRecord(2, "ab", "data")); + + File location = temp.newFolder("partitioned_table"); + String tableName = "external_table"; + + spark.createDataFrame(records, ThreeColumnRecord.class) + .write().mode("overwrite").format(format.toString()) + .partitionBy("c1", "c2").saveAsTable(tableName); + + TableIdentifier source = spark.sessionState().sqlParser() + .parseTableIdentifier(tableName); + + Map partition1 = Stream.of( + new AbstractMap.SimpleImmutableEntry<>("c1", "1"), + new AbstractMap.SimpleImmutableEntry<>("c2", "ab")) + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())); + Map partition2 = Stream.of( + new AbstractMap.SimpleImmutableEntry<>("c1", "2"), + new AbstractMap.SimpleImmutableEntry<>("c2", "b c")) + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())); + + Map partition3 = Stream.of( + new AbstractMap.SimpleImmutableEntry<>("c1", "1"), + new AbstractMap.SimpleImmutableEntry<>("c2", "b c")) + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())); + Map partition4 = Stream.of( + new AbstractMap.SimpleImmutableEntry<>("c1", "2"), + new AbstractMap.SimpleImmutableEntry<>("c2", "ab")) + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())); + + List partitionsC11 = + SparkTableUtil.getPartitions(spark, source, Collections.singletonMap("c1", "1")); + Set> expectedC11 = + Sets.newHashSet(partition1, partition3); + Set> actualC11 = partitionsC11.stream().map( + p -> p.getValues()).collect(Collectors.toSet()); + Assert.assertEquals("Wrong partitions fetched for c1=1", expectedC11, actualC11); + + List partitionsC12 = + SparkTableUtil.getPartitions(spark, source, Collections.singletonMap("c1", "2")); + Set> expectedC12 = Sets.newHashSet(partition2, partition4); + Set> actualC12 = partitionsC12.stream().map( + p -> p.getValues()).collect(Collectors.toSet()); + Assert.assertEquals("Wrong partitions fetched for c1=2", expectedC12, actualC12); + + List partitionsC21 = + SparkTableUtil.getPartitions(spark, source, Collections.singletonMap("c2", "ab")); + Set> expectedC21 = + Sets.newHashSet(partition1, partition4); + Set> actualC21 = partitionsC21.stream().map( + p -> p.getValues()).collect(Collectors.toSet()); + Assert.assertEquals("Wrong partitions fetched for c2=ab", expectedC21, actualC21); + + List partitionsC22 = + SparkTableUtil.getPartitions(spark, source, Collections.singletonMap("c2", "b c")); + Set> expectedC22 = + Sets.newHashSet(partition2, partition3); + Set> actualC22 = partitionsC22.stream().map( + p -> p.getValues()).collect(Collectors.toSet()); + Assert.assertEquals("Wrong partitions fetched for c2=b c", expectedC22, actualC22); + } + } + public static class PartitionScan { @Before diff --git a/versions.props b/versions.props index 27f9cb871e9b..1b2c5f4f284f 100644 --- a/versions.props +++ b/versions.props @@ -18,7 +18,6 @@ org.apache.arrow:arrow-memory-netty = 2.0.0 com.github.stephenc.findbugs:findbugs-annotations = 1.3.9-1 software.amazon.awssdk:* = 2.15.7 org.scala-lang:scala-library = 2.12.10 -org.scala-lang.modules:scala-java8-compat_2.12 = 0.8.0 org.projectnessie:* = 0.5.1 javax.ws.rs:javax.ws.rs-api = 2.1.1 io.quarkus:* = 1.13.1.Final From 8b9d576c74c510d19a6cc88c3c20022a6257f7bb Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Mon, 12 Jul 2021 21:04:48 -0700 Subject: [PATCH 3/3] Use guava in new test --- .../spark/source/TestSparkTableUtil.java | 41 ++++++++----------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/spark2/src/test/java/org/apache/iceberg/spark/source/TestSparkTableUtil.java b/spark2/src/test/java/org/apache/iceberg/spark/source/TestSparkTableUtil.java index ec7afed405f5..1c9362a26133 100644 --- a/spark2/src/test/java/org/apache/iceberg/spark/source/TestSparkTableUtil.java +++ b/spark2/src/test/java/org/apache/iceberg/spark/source/TestSparkTableUtil.java @@ -21,13 +21,10 @@ import java.io.File; import java.io.IOException; -import java.util.AbstractMap; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; -import java.util.stream.Stream; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.iceberg.FileFormat; @@ -393,7 +390,6 @@ public void testPartitionScan() throws Exception { new ThreeColumnRecord(1, "b c", "data"), new ThreeColumnRecord(2, "ab", "data")); - File location = temp.newFolder("partitioned_table"); String tableName = "external_table"; spark.createDataFrame(records, ThreeColumnRecord.class) @@ -403,26 +399,21 @@ public void testPartitionScan() throws Exception { TableIdentifier source = spark.sessionState().sqlParser() .parseTableIdentifier(tableName); - Map partition1 = Stream.of( - new AbstractMap.SimpleImmutableEntry<>("c1", "1"), - new AbstractMap.SimpleImmutableEntry<>("c2", "ab")) - .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())); - Map partition2 = Stream.of( - new AbstractMap.SimpleImmutableEntry<>("c1", "2"), - new AbstractMap.SimpleImmutableEntry<>("c2", "b c")) - .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())); - - Map partition3 = Stream.of( - new AbstractMap.SimpleImmutableEntry<>("c1", "1"), - new AbstractMap.SimpleImmutableEntry<>("c2", "b c")) - .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())); - Map partition4 = Stream.of( - new AbstractMap.SimpleImmutableEntry<>("c1", "2"), - new AbstractMap.SimpleImmutableEntry<>("c2", "ab")) - .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())); + Map partition1 = ImmutableMap.of( + "c1", "1", + "c2", "ab"); + Map partition2 = ImmutableMap.of( + "c1", "2", + "c2", "b c"); + Map partition3 = ImmutableMap.of( + "c1", "1", + "c2", "b c"); + Map partition4 = ImmutableMap.of( + "c1", "2", + "c2", "ab"); List partitionsC11 = - SparkTableUtil.getPartitions(spark, source, Collections.singletonMap("c1", "1")); + SparkTableUtil.getPartitions(spark, source, ImmutableMap.of("c1", "1")); Set> expectedC11 = Sets.newHashSet(partition1, partition3); Set> actualC11 = partitionsC11.stream().map( @@ -430,14 +421,14 @@ public void testPartitionScan() throws Exception { Assert.assertEquals("Wrong partitions fetched for c1=1", expectedC11, actualC11); List partitionsC12 = - SparkTableUtil.getPartitions(spark, source, Collections.singletonMap("c1", "2")); + SparkTableUtil.getPartitions(spark, source, ImmutableMap.of("c1", "2")); Set> expectedC12 = Sets.newHashSet(partition2, partition4); Set> actualC12 = partitionsC12.stream().map( p -> p.getValues()).collect(Collectors.toSet()); Assert.assertEquals("Wrong partitions fetched for c1=2", expectedC12, actualC12); List partitionsC21 = - SparkTableUtil.getPartitions(spark, source, Collections.singletonMap("c2", "ab")); + SparkTableUtil.getPartitions(spark, source, ImmutableMap.of("c2", "ab")); Set> expectedC21 = Sets.newHashSet(partition1, partition4); Set> actualC21 = partitionsC21.stream().map( @@ -445,7 +436,7 @@ public void testPartitionScan() throws Exception { Assert.assertEquals("Wrong partitions fetched for c2=ab", expectedC21, actualC21); List partitionsC22 = - SparkTableUtil.getPartitions(spark, source, Collections.singletonMap("c2", "b c")); + SparkTableUtil.getPartitions(spark, source, ImmutableMap.of("c2", "b c")); Set> expectedC22 = Sets.newHashSet(partition2, partition3); Set> actualC22 = partitionsC22.stream().map(