From 9544432d939744ce88b5bbdae00728c47baa28db Mon Sep 17 00:00:00 2001 From: Li Xian Date: Tue, 30 Mar 2021 01:58:11 +0800 Subject: [PATCH 1/7] column reader with rowIndexes --- .../parquet/VectorizedColumnReader.java | 91 +++++++++++++- .../VectorizedParquetRecordReader.java | 3 +- .../ParquetColumnIndexBenchmark.scala | 118 ++++++++++++++++++ 3 files changed, 206 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ParquetColumnIndexBenchmark.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 672b73e94c42f..b84b8f9b13883 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -21,7 +21,7 @@ import java.math.BigInteger; import java.time.ZoneId; import java.time.ZoneOffset; -import java.util.Arrays; +import java.util.*; import org.apache.parquet.bytes.ByteBufferInputStream; import org.apache.parquet.bytes.BytesInput; @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime; import org.apache.spark.sql.execution.datasources.DataSourceUtils; import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException; +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; @@ -111,6 +112,12 @@ public class VectorizedColumnReader { private final String datetimeRebaseMode; private final String int96RebaseMode; + // TODO handle and init these filed properly + private Optional rowIndexesIterator; + private long[] rowIndexes; + private WritableColumnVector tempVector; + private Long currentRow; + private boolean isDecimalTypeMatched(DataType dt) { DecimalType d = (DecimalType) dt; DecimalMetadata dm = descriptor.getPrimitiveType().getDecimalMetadata(); @@ -140,7 +147,10 @@ public VectorizedColumnReader( PageReader pageReader, ZoneId convertTz, String datetimeRebaseMode, - String int96RebaseMode) throws IOException { + String int96RebaseMode, + Optional rowIndexesIterator + ) throws IOException { + this.rowIndexesIterator = rowIndexesIterator; this.descriptor = descriptor; this.pageReader = pageReader; this.convertTz = convertTz; @@ -248,7 +258,31 @@ static long rebaseInt96(long julianMicros, final boolean failIfRebase) { /** * Reads `total` values from this columnReader into column. */ - void readBatch(int total, WritableColumnVector column) throws IOException { + void readBatch(int total, WritableColumnVector _column) throws IOException { + WritableColumnVector column; + if (rowIndexesIterator.isPresent()) { + if (tempVector == null) { + switch (descriptor.getPrimitiveType().getPrimitiveTypeName()) { + case INT64: + tempVector = new OnHeapColumnVector(4096, DataTypes.LongType); + break; + case BINARY: + tempVector = new OnHeapColumnVector(4096, DataTypes.BinaryType); + break; + } + } + column = tempVector; + column.reset(); + + rowIndexes = new long[total]; + for (int i = 0; i < total; i++) { + rowIndexes[i] = rowIndexesIterator.get().next(); + } + } else { + column = _column; + } + + int rowId = 0; WritableColumnVector dictionaryIds = null; if (dictionary != null) { @@ -257,6 +291,7 @@ void readBatch(int total, WritableColumnVector column) throws IOException { // page. dictionaryIds = column.reserveDictionaryIds(total); } + while (total > 0) { // Compute the number of values we want to read in this page. int leftInPage = (int) (endOfPageValueCount - valuesRead); @@ -338,9 +373,52 @@ void readBatch(int total, WritableColumnVector column) throws IOException { } } + if (rowIndexesIterator.isPresent()) { + boolean continuousRange = (rowIndexes[total - 1] - rowIndexes[0] + 1) == total; + if (continuousRange) { + // skip to offset pos and dump all remaining values + int offset = (int) (rowIndexes[rowId] - currentRow); + if (offset < num) { + switch (typeName) { + case INT64: + _column.putLongs(rowId, num, column.getLongs(offset, num - offset), 0); + break; + case BINARY: + for (int i = 0; i < num - offset; i++) { + _column.putByteArray(rowId + i, column.getBinary(i + offset)); + } + break; + } + currentRow += num; + rowId += (num - offset); + total -= (num - offset); + } else { + currentRow += num; + } + } else { + // need to check every row + for (int i = 0; i < num; ) { + while (currentRow < rowIndexes[rowId]) { + i++; + currentRow++; + } + switch (typeName) { + case INT64: + _column.putLong(rowId, column.getLong(i)); + break; + case BINARY: + _column.putByteArray(rowId, column.getBinary(i)); + } + rowId++; + total--; + } + } + } else { + rowId += num; + total -= num; + } + valuesRead += num; - rowId += num; - total -= num; } } @@ -853,6 +931,7 @@ private void initDataReader(Encoding dataEncoding, ByteBufferInputStream in) thr } private void readPageV1(DataPageV1 page) throws IOException { + this.currentRow = page.getFirstRowIndex().orElse(0L); this.pageValueCount = page.getValueCount(); ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL); ValuesReader dlReader; @@ -878,6 +957,7 @@ private void readPageV1(DataPageV1 page) throws IOException { } private void readPageV2(DataPageV2 page) throws IOException { + this.currentRow = page.getFirstRowIndex().orElse(0L); this.pageValueCount = page.getValueCount(); this.repetitionLevelColumn = createRLEIterator(descriptor.getMaxRepetitionLevel(), page.getRepetitionLevels(), descriptor); @@ -894,4 +974,5 @@ private void readPageV2(DataPageV2 page) throws IOException { throw new IOException("could not read page " + page + " in col " + descriptor, e); } } + } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 1b159534c8a4f..7f234a32ed48a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -336,7 +336,8 @@ private void checkEndOfRowGroup() throws IOException { pages.getPageReader(columns.get(i)), convertTz, datetimeRebaseMode, - int96RebaseMode); + int96RebaseMode, + pages.getRowIndexes()); } totalCountLoadedSoFar += pages.getRowCount(); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ParquetColumnIndexBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ParquetColumnIndexBenchmark.scala new file mode 100644 index 0000000000000..96bc54e23377b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ParquetColumnIndexBenchmark.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import java.io.File + +import scala.util.Random + +import org.apache.parquet.hadoop.ParquetInputFormat + +import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.{DataFrame, SparkSession} + +/** + * Benchmark to measure read performance with Parquet column index. + * To run this benchmark: + * {{{ + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/ParquetFilterPushdownBenchmark-results.txt". + * }}} + */ +object ParquetColumnIndexBenchmark extends SqlBasedBenchmark { + + override def getSparkSession: SparkSession = { + val conf = new SparkConf() + .setAppName(this.getClass.getSimpleName) + // Since `spark.master` always exists, overrides this value + .set("spark.master", "local[1]") + .setIfMissing("spark.driver.memory", "3g") + .setIfMissing("spark.executor.memory", "3g") + .setIfMissing("orc.compression", "snappy") + .setIfMissing("spark.sql.parquet.compression.codec", "snappy") + + SparkSession.builder().config(conf).getOrCreate() + } + + private val numRows = 1024 * 1024 * 15 + private val width = 5 + private val mid = numRows / 2 + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + private def prepareTable( + dir: File, numRows: Int): Unit = { + import spark.implicits._ + + val df = spark.range(numRows).map(i => (i, i + ":f" + "o" * Random.nextInt(200))).toDF() + + saveAsTable(df, dir) + } + + private def saveAsTable(df: DataFrame, dir: File, useDictionary: Boolean = false): Unit = { + val parquetPath = dir.getCanonicalPath + "/parquet" + df.write.mode("overwrite").parquet(parquetPath) + spark.read.parquet(parquetPath).createOrReplaceTempView("parquetTable") + } + + def filterPushDownBenchmark( + values: Int, + title: String, + whereExpr: String, + selectExpr: String = "*"): Unit = { + val benchmark = new Benchmark(title, values, minNumIters = 5, output = output) + + Seq(false, true).foreach { columnIndexEnabled => + val name = s"Parquet Vectorized ${if (columnIndexEnabled) s"(columnIndex)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(ParquetInputFormat.COLUMN_INDEX_FILTERING_ENABLED -> s"$columnIndexEnabled") { + spark.sql(s"SELECT $selectExpr FROM parquetTable WHERE $whereExpr").noop() + } + } + } + + benchmark.run() + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("Pushdown for single value filter") { + withTempPath { dir => + withTempTable("parquetTable") { + prepareTable(dir, numRows) + filterPushDownBenchmark(numRows, "simple filters", s" _1 = $numRows - 100 ") + } + } + } + + runBenchmark("Pushdown for range filter") { + withTempPath { dir => + withTempTable("parquetTable") { + prepareTable(dir, numRows) + filterPushDownBenchmark(numRows, + "range filters", s" _1 > ($numRows - 1000000) and _1 < ($numRows - 1000)") + } + } + } + + } +} From 6535575f3d2c80614d82529805c08926093a3e63 Mon Sep 17 00:00:00 2001 From: Li Xian Date: Thu, 1 Apr 2021 00:24:35 +0800 Subject: [PATCH 2/7] fix indexes --- .../parquet/VectorizedColumnReader.java | 14 +++---- .../org/apache/spark/sql/SimpleTest.scala | 40 +++++++++++++++++++ .../ParquetColumnIndexBenchmark.scala | 11 +++++ 3 files changed, 58 insertions(+), 7 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/SimpleTest.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index b84b8f9b13883..5f9ddd617b844 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -381,7 +381,7 @@ void readBatch(int total, WritableColumnVector _column) throws IOException { if (offset < num) { switch (typeName) { case INT64: - _column.putLongs(rowId, num, column.getLongs(offset, num - offset), 0); + _column.putLongs(rowId, num - offset, column.getLongs(offset, num - offset), 0); break; case BINARY: for (int i = 0; i < num - offset; i++) { @@ -389,18 +389,17 @@ void readBatch(int total, WritableColumnVector _column) throws IOException { } break; } - currentRow += num; rowId += (num - offset); total -= (num - offset); - } else { - currentRow += num; } } else { // need to check every row - for (int i = 0; i < num; ) { - while (currentRow < rowIndexes[rowId]) { + for (int i = 0; i < num && total > 0; ) { + while (currentRow + i < rowIndexes[rowId] && i < num) { i++; - currentRow++; + } + if (i >= num) { + break; } switch (typeName) { case INT64: @@ -418,6 +417,7 @@ void readBatch(int total, WritableColumnVector _column) throws IOException { total -= num; } + currentRow += num; valuesRead += num; } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SimpleTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/SimpleTest.scala new file mode 100644 index 0000000000000..763dc0d55948b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SimpleTest.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + + +object SimpleTest { + + def main(args: Array[String]): Unit = { + + val spark = SparkSession.builder + .appName("Simple Application") + .master("local[*]") + .getOrCreate() + spark.sparkContext.setCallSite("short") +// spark.sparkContext.setCallSite(new CallSite("shot","long")) + + val df = spark.read +// .option("parquet.filter.columnindex.enabled", "false") + .parquet("/Users/lxian/Documents/parquet-playground/part-00000-66712089-3639-4c41-84fb-36790dec7c79-c000.snappy.parquet") + +// df.filter(" _1 = 100000").queryExecution.debug.codegen() + df.filter(" _1 = 100000").show(false) +// df.filter(" _1 = 100000").show(false) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ParquetColumnIndexBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ParquetColumnIndexBenchmark.scala index 96bc54e23377b..139b70e36c436 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ParquetColumnIndexBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ParquetColumnIndexBenchmark.scala @@ -114,5 +114,16 @@ object ParquetColumnIndexBenchmark extends SqlBasedBenchmark { } } + runBenchmark("Pushdown for multi range filter") { + withTempPath { dir => + withTempTable("parquetTable") { + prepareTable(dir, numRows) + filterPushDownBenchmark(numRows, + "multi range filters", + s" (_1 > ($numRows - 3000000) and _1 < ($numRows - 2000000)) or ( _1 > ($numRows - 1000000) and _1 < ($numRows - 1000))") + } + } + } + } } From 2978f78122bb04af1d488a8529958f7576a3f0c5 Mon Sep 17 00:00:00 2001 From: Li Xian Date: Thu, 1 Apr 2021 01:00:56 +0800 Subject: [PATCH 3/7] more datatypes --- .../parquet/VectorizedColumnReader.java | 89 +++++++++++++------ .../VectorizedParquetRecordReader.java | 4 +- 2 files changed, 62 insertions(+), 31 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 5f9ddd617b844..87e016e8f9bfa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -113,10 +113,10 @@ public class VectorizedColumnReader { private final String int96RebaseMode; // TODO handle and init these filed properly - private Optional rowIndexesIterator; - private long[] rowIndexes; - private WritableColumnVector tempVector; - private Long currentRow; + private PrimitiveIterator.OfLong rowIndexesIterator; + private long[] rowIndexes; // row indexes of current row group + private long currentRow = 0; // current row to read + private WritableColumnVector tempColumnVector; private boolean isDecimalTypeMatched(DataType dt) { DecimalType d = (DecimalType) dt; @@ -148,7 +148,7 @@ public VectorizedColumnReader( ZoneId convertTz, String datetimeRebaseMode, String int96RebaseMode, - Optional rowIndexesIterator + PrimitiveIterator.OfLong rowIndexesIterator ) throws IOException { this.rowIndexesIterator = rowIndexesIterator; this.descriptor = descriptor; @@ -258,31 +258,28 @@ static long rebaseInt96(long julianMicros, final boolean failIfRebase) { /** * Reads `total` values from this columnReader into column. */ - void readBatch(int total, WritableColumnVector _column) throws IOException { + void readBatch(int total, int columnSize, WritableColumnVector resultColumn) throws IOException { + PrimitiveType.PrimitiveTypeName typeName = + descriptor.getPrimitiveType().getPrimitiveTypeName(); + WritableColumnVector column; - if (rowIndexesIterator.isPresent()) { - if (tempVector == null) { - switch (descriptor.getPrimitiveType().getPrimitiveTypeName()) { - case INT64: - tempVector = new OnHeapColumnVector(4096, DataTypes.LongType); - break; - case BINARY: - tempVector = new OnHeapColumnVector(4096, DataTypes.BinaryType); - break; - } + + if (rowIndexesIterator != null) { + if (tempColumnVector == null) { + tempColumnVector = new OnHeapColumnVector(columnSize, resultColumn.dataType()); } - column = tempVector; + column = tempColumnVector; column.reset(); rowIndexes = new long[total]; for (int i = 0; i < total; i++) { - rowIndexes[i] = rowIndexesIterator.get().next(); + rowIndexes[i] = rowIndexesIterator.next(); } } else { - column = _column; + // write to result column directly if no row indexes if present + column = resultColumn; } - int rowId = 0; WritableColumnVector dictionaryIds = null; if (dictionary != null) { @@ -300,8 +297,6 @@ void readBatch(int total, WritableColumnVector _column) throws IOException { leftInPage = (int) (endOfPageValueCount - valuesRead); } int num = Math.min(total, leftInPage); - PrimitiveType.PrimitiveTypeName typeName = - descriptor.getPrimitiveType().getPrimitiveTypeName(); if (isCurrentPageDictionaryEncoded) { // Read and decode dictionary ids. defColumn.readIntegers( @@ -373,21 +368,57 @@ void readBatch(int total, WritableColumnVector _column) throws IOException { } } - if (rowIndexesIterator.isPresent()) { + if (rowIndexesIterator != null) { + // copy values from temp column to result column boolean continuousRange = (rowIndexes[total - 1] - rowIndexes[0] + 1) == total; if (continuousRange) { // skip to offset pos and dump all remaining values int offset = (int) (rowIndexes[rowId] - currentRow); if (offset < num) { + int validValueNum = num - offset; switch (typeName) { - case INT64: - _column.putLongs(rowId, num - offset, column.getLongs(offset, num - offset), 0); + case BOOLEAN: + for (int i = 0; i < validValueNum; i++) { + resultColumn.putBoolean(rowId + i, column.getBoolean(offset + i)); + } + break; + case INT32:// TODO handle column types + resultColumn.putInts(rowId, validValueNum, column.getInts(offset, validValueNum), 0); + break; + case INT64:// TODO handle column types + resultColumn.putLongs(rowId, num - offset, column.getLongs(offset, validValueNum), 0); break; + case FLOAT: + resultColumn.putFloats(rowId, validValueNum, column.getFloats(offset, validValueNum), 0); + break; + case DOUBLE: + resultColumn.putDoubles(rowId, validValueNum, column.getDoubles(offset, validValueNum), 0); + break; + case INT96: case BINARY: - for (int i = 0; i < num - offset; i++) { - _column.putByteArray(rowId + i, column.getBinary(i + offset)); + if (column.dataType() == DataTypes.TimestampType) { + resultColumn.putLongs(rowId, num - offset, column.getLongs(offset, validValueNum), 0); + } else { + for (int i = 0; i < num - offset; i++) { + resultColumn.putByteArray(rowId + i, column.getBinary(offset + i)); + } + } + break; + case FIXED_LEN_BYTE_ARRAY: + if (canReadAsIntDecimal(column.dataType())) { + resultColumn.putInts(rowId, validValueNum, column.getInts(offset, validValueNum), 0); + } else if (canReadAsLongDecimal(column.dataType())) { + resultColumn.putLongs(rowId, num - offset, column.getLongs(offset, validValueNum), 0); + } else if (canReadAsBinaryDecimal(column.dataType())) { + for (int i = 0; i < num - offset; i++) { + resultColumn.putByteArray(rowId + i, column.getBinary(offset + i)); + } + } else { + throw constructConvertNotSupportedException(descriptor, column); } break; + default: + throw new IOException("Unsupported type: " + typeName); } rowId += (num - offset); total -= (num - offset); @@ -403,10 +434,10 @@ void readBatch(int total, WritableColumnVector _column) throws IOException { } switch (typeName) { case INT64: - _column.putLong(rowId, column.getLong(i)); + resultColumn.putLong(rowId, column.getLong(i)); break; case BINARY: - _column.putByteArray(rowId, column.getBinary(i)); + resultColumn.putByteArray(rowId, column.getBinary(i)); } rowId++; total--; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 7f234a32ed48a..68606ea299907 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -280,7 +280,7 @@ public boolean nextBatch() throws IOException { int num = (int) Math.min((long) capacity, totalCountLoadedSoFar - rowsReturned); for (int i = 0; i < columnReaders.length; ++i) { if (columnReaders[i] == null) continue; - columnReaders[i].readBatch(num, columnVectors[i]); + columnReaders[i].readBatch(num, capacity, columnVectors[i]); } rowsReturned += num; columnarBatch.setNumRows(num); @@ -337,7 +337,7 @@ private void checkEndOfRowGroup() throws IOException { convertTz, datetimeRebaseMode, int96RebaseMode, - pages.getRowIndexes()); + pages.getRowIndexes().orElse(null)); } totalCountLoadedSoFar += pages.getRowCount(); } From 2422abc8243cb735b1b4bd32b14257749e107fa9 Mon Sep 17 00:00:00 2001 From: Li Xian Date: Sat, 3 Apr 2021 16:28:02 +0800 Subject: [PATCH 4/7] ints --- .../parquet/VectorizedColumnReader.java | 31 +++++++++++++++++-- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 87e016e8f9bfa..ec7b8ecb2d62f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -382,11 +382,36 @@ void readBatch(int total, int columnSize, WritableColumnVector resultColumn) thr resultColumn.putBoolean(rowId + i, column.getBoolean(offset + i)); } break; - case INT32:// TODO handle column types + case INT32: resultColumn.putInts(rowId, validValueNum, column.getInts(offset, validValueNum), 0); + if (column.dataType() == DataTypes.IntegerType || canReadAsIntDecimal(column.dataType())) { + resultColumn.putInts(rowId, validValueNum, column.getInts(offset, validValueNum), 0); + } else if (column.dataType() == DataTypes.LongType) { + resultColumn.putLongs(rowId, num - offset, column.getLongs(offset, validValueNum), 0); + } else if (column.dataType() == DataTypes.ByteType) { + resultColumn.putBytes(rowId, num - offset, column.getBytes(offset, validValueNum), 0); + } else if (column.dataType() == DataTypes.ShortType) { + resultColumn.putShorts(rowId, num - offset, column.getShorts(offset, validValueNum), 0); + } else if (column.dataType() == DataTypes.DateType ) { + resultColumn.putInts(rowId, validValueNum, column.getInts(offset, validValueNum), 0); + } else { + throw constructConvertNotSupportedException(descriptor, column); + } break; - case INT64:// TODO handle column types - resultColumn.putLongs(rowId, num - offset, column.getLongs(offset, validValueNum), 0); + case INT64: + if (column.dataType() == DataTypes.LongType || canReadAsLongDecimal(column.dataType())) { + resultColumn.putLongs(rowId, num - offset, column.getLongs(offset, validValueNum), 0); + } else if (originalType == OriginalType.UINT_64) { + for (int i = 0; i < num - offset; i++) { + resultColumn.putByteArray(rowId + i, column.getBinary(offset + i)); + } + } else if (originalType == OriginalType.TIMESTAMP_MICROS) { + resultColumn.putLongs(rowId, num - offset, column.getLongs(offset, validValueNum), 0); + } else if (originalType == OriginalType.TIMESTAMP_MILLIS) { + resultColumn.putLongs(rowId, num - offset, column.getLongs(offset, validValueNum), 0); + } else { + throw constructConvertNotSupportedException(descriptor, column); + } break; case FLOAT: resultColumn.putFloats(rowId, validValueNum, column.getFloats(offset, validValueNum), 0); From ddc233a8e127c8f4a2664259f5423a735a9a815f Mon Sep 17 00:00:00 2001 From: Li Xian Date: Fri, 9 Apr 2021 01:01:47 +0800 Subject: [PATCH 5/7] handle types with column datatype --- .../parquet/VectorizedColumnReader.java | 141 ++++++++---------- .../parquet/ParquetColumnIndexSuite.scala | 61 ++++++++ 2 files changed, 124 insertions(+), 78 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index ec7b8ecb2d62f..948622c38772d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -275,6 +275,13 @@ void readBatch(int total, int columnSize, WritableColumnVector resultColumn) thr for (int i = 0; i < total; i++) { rowIndexes[i] = rowIndexesIterator.next(); } + + // if row indexes is exactly matching the range we are going to read + // there is no need to do additional row index synchronization + boolean continuousRange = (rowIndexes[total - 1] - rowIndexes[0] + 1) == total; + if (continuousRange && rowIndexes[0] == currentRow) { + column = resultColumn; + } } else { // write to result column directly if no row indexes if present column = resultColumn; @@ -368,7 +375,7 @@ void readBatch(int total, int columnSize, WritableColumnVector resultColumn) thr } } - if (rowIndexesIterator != null) { + if (resultColumn != column) { // copy values from temp column to result column boolean continuousRange = (rowIndexes[total - 1] - rowIndexes[0] + 1) == total; if (continuousRange) { @@ -376,93 +383,71 @@ void readBatch(int total, int columnSize, WritableColumnVector resultColumn) thr int offset = (int) (rowIndexes[rowId] - currentRow); if (offset < num) { int validValueNum = num - offset; - switch (typeName) { - case BOOLEAN: - for (int i = 0; i < validValueNum; i++) { - resultColumn.putBoolean(rowId + i, column.getBoolean(offset + i)); - } - break; - case INT32: - resultColumn.putInts(rowId, validValueNum, column.getInts(offset, validValueNum), 0); - if (column.dataType() == DataTypes.IntegerType || canReadAsIntDecimal(column.dataType())) { - resultColumn.putInts(rowId, validValueNum, column.getInts(offset, validValueNum), 0); - } else if (column.dataType() == DataTypes.LongType) { - resultColumn.putLongs(rowId, num - offset, column.getLongs(offset, validValueNum), 0); - } else if (column.dataType() == DataTypes.ByteType) { - resultColumn.putBytes(rowId, num - offset, column.getBytes(offset, validValueNum), 0); - } else if (column.dataType() == DataTypes.ShortType) { - resultColumn.putShorts(rowId, num - offset, column.getShorts(offset, validValueNum), 0); - } else if (column.dataType() == DataTypes.DateType ) { - resultColumn.putInts(rowId, validValueNum, column.getInts(offset, validValueNum), 0); - } else { - throw constructConvertNotSupportedException(descriptor, column); - } - break; - case INT64: - if (column.dataType() == DataTypes.LongType || canReadAsLongDecimal(column.dataType())) { - resultColumn.putLongs(rowId, num - offset, column.getLongs(offset, validValueNum), 0); - } else if (originalType == OriginalType.UINT_64) { - for (int i = 0; i < num - offset; i++) { - resultColumn.putByteArray(rowId + i, column.getBinary(offset + i)); - } - } else if (originalType == OriginalType.TIMESTAMP_MICROS) { - resultColumn.putLongs(rowId, num - offset, column.getLongs(offset, validValueNum), 0); - } else if (originalType == OriginalType.TIMESTAMP_MILLIS) { - resultColumn.putLongs(rowId, num - offset, column.getLongs(offset, validValueNum), 0); - } else { - throw constructConvertNotSupportedException(descriptor, column); - } - break; - case FLOAT: - resultColumn.putFloats(rowId, validValueNum, column.getFloats(offset, validValueNum), 0); - break; - case DOUBLE: - resultColumn.putDoubles(rowId, validValueNum, column.getDoubles(offset, validValueNum), 0); - break; - case INT96: - case BINARY: - if (column.dataType() == DataTypes.TimestampType) { - resultColumn.putLongs(rowId, num - offset, column.getLongs(offset, validValueNum), 0); - } else { - for (int i = 0; i < num - offset; i++) { - resultColumn.putByteArray(rowId + i, column.getBinary(offset + i)); - } - } - break; - case FIXED_LEN_BYTE_ARRAY: - if (canReadAsIntDecimal(column.dataType())) { - resultColumn.putInts(rowId, validValueNum, column.getInts(offset, validValueNum), 0); - } else if (canReadAsLongDecimal(column.dataType())) { - resultColumn.putLongs(rowId, num - offset, column.getLongs(offset, validValueNum), 0); - } else if (canReadAsBinaryDecimal(column.dataType())) { - for (int i = 0; i < num - offset; i++) { - resultColumn.putByteArray(rowId + i, column.getBinary(offset + i)); - } - } else { - throw constructConvertNotSupportedException(descriptor, column); - } - break; - default: - throw new IOException("Unsupported type: " + typeName); + if (resultColumn.dataType() == DataTypes.ByteType) { + resultColumn.putBytes(rowId, validValueNum, column.getBytes(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.ShortType) { + resultColumn.putShorts(rowId, validValueNum, column.getShorts(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.IntegerType) { + resultColumn.putInts(rowId, validValueNum, column.getInts(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.LongType) { + resultColumn.putLongs(rowId, validValueNum, column.getLongs(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.DateType) { + resultColumn.putInts(rowId, validValueNum, column.getInts(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.FloatType) { + resultColumn.putFloats(rowId, validValueNum, column.getFloats(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.DoubleType) { + resultColumn.putDoubles(rowId, validValueNum, column.getDoubles(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.TimestampType) { + resultColumn.putLongs(rowId, validValueNum, column.getLongs(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.DayTimeIntervalType) { + resultColumn.putLongs(rowId, validValueNum, column.getLongs(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.YearMonthIntervalType) { + resultColumn.putInts(rowId, validValueNum, column.getInts(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.BooleanType) { + for (int i = 0; i < validValueNum; i++) { + resultColumn.putBoolean(rowId + i, column.getBoolean(rowId + offset + i)); + } + } else { + for (int i = 0; i < validValueNum; i++) { + resultColumn.putByteArray(rowId + i, column.getBinary(rowId + offset + i)); + } } - rowId += (num - offset); - total -= (num - offset); + rowId += validValueNum; + total -= validValueNum; } } else { // need to check every row - for (int i = 0; i < num && total > 0; ) { + for (int i = 0, startingRowId = rowId; i < num && total > 0; ) { while (currentRow + i < rowIndexes[rowId] && i < num) { i++; } if (i >= num) { break; } - switch (typeName) { - case INT64: - resultColumn.putLong(rowId, column.getLong(i)); - break; - case BINARY: - resultColumn.putByteArray(rowId, column.getBinary(i)); + if (resultColumn.dataType() == DataTypes.ByteType) { + resultColumn.putByte(rowId, column.getByte(startingRowId + i)); + } else if (resultColumn.dataType() == DataTypes.ShortType) { + resultColumn.putShort(rowId, column.getShort(startingRowId + i)); + } else if (resultColumn.dataType() == DataTypes.IntegerType) { + resultColumn.putInt(rowId, column.getInt(startingRowId + i)); + } else if (resultColumn.dataType() == DataTypes.LongType) { + resultColumn.putLong(rowId, column.getLong(startingRowId + i)); + } else if (resultColumn.dataType() == DataTypes.DateType) { + resultColumn.putInt(rowId, column.getInt(startingRowId + i)); + } else if (resultColumn.dataType() == DataTypes.FloatType) { + resultColumn.putFloat(rowId, column.getFloat(startingRowId + i)); + } else if (resultColumn.dataType() == DataTypes.DoubleType) { + resultColumn.putDouble(rowId, column.getDouble(startingRowId + i)); + } else if (resultColumn.dataType() == DataTypes.TimestampType) { + resultColumn.putLong(rowId, column.getLong(startingRowId + i)); + } else if (resultColumn.dataType() == DataTypes.DayTimeIntervalType) { + resultColumn.putLong(rowId, column.getLong(startingRowId + i)); + } else if (resultColumn.dataType() == DataTypes.YearMonthIntervalType) { + resultColumn.putInt(rowId, column.getInt(startingRowId + i)); + } else if (resultColumn.dataType() == DataTypes.BooleanType) { + resultColumn.putBoolean(rowId, column.getBoolean(startingRowId + i)); + } else { + resultColumn.putByteArray(rowId, column.getBinary(startingRowId + i)); } rowId++; total--; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala new file mode 100644 index 0000000000000..03dfde956d1f6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSparkSession + + +class ParquetColumnIndexSuite extends QueryTest with ParquetTest with SharedSparkSession { + + import testImplicits._ + + /** + * create parquet file with two columns and unaligned pages + * pages will be of the following layout + * col_1 500 500 500 500 + * |---------|---------|---------|---------| + * |-------|-----|-----|---|---|---|---|---| + * col_2 400 300 200 200 200 200 200 200 + */ + def checkUnalignedPages(action: DataFrame => DataFrame): Unit = { + withTempPath(file => { + val ds = spark.range(0, 2000).map(i => (i, i + ":" + "o" * (i / 100).toInt)) + ds.coalesce(1) + .write + .option("parquet.page.size", "4096") + .parquet(file.getCanonicalPath) + + val parquetDf = spark.read.parquet(file.getCanonicalPath) + + checkAnswer(action(parquetDf), action(ds.toDF())) + }) + } + + test("read from unaligned pages - single value filters") { + checkUnalignedPages(df => df.filter("_1 = 500")) + checkUnalignedPages(df => df.filter("_1 = 500 or _1 = 1500")) + checkUnalignedPages(df => df.filter("_1 = 500 or _1 = 501 or _1 = 1500")) + checkUnalignedPages(df => df.filter("_1 = 500 or _1 = 501 or _1 = 1000 or _1 = 1500")) + } + + test("read from unaligned pages - range filter") { + checkUnalignedPages(df => df.filter("_1 >= 500 and _1 < 1000")) + checkUnalignedPages(df => df.filter("(_1 >= 500 and _1 < 1000) or (_1 >= 1500 and _1 < 1600)")) + } +} From e95e7a11d25a501d6cc2570cec1184a61e6be611 Mon Sep 17 00:00:00 2001 From: Li Xian Date: Mon, 12 Apr 2021 01:37:46 +0800 Subject: [PATCH 6/7] handle dict encoding --- .../parquet/VectorizedColumnReader.java | 33 ++++--- .../parquet/ParquetColumnIndexSuite.scala | 87 ++++++++++++++++--- 2 files changed, 97 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 948622c38772d..8dc167badd99b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -289,11 +289,15 @@ void readBatch(int total, int columnSize, WritableColumnVector resultColumn) thr int rowId = 0; WritableColumnVector dictionaryIds = null; + WritableColumnVector resultDictionaryIds = null; if (dictionary != null) { // SPARK-16334: We only maintain a single dictionary per row batch, so that it can be used to // decode all previous dictionary encoded pages if we ever encounter a non-dictionary encoded // page. dictionaryIds = column.reserveDictionaryIds(total); + if (column != resultColumn) { + resultDictionaryIds = resultColumn.reserveDictionaryIds(total); + } } while (total > 0) { @@ -334,6 +338,9 @@ void readBatch(int total, int columnSize, WritableColumnVector resultColumn) thr boolean needTransform = castLongToInt || isUnsignedInt32 || isUnsignedInt64; column.setDictionary(new ParquetDictionary(dictionary, needTransform)); + if (column != resultColumn) { // set result column as well + resultColumn.setDictionary(new ParquetDictionary(dictionary, needTransform)); + } } else { decodeDictionaryIds(rowId, num, column, dictionaryIds); } @@ -341,7 +348,11 @@ void readBatch(int total, int columnSize, WritableColumnVector resultColumn) thr if (column.hasDictionary() && rowId != 0) { // This batch already has dictionary encoded values but this new page is not. The batch // does not support a mix of dictionary and not so we will decode the dictionary. - decodeDictionaryIds(0, rowId, column, column.getDictionaryIds()); + if (column != resultColumn) { + decodeDictionaryIds(0, rowId, resultColumn, resultColumn.getDictionaryIds()); + } else { + decodeDictionaryIds(0, rowId, column, column.getDictionaryIds()); + } } column.setDictionary(null); switch (typeName) { @@ -383,7 +394,9 @@ void readBatch(int total, int columnSize, WritableColumnVector resultColumn) thr int offset = (int) (rowIndexes[rowId] - currentRow); if (offset < num) { int validValueNum = num - offset; - if (resultColumn.dataType() == DataTypes.ByteType) { + if (isCurrentPageDictionaryEncoded && column.hasDictionary()) { + resultDictionaryIds.putInts(rowId, validValueNum, dictionaryIds.getInts(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.ByteType) { resultColumn.putBytes(rowId, validValueNum, column.getBytes(rowId + offset, validValueNum), 0); } else if (resultColumn.dataType() == DataTypes.ShortType) { resultColumn.putShorts(rowId, validValueNum, column.getShorts(rowId + offset, validValueNum), 0); @@ -391,18 +404,14 @@ void readBatch(int total, int columnSize, WritableColumnVector resultColumn) thr resultColumn.putInts(rowId, validValueNum, column.getInts(rowId + offset, validValueNum), 0); } else if (resultColumn.dataType() == DataTypes.LongType) { resultColumn.putLongs(rowId, validValueNum, column.getLongs(rowId + offset, validValueNum), 0); - } else if (resultColumn.dataType() == DataTypes.DateType) { - resultColumn.putInts(rowId, validValueNum, column.getInts(rowId + offset, validValueNum), 0); } else if (resultColumn.dataType() == DataTypes.FloatType) { resultColumn.putFloats(rowId, validValueNum, column.getFloats(rowId + offset, validValueNum), 0); } else if (resultColumn.dataType() == DataTypes.DoubleType) { resultColumn.putDoubles(rowId, validValueNum, column.getDoubles(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.DateType) { + resultColumn.putInts(rowId, validValueNum, column.getInts(rowId + offset, validValueNum), 0); } else if (resultColumn.dataType() == DataTypes.TimestampType) { resultColumn.putLongs(rowId, validValueNum, column.getLongs(rowId + offset, validValueNum), 0); - } else if (resultColumn.dataType() == DataTypes.DayTimeIntervalType) { - resultColumn.putLongs(rowId, validValueNum, column.getLongs(rowId + offset, validValueNum), 0); - } else if (resultColumn.dataType() == DataTypes.YearMonthIntervalType) { - resultColumn.putInts(rowId, validValueNum, column.getInts(rowId + offset, validValueNum), 0); } else if (resultColumn.dataType() == DataTypes.BooleanType) { for (int i = 0; i < validValueNum; i++) { resultColumn.putBoolean(rowId + i, column.getBoolean(rowId + offset + i)); @@ -424,7 +433,9 @@ void readBatch(int total, int columnSize, WritableColumnVector resultColumn) thr if (i >= num) { break; } - if (resultColumn.dataType() == DataTypes.ByteType) { + if (isCurrentPageDictionaryEncoded && column.hasDictionary()) { + resultDictionaryIds.putInt(rowId, dictionaryIds.getInt(startingRowId + i)); + } if (resultColumn.dataType() == DataTypes.ByteType) { resultColumn.putByte(rowId, column.getByte(startingRowId + i)); } else if (resultColumn.dataType() == DataTypes.ShortType) { resultColumn.putShort(rowId, column.getShort(startingRowId + i)); @@ -440,10 +451,6 @@ void readBatch(int total, int columnSize, WritableColumnVector resultColumn) thr resultColumn.putDouble(rowId, column.getDouble(startingRowId + i)); } else if (resultColumn.dataType() == DataTypes.TimestampType) { resultColumn.putLong(rowId, column.getLong(startingRowId + i)); - } else if (resultColumn.dataType() == DataTypes.DayTimeIntervalType) { - resultColumn.putLong(rowId, column.getLong(startingRowId + i)); - } else if (resultColumn.dataType() == DataTypes.YearMonthIntervalType) { - resultColumn.putInt(rowId, column.getInt(startingRowId + i)); } else if (resultColumn.dataType() == DataTypes.BooleanType) { resultColumn.putBoolean(rowId, column.getBoolean(startingRowId + i)); } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala index 03dfde956d1f6..6cfd52ed66be5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala @@ -33,7 +33,7 @@ class ParquetColumnIndexSuite extends QueryTest with ParquetTest with SharedSpar * |-------|-----|-----|---|---|---|---|---| * col_2 400 300 200 200 200 200 200 200 */ - def checkUnalignedPages(action: DataFrame => DataFrame): Unit = { + def checkUnalignedPages(actions: (DataFrame => DataFrame)*): Unit = { withTempPath(file => { val ds = spark.range(0, 2000).map(i => (i, i + ":" + "o" * (i / 100).toInt)) ds.coalesce(1) @@ -43,19 +43,86 @@ class ParquetColumnIndexSuite extends QueryTest with ParquetTest with SharedSpar val parquetDf = spark.read.parquet(file.getCanonicalPath) - checkAnswer(action(parquetDf), action(ds.toDF())) + actions.foreach{ action => + checkAnswer(action(parquetDf), action(ds.toDF())) + } }) } - test("read from unaligned pages - single value filters") { - checkUnalignedPages(df => df.filter("_1 = 500")) - checkUnalignedPages(df => df.filter("_1 = 500 or _1 = 1500")) - checkUnalignedPages(df => df.filter("_1 = 500 or _1 = 501 or _1 = 1500")) - checkUnalignedPages(df => df.filter("_1 = 500 or _1 = 501 or _1 = 1000 or _1 = 1500")) + test("reading from unaligned pages - test filters") { + checkUnalignedPages( + // single value filter + df => df.filter("_1 = 500"), + df => df.filter("_1 = 500 or _1 = 1500"), + df => df.filter("_1 = 500 or _1 = 501 or _1 = 1500"), + df => df.filter("_1 = 500 or _1 = 501 or _1 = 1000 or _1 = 1500"), + // range filter + df => df.filter("_1 >= 500 and _1 < 1000"), + df => df.filter("(_1 >= 500 and _1 < 1000) or (_1 >= 1500 and _1 < 1600)") + ) } - test("read from unaligned pages - range filter") { - checkUnalignedPages(df => df.filter("_1 >= 500 and _1 < 1000")) - checkUnalignedPages(df => df.filter("(_1 >= 500 and _1 < 1000) or (_1 >= 1500 and _1 < 1600)")) + test("test reading unaligned pages - test all types") { + withTempPath(file => { + val df = spark.range(0, 2000).selectExpr( + "id as _1", + "cast(id as short) as _3", + "cast(id as int) as _4", + "cast(id as float) as _5", + "cast(id as double) as _6", + "cast(id as decimal(20,0)) as _7", + "cast(cast(1618161925000 + id * 1000 * 60 * 60 * 24 as timestamp) as date) as _9", + "cast(1618161925000 + id as timestamp) as _10" + ) + df.coalesce(1) + .write + .option("parquet.page.size", "4096") + .parquet(file.getCanonicalPath) + + val parquetDf = spark.read.parquet(file.getCanonicalPath) + val singleValueFilterExpr = "_1 = 500 or _1 = 1500" + checkAnswer( + parquetDf.filter(singleValueFilterExpr), + df.filter(singleValueFilterExpr) + ) + val rangeFilterExpr = "_1 > 500 " + checkAnswer( + parquetDf.filter(rangeFilterExpr), + df.filter(rangeFilterExpr) + ) + }) + } + + test("test reading unaligned pages - test all types (dict encode)") { + withTempPath(file => { + val df = spark.range(0, 2000).selectExpr( + "id as _1", + "cast(id % 10 as byte) as _2", + "cast(id % 10 as short) as _3", + "cast(id % 10 as int) as _4", + "cast(id % 10 as float) as _5", + "cast(id % 10 as double) as _6", + "cast(id % 10 as decimal(20,0)) as _7", + "cast(id % 2 as boolean) as _8", + "cast(cast(1618161925000 + (id % 10) * 1000 * 60 * 60 * 24 as timestamp) as date) as _9", + "cast(1618161925000 + (id % 10) as timestamp) as _10" + ) + df.coalesce(1) + .write + .option("parquet.page.size", "4096") + .parquet(file.getCanonicalPath) + + val parquetDf = spark.read.parquet(file.getCanonicalPath) + val singleValueFilterExpr = "_1 = 500 or _1 = 1500" + checkAnswer( + parquetDf.filter(singleValueFilterExpr), + df.filter(singleValueFilterExpr) + ) + val rangeFilterExpr = "_1 > 500 " + checkAnswer( + parquetDf.filter(rangeFilterExpr), + df.filter(rangeFilterExpr) + ) + }) } } From 5a3c085da144bb00387ca63706d13d3974ad07de Mon Sep 17 00:00:00 2001 From: Xian Li Date: Tue, 13 Apr 2021 13:56:09 +0800 Subject: [PATCH 7/7] cleanup --- .../org/apache/spark/sql/SimpleTest.scala | 40 ------------------- .../ParquetColumnIndexBenchmark.scala | 21 +++++----- .../parquet/ParquetColumnIndexSuite.scala | 14 +++---- 3 files changed, 18 insertions(+), 57 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/SimpleTest.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SimpleTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/SimpleTest.scala deleted file mode 100644 index 763dc0d55948b..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/SimpleTest.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - - -object SimpleTest { - - def main(args: Array[String]): Unit = { - - val spark = SparkSession.builder - .appName("Simple Application") - .master("local[*]") - .getOrCreate() - spark.sparkContext.setCallSite("short") -// spark.sparkContext.setCallSite(new CallSite("shot","long")) - - val df = spark.read -// .option("parquet.filter.columnindex.enabled", "false") - .parquet("/Users/lxian/Documents/parquet-playground/part-00000-66712089-3639-4c41-84fb-36790dec7c79-c000.snappy.parquet") - -// df.filter(" _1 = 100000").queryExecution.debug.codegen() - df.filter(" _1 = 100000").show(false) -// df.filter(" _1 = 100000").show(false) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ParquetColumnIndexBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ParquetColumnIndexBenchmark.scala index 139b70e36c436..041ece7da4d4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ParquetColumnIndexBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ParquetColumnIndexBenchmark.scala @@ -28,15 +28,15 @@ import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{DataFrame, SparkSession} /** - * Benchmark to measure read performance with Parquet column index. - * To run this benchmark: - * {{{ - * 1. without sbt: bin/spark-submit --class - * 2. build/sbt "sql/test:runMain " - * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " - * Results will be written to "benchmarks/ParquetFilterPushdownBenchmark-results.txt". - * }}} - */ + * Benchmark to measure read performance with Parquet column index. + * To run this benchmark: + * {{{ + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/ParquetFilterPushdownBenchmark-results.txt". + * }}} + */ object ParquetColumnIndexBenchmark extends SqlBasedBenchmark { override def getSparkSession: SparkSession = { @@ -120,7 +120,8 @@ object ParquetColumnIndexBenchmark extends SqlBasedBenchmark { prepareTable(dir, numRows) filterPushDownBenchmark(numRows, "multi range filters", - s" (_1 > ($numRows - 3000000) and _1 < ($numRows - 2000000)) or ( _1 > ($numRows - 1000000) and _1 < ($numRows - 1000))") + s" (_1 > ($numRows - 3000000) and _1 < ($numRows - 2000000))" + + s" or ( _1 > ($numRows - 1000000) and _1 < ($numRows - 1000))") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala index 6cfd52ed66be5..44bb93ff9c358 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala @@ -26,13 +26,13 @@ class ParquetColumnIndexSuite extends QueryTest with ParquetTest with SharedSpar import testImplicits._ /** - * create parquet file with two columns and unaligned pages - * pages will be of the following layout - * col_1 500 500 500 500 - * |---------|---------|---------|---------| - * |-------|-----|-----|---|---|---|---|---| - * col_2 400 300 200 200 200 200 200 200 - */ + * create parquet file with two columns and unaligned pages + * pages will be of the following layout + * col_1 500 500 500 500 + * |---------|---------|---------|---------| + * |-------|-----|-----|---|---|---|---|---| + * col_2 400 300 200 200 200 200 200 200 + */ def checkUnalignedPages(actions: (DataFrame => DataFrame)*): Unit = { withTempPath(file => { val ds = spark.range(0, 2000).map(i => (i, i + ":" + "o" * (i / 100).toInt))