diff --git a/dev/run-tests.py b/dev/run-tests.py index e54e098551514..f904a397f83c4 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -343,7 +343,8 @@ def get_hive_profiles(hive_version): def build_spark_maven(extra_profiles): # Enable all of the profiles for the build: build_profiles = extra_profiles + modules.root.build_profile_flags - mvn_goals = ["clean", "package", "-DskipTests"] + mvn_goals = ["dependency:purge-local-repository", "-Dinclude=org.apache.parquet", + "clean", "package", "-DskipTests"] profiles_and_goals = build_profiles + mvn_goals print("[info] Building Spark using Maven with these arguments: ", " ".join(profiles_and_goals)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 2ffccdd06e504..c6594da3ddd05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -47,6 +47,8 @@ object DateTimeUtils { final val JULIAN_DAY_OF_EPOCH = 2440588 final val TimeZoneUTC = TimeZone.getTimeZone("UTC") + // for why ".normalized", see https://stackoverflow.com/a/39507023/2965879 + final val ZoneIdUTC = ZoneId.of("UTC").normalized() val TIMEZONE_OPTION = "timeZone" diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index dac18b1abe047..545456ea9c8f6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -32,6 +32,7 @@ import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.io.api.Binary; import org.apache.parquet.schema.DecimalMetadata; +import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.OriginalType; import org.apache.parquet.schema.PrimitiveType; @@ -46,6 +47,10 @@ import org.apache.spark.sql.types.DecimalType; import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL; +import static org.apache.parquet.schema.LogicalTypeAnnotation.DateLogicalTypeAnnotation; +import static org.apache.parquet.schema.LogicalTypeAnnotation.TimestampLogicalTypeAnnotation; +import static org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit.MICROS; +import static org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit.MILLIS; import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32; import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64; import static org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.ValuesReaderIntIterator; @@ -104,7 +109,7 @@ public class VectorizedColumnReader { private final PageReader pageReader; private final ColumnDescriptor descriptor; - private final OriginalType originalType; + private final LogicalTypeAnnotation logicalTypeAnnotation; // The timezone conversion to apply to int96 timestamps. Null if no conversion. private final ZoneId convertTz; private static final ZoneId UTC = ZoneOffset.UTC; @@ -136,7 +141,7 @@ private boolean canReadAsBinaryDecimal(DataType dt) { public VectorizedColumnReader( ColumnDescriptor descriptor, - OriginalType originalType, + LogicalTypeAnnotation logicalTypeAnnotation, PageReader pageReader, ZoneId convertTz, String datetimeRebaseMode, @@ -144,7 +149,7 @@ public VectorizedColumnReader( this.descriptor = descriptor; this.pageReader = pageReader; this.convertTz = convertTz; - this.originalType = originalType; + this.logicalTypeAnnotation = logicalTypeAnnotation; this.maxDefLevel = descriptor.getMaxDefinitionLevel(); DictionaryPage dictionaryPage = pageReader.readDictionaryPage(); @@ -192,13 +197,14 @@ private boolean isLazyDecodingSupported(PrimitiveType.PrimitiveTypeName typeName boolean isSupported = false; switch (typeName) { case INT32: - isSupported = originalType != OriginalType.DATE || "CORRECTED".equals(datetimeRebaseMode); + isSupported = (!(logicalTypeAnnotation instanceof DateLogicalTypeAnnotation) + || "CORRECTED".equals(datetimeRebaseMode)); break; case INT64: - if (originalType == OriginalType.TIMESTAMP_MICROS) { + if (isTimestampWithUnit(logicalTypeAnnotation, MICROS)) { isSupported = "CORRECTED".equals(datetimeRebaseMode); } else { - isSupported = originalType != OriginalType.TIMESTAMP_MILLIS; + isSupported = !(logicalTypeAnnotation instanceof TimestampLogicalTypeAnnotation); } break; case FLOAT: @@ -278,6 +284,7 @@ void readBatch(int total, WritableColumnVector column) throws IOException { // Column vector supports lazy decoding of dictionary values so just set the dictionary. // We can't do this if rowId != 0 AND the column doesn't have a dictionary (i.e. some // non-dictionary encoded values have already been added). + // TODO: replace OriginalType with something from LogicalTypeAnnotation PrimitiveType primitiveType = descriptor.getPrimitiveType(); if (primitiveType.getOriginalType() == OriginalType.DECIMAL && primitiveType.getDecimalMetadata().getPrecision() <= Decimal.MAX_INT_DIGITS() && @@ -398,14 +405,22 @@ private void decodeDictionaryIds( case INT64: if (column.dataType() == DataTypes.LongType || canReadAsLongDecimal(column.dataType()) || - (originalType == OriginalType.TIMESTAMP_MICROS && + (isTimestampWithUnit(logicalTypeAnnotation, MICROS) && "CORRECTED".equals(datetimeRebaseMode))) { for (int i = rowId; i < rowId + num; ++i) { if (!column.isNullAt(i)) { column.putLong(i, dictionary.decodeToLong(dictionaryIds.getDictId(i))); } } - } else if (originalType == OriginalType.TIMESTAMP_MILLIS) { + } else if (isTimestampWithUnit(logicalTypeAnnotation, MICROS)) { + final boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode); + for (int i = rowId; i < rowId + num; ++i) { + if (!column.isNullAt(i)) { + long julianMicros = dictionary.decodeToLong(dictionaryIds.getDictId(i)); + column.putLong(i, rebaseMicros(julianMicros, failIfRebase)); + } + } + } else if (isTimestampWithUnit(logicalTypeAnnotation, MILLIS)) { if ("CORRECTED".equals(datetimeRebaseMode)) { for (int i = rowId; i < rowId + num; ++i) { if (!column.isNullAt(i)) { @@ -423,14 +438,6 @@ private void decodeDictionaryIds( } } } - } else if (originalType == OriginalType.TIMESTAMP_MICROS) { - final boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode); - for (int i = rowId; i < rowId + num; ++i) { - if (!column.isNullAt(i)) { - long julianMicros = dictionary.decodeToLong(dictionaryIds.getDictId(i)); - column.putLong(i, rebaseMicros(julianMicros, failIfRebase)); - } - } } else { throw constructConvertNotSupportedException(descriptor, column); } @@ -592,7 +599,7 @@ private void readLongBatch(int rowId, int num, WritableColumnVector column) thro defColumn.readLongs( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, DecimalType.is32BitDecimalType(column.dataType())); - } else if (originalType == OriginalType.TIMESTAMP_MICROS) { + } else if (isTimestampWithUnit(logicalTypeAnnotation, MICROS)) { if ("CORRECTED".equals(datetimeRebaseMode)) { defColumn.readLongs( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, false); @@ -601,7 +608,7 @@ private void readLongBatch(int rowId, int num, WritableColumnVector column) thro defColumn.readLongsWithRebase( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, failIfRebase); } - } else if (originalType == OriginalType.TIMESTAMP_MILLIS) { + } else if (isTimestampWithUnit(logicalTypeAnnotation, MILLIS)) { if ("CORRECTED".equals(datetimeRebaseMode)) { for (int i = 0; i < num; i++) { if (defColumn.readInteger() == maxDefLevel) { @@ -626,6 +633,13 @@ private void readLongBatch(int rowId, int num, WritableColumnVector column) thro } } + private boolean isTimestampWithUnit( + LogicalTypeAnnotation logicalTypeAnnotation, + LogicalTypeAnnotation.TimeUnit timeUnit) { + return (logicalTypeAnnotation instanceof TimestampLogicalTypeAnnotation) && + ((TimestampLogicalTypeAnnotation) logicalTypeAnnotation).getUnit() == timeUnit; + } + private void readFloatBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: support implicit cast to double? diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 1b159534c8a4f..32455278c4fb7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -332,7 +332,7 @@ private void checkEndOfRowGroup() throws IOException { if (missingColumns[i]) continue; columnReaders[i] = new VectorizedColumnReader( columns.get(i), - types.get(i).getOriginalType(), + types.get(i).getLogicalTypeAnnotation(), pages.getPageReader(columns.get(i)), convertTz, datetimeRebaseMode, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 64a1ac8675104..3785a8e3fccd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -272,20 +272,6 @@ class ParquetFileFormat lazy val footerFileMetaData = ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS).getFileMetaData - // Try to push down filters when filter push-down is enabled. - val pushed = if (enableParquetFilterPushDown) { - val parquetSchema = footerFileMetaData.getSchema - val parquetFilters = new ParquetFilters(parquetSchema, pushDownDate, pushDownTimestamp, - pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive) - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(parquetFilters.createFilter(_)) - .reduceOption(FilterApi.and) - } else { - None - } // PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 timestamps' // *only* if the file was created by something other than "parquet-mr", so check the actual @@ -302,6 +288,22 @@ class ParquetFileFormat None } + // Try to push down filters when filter push-down is enabled. + val pushed = if (enableParquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new ParquetFilters(parquetSchema, pushDownDate, pushDownTimestamp, + pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive, + convertTz.orNull) + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(parquetFilters.createFilter(_)) + .reduceOption(FilterApi.and) + } else { + None + } + val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 73910c3943e9a..84fcc5630b2d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong} import java.math.{BigDecimal => JBigDecimal} import java.sql.{Date, Timestamp} -import java.time.{Instant, LocalDate} +import java.time.{Instant, LocalDate, ZoneId} import java.util.Locale import scala.collection.JavaConverters.asScalaBufferConverter @@ -28,8 +28,8 @@ import scala.collection.JavaConverters.asScalaBufferConverter import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.SparkFilterApi._ import org.apache.parquet.io.api.Binary -import org.apache.parquet.schema.{DecimalMetadata, GroupType, MessageType, OriginalType, PrimitiveComparator, PrimitiveType, Type} -import org.apache.parquet.schema.OriginalType._ +import org.apache.parquet.schema._ +import org.apache.parquet.schema.LogicalTypeAnnotation._ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ @@ -47,7 +47,8 @@ class ParquetFilters( pushDownDecimal: Boolean, pushDownStartWith: Boolean, pushDownInFilterThreshold: Int, - caseSensitive: Boolean) { + caseSensitive: Boolean, + convertTz: ZoneId) { // A map which contains parquet field name and data type, if predicate push down applies. // // Each key in `nameToParquetField` represents a column; `dots` are used as separators for @@ -62,8 +63,8 @@ class ParquetFilters( fields.flatMap { case p: PrimitiveType => Some(ParquetPrimitiveField(fieldNames = parentFieldNames :+ p.getName, - fieldType = ParquetSchemaType(p.getOriginalType, - p.getPrimitiveTypeName, p.getTypeLength, p.getDecimalMetadata))) + fieldType = ParquetSchemaType(p.getLogicalTypeAnnotation(), + p.getLogicalTypeAnnotation().getClass, p.getPrimitiveTypeName, p.getTypeLength))) // Note that when g is a `Struct`, `g.getOriginalType` is `null`. // When g is a `Map`, `g.getOriginalType` is `MAP`. // When g is a `List`, `g.getOriginalType` is `LIST`. @@ -105,23 +106,25 @@ class ParquetFilters( fieldType: ParquetSchemaType) private case class ParquetSchemaType( - originalType: OriginalType, - primitiveTypeName: PrimitiveTypeName, - length: Int, - decimalMetadata: DecimalMetadata) - - private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0, null) - private val ParquetByteType = ParquetSchemaType(INT_8, INT32, 0, null) - private val ParquetShortType = ParquetSchemaType(INT_16, INT32, 0, null) - private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0, null) - private val ParquetLongType = ParquetSchemaType(null, INT64, 0, null) - private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0, null) - private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0, null) - private val ParquetStringType = ParquetSchemaType(UTF8, BINARY, 0, null) - private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0, null) - private val ParquetDateType = ParquetSchemaType(DATE, INT32, 0, null) - private val ParquetTimestampMicrosType = ParquetSchemaType(TIMESTAMP_MICROS, INT64, 0, null) - private val ParquetTimestampMillisType = ParquetSchemaType(TIMESTAMP_MILLIS, INT64, 0, null) + logicalType: LogicalTypeAnnotation, + logicalTypeClass: Class[_ <: LogicalTypeAnnotation], + primitiveTypeName: PrimitiveTypeName, + length: Int) + + private val ParquetBooleanType = ParquetSchemaType(null, null, BOOLEAN, 0) + private val ParquetByteType = ParquetSchemaType(intType(8, true), + classOf[IntLogicalTypeAnnotation], INT32, 0) + private val ParquetShortType = ParquetSchemaType(intType(16, true), + classOf[IntLogicalTypeAnnotation], INT32, 0) + private val ParquetIntegerType = ParquetSchemaType(null, null, INT32, 0) + private val ParquetLongType = ParquetSchemaType(null, null, INT64, 0) + private val ParquetFloatType = ParquetSchemaType(null, null, FLOAT, 0) + private val ParquetDoubleType = ParquetSchemaType(null, null, DOUBLE, 0) + private val ParquetStringType = ParquetSchemaType(stringType(), + classOf[StringLogicalTypeAnnotation], BINARY, 0) + private val ParquetBinaryType = ParquetSchemaType(null, null, BINARY, 0) + private val ParquetDateType = ParquetSchemaType(dateType(), + classOf[DateLogicalTypeAnnotation], INT32, 0) private def dateToDays(date: Any): Int = date match { case d: Date => DateTimeUtils.fromJavaDate(d) @@ -152,6 +155,25 @@ class ParquetFilters( Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) } + private def timestampValue(timestampType: TimestampLogicalTypeAnnotation, v: Any): JLong = + Option(v).map((timestampType.getUnit, timestampType.isAdjustedToUTC) match { + case (TimeUnit.MICROS, true) => + t => DateTimeUtils.fromJavaTimestamp(t.asInstanceOf[Timestamp]) + .asInstanceOf[JLong] + case (TimeUnit.MICROS, false) => + t => DateTimeUtils.convertTz( + DateTimeUtils.fromJavaTimestamp(t.asInstanceOf[Timestamp]), + DateTimeUtils.ZoneIdUTC, convertTz).asInstanceOf[JLong] + case (TimeUnit.MILLIS, true) => + _.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong] + case (TimeUnit.MILLIS, false) => + t => DateTimeUtils.microsToMillis(DateTimeUtils.convertTz( + DateTimeUtils.millisToMicros(t.asInstanceOf[Timestamp].getTime), + DateTimeUtils.ZoneIdUTC, convertTz)).asInstanceOf[JLong] + case _ => throw new IllegalArgumentException(s"Unsupported timestamp type: " + + s"TIMESTAMP(${timestampType.getUnit}, ${timestampType.isAdjustedToUTC})") + }).orNull + private def timestampToMillis(v: Any): JLong = { val micros = timestampToMicros(v) val millis = DateTimeUtils.microsToMillis(micros) @@ -186,27 +208,26 @@ class ParquetFilters( (n: Array[String], v: Any) => FilterApi.eq( intColumn(n), Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.eq( - longColumn(n), - Option(v).map(timestampToMicros).orNull) - case ParquetTimestampMillisType if pushDownTimestamp => + case ParquetSchemaType(logicalType, _class, INT64, _) if pushDownTimestamp && + _class == classOf[TimestampLogicalTypeAnnotation] => (n: Array[String], v: Any) => FilterApi.eq( longColumn(n), - Option(v).map(timestampToMillis).orNull) - - case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => - (n: Array[String], v: Any) => FilterApi.eq( - intColumn(n), - Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => - (n: Array[String], v: Any) => FilterApi.eq( - longColumn(n), - Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => - (n: Array[String], v: Any) => FilterApi.eq( - binaryColumn(n), - Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) + timestampValue(logicalType.asInstanceOf[TimestampLogicalTypeAnnotation], v)) + case ParquetSchemaType(_, _class, INT32, _) if pushDownDecimal && + _class == classOf[DecimalLogicalTypeAnnotation] => + (n: Array[String], v: Any) => FilterApi.eq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_, _class, INT64, _) if pushDownDecimal && + _class == classOf[DecimalLogicalTypeAnnotation] => + (n: Array[String], v: Any) => FilterApi.eq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_, _class, FIXED_LEN_BYTE_ARRAY, length) if pushDownDecimal && + _class == classOf[DecimalLogicalTypeAnnotation] => + (n: Array[String], v: Any) => FilterApi.eq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) } private val makeNotEq: @@ -236,24 +257,23 @@ class ParquetFilters( (n: Array[String], v: Any) => FilterApi.notEq( intColumn(n), Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull) - case ParquetTimestampMicrosType if pushDownTimestamp => + case ParquetSchemaType(logicalType, _class, INT64, _) if pushDownTimestamp && + _class == classOf[TimestampLogicalTypeAnnotation] => (n: Array[String], v: Any) => FilterApi.notEq( - longColumn(n), - Option(v).map(timestampToMicros).orNull) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.notEq( - longColumn(n), - Option(v).map(timestampToMillis).orNull) - - case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => - (n: Array[String], v: Any) => FilterApi.notEq( - intColumn(n), - Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + longColumn(n), + timestampValue(logicalType.asInstanceOf[TimestampLogicalTypeAnnotation], v)) + case ParquetSchemaType(_, _class, INT32, _) if pushDownDecimal && + _class == classOf[DecimalLogicalTypeAnnotation] => + (n: Array[String], v: Any) => FilterApi.notEq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(_, _class, INT64, _) if pushDownDecimal && + _class == classOf[DecimalLogicalTypeAnnotation] => (n: Array[String], v: Any) => FilterApi.notEq( longColumn(n), Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) - case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + case ParquetSchemaType(_, _class, FIXED_LEN_BYTE_ARRAY, length) if pushDownDecimal && + _class == classOf[DecimalLogicalTypeAnnotation] => (n: Array[String], v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) @@ -280,18 +300,22 @@ class ParquetFilters( case ParquetDateType if pushDownDate => (n: Array[String], v: Any) => FilterApi.lt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMillis(v)) + case ParquetSchemaType(logicalType, _class, INT64, _) if pushDownTimestamp && + _class == classOf[TimestampLogicalTypeAnnotation] => + (n: Array[String], v: Any) => FilterApi.lt( + longColumn(n), + timestampValue(logicalType.asInstanceOf[TimestampLogicalTypeAnnotation], v)) - case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + case ParquetSchemaType(_, _class, INT32, _) if pushDownDecimal && + _class == classOf[DecimalLogicalTypeAnnotation] => (n: Array[String], v: Any) => FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + case ParquetSchemaType(_, _class, INT64, _) if pushDownDecimal && + _class == classOf[DecimalLogicalTypeAnnotation] => (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + case ParquetSchemaType(_, _class, FIXED_LEN_BYTE_ARRAY, length) if pushDownDecimal && + _class == classOf[DecimalLogicalTypeAnnotation] => (n: Array[String], v: Any) => FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } @@ -317,18 +341,21 @@ class ParquetFilters( case ParquetDateType if pushDownDate => (n: Array[String], v: Any) => FilterApi.ltEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + case ParquetSchemaType(logicalType, _class, INT64, _) if pushDownTimestamp && + _class == classOf[TimestampLogicalTypeAnnotation] => + (n: Array[String], v: Any) => FilterApi.ltEq( + longColumn(n), + timestampValue(logicalType.asInstanceOf[TimestampLogicalTypeAnnotation], v)) + case ParquetSchemaType(_, _class, INT32, _) if pushDownDecimal && + _class == classOf[DecimalLogicalTypeAnnotation] => (n: Array[String], v: Any) => FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + case ParquetSchemaType(_, _class, INT64, _) if pushDownDecimal && + _class == classOf[DecimalLogicalTypeAnnotation] => (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + case ParquetSchemaType(_, _class, FIXED_LEN_BYTE_ARRAY, length) if pushDownDecimal && + _class == classOf[DecimalLogicalTypeAnnotation] => (n: Array[String], v: Any) => FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } @@ -354,18 +381,22 @@ class ParquetFilters( case ParquetDateType if pushDownDate => (n: Array[String], v: Any) => FilterApi.gt(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + case ParquetSchemaType(logicalType, _class, INT64, _) if pushDownTimestamp && + _class == classOf[TimestampLogicalTypeAnnotation] => + (n: Array[String], v: Any) => + FilterApi.gt( + longColumn(n), + timestampValue(logicalType.asInstanceOf[TimestampLogicalTypeAnnotation], v)) + case ParquetSchemaType(_, _class, INT32, _) if pushDownDecimal && + _class == classOf[DecimalLogicalTypeAnnotation] => (n: Array[String], v: Any) => FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + case ParquetSchemaType(_, _class, INT64, _) if pushDownDecimal && + _class == classOf[DecimalLogicalTypeAnnotation] => (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + case ParquetSchemaType(_, _class, FIXED_LEN_BYTE_ARRAY, length) if pushDownDecimal && + _class == classOf[DecimalLogicalTypeAnnotation] => (n: Array[String], v: Any) => FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } @@ -391,18 +422,22 @@ class ParquetFilters( case ParquetDateType if pushDownDate => (n: Array[String], v: Any) => FilterApi.gtEq(intColumn(n), dateToDays(v).asInstanceOf[Integer]) - case ParquetTimestampMicrosType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMicros(v)) - case ParquetTimestampMillisType if pushDownTimestamp => - (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMillis(v)) - - case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + case ParquetSchemaType(logicalType, _class, INT64, _) if pushDownTimestamp && + _class == classOf[TimestampLogicalTypeAnnotation] => + (n: Array[String], v: Any) => + FilterApi.gtEq( + longColumn(n), + timestampValue(logicalType.asInstanceOf[TimestampLogicalTypeAnnotation], v)) + case ParquetSchemaType(_, _class, INT32, _) if pushDownDecimal && + _class == classOf[DecimalLogicalTypeAnnotation] => (n: Array[String], v: Any) => FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + case ParquetSchemaType(_, _class, INT64, _) if pushDownDecimal && + _class == classOf[DecimalLogicalTypeAnnotation] => (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) - case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + case ParquetSchemaType(_, _class, FIXED_LEN_BYTE_ARRAY, length) if pushDownDecimal && + _class == classOf[DecimalLogicalTypeAnnotation] => (n: Array[String], v: Any) => FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } @@ -465,25 +500,27 @@ class ParquetFilters( case ParquetDoubleType => value.isInstanceOf[JDouble] case ParquetStringType => value.isInstanceOf[String] case ParquetBinaryType => value.isInstanceOf[Array[Byte]] - case ParquetDateType => - value.isInstanceOf[Date] || value.isInstanceOf[LocalDate] - case ParquetTimestampMicrosType | ParquetTimestampMillisType => - value.isInstanceOf[Timestamp] || value.isInstanceOf[Instant] - case ParquetSchemaType(DECIMAL, INT32, _, decimalMeta) => - isDecimalMatched(value, decimalMeta) - case ParquetSchemaType(DECIMAL, INT64, _, decimalMeta) => - isDecimalMatched(value, decimalMeta) - case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, _, decimalMeta) => - isDecimalMatched(value, decimalMeta) + case ParquetDateType => value.isInstanceOf[Date] + case ParquetSchemaType(_, _class, INT64, _) + if _class == classOf[TimestampLogicalTypeAnnotation] => + value.isInstanceOf[Timestamp] + case ParquetSchemaType(decimal, _class, INT32, _) + if _class == classOf[DecimalLogicalTypeAnnotation] => + isDecimalMatched(value, decimal.asInstanceOf[DecimalLogicalTypeAnnotation]) + case ParquetSchemaType(decimal, _class, INT64, _) => + isDecimalMatched(value, decimal.asInstanceOf[DecimalLogicalTypeAnnotation]) + case ParquetSchemaType(decimal, _class, FIXED_LEN_BYTE_ARRAY, _) => + isDecimalMatched(value, decimal.asInstanceOf[DecimalLogicalTypeAnnotation]) case _ => false }) } // Decimal type must make sure that filter value's scale matched the file. // If doesn't matched, which would cause data corruption. - private def isDecimalMatched(value: Any, decimalMeta: DecimalMetadata): Boolean = value match { + private def isDecimalMatched( + value: Any, dec: DecimalLogicalTypeAnnotation): Boolean = value match { case decimal: JBigDecimal => - decimal.scale == decimalMeta.getScale + decimal.scale == dec.getScale case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index 4a1f9154488af..d12c58f26c01c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -27,6 +27,7 @@ import org.apache.parquet.hadoop.api.{InitContext, ReadSupport} import org.apache.parquet.hadoop.api.ReadSupport.ReadContext import org.apache.parquet.io.api.RecordMaterializer import org.apache.parquet.schema._ +import org.apache.parquet.schema.LogicalTypeAnnotation.{listType, ListLogicalTypeAnnotation} import org.apache.parquet.schema.Type.Repetition import org.apache.spark.internal.Logging @@ -213,11 +214,12 @@ object ParquetReadSupport { // Unannotated repeated group should be interpreted as required list of required element, so // list element type is just the group itself. Clip it. - if (parquetList.getOriginalType == null && parquetList.isRepetition(Repetition.REPEATED)) { + if (parquetList.getLogicalTypeAnnotation == null + && parquetList.isRepetition(Repetition.REPEATED)) { clipParquetType(parquetList, elementType, caseSensitive) } else { assert( - parquetList.getOriginalType == OriginalType.LIST, + parquetList.getLogicalTypeAnnotation.isInstanceOf[ListLogicalTypeAnnotation], "Invalid Parquet schema. " + "Original type of annotated Parquet lists must be LIST: " + parquetList.toString) @@ -245,7 +247,7 @@ object ParquetReadSupport { ) { Types .buildGroup(parquetList.getRepetition) - .as(OriginalType.LIST) + .as(listType()) .addField(clipParquetType(repeatedGroup, elementType, caseSensitive)) .named(parquetList.getName) } else { @@ -253,7 +255,7 @@ object ParquetReadSupport { // repetition. Types .buildGroup(parquetList.getRepetition) - .as(OriginalType.LIST) + .as(listType()) .addField( Types .repeatedGroup() @@ -284,14 +286,14 @@ object ParquetReadSupport { val clippedRepeatedGroup = Types .repeatedGroup() - .as(repeatedGroup.getOriginalType) + .as(repeatedGroup.getLogicalTypeAnnotation) .addField(clipParquetType(parquetKeyType, keyType, caseSensitive)) .addField(clipParquetType(parquetValueType, valueType, caseSensitive)) .named(repeatedGroup.getName) Types .buildGroup(parquetMap.getRepetition) - .as(parquetMap.getOriginalType) + .as(parquetMap.getLogicalTypeAnnotation) .addField(clippedRepeatedGroup) .named(parquetMap.getName) } @@ -309,7 +311,7 @@ object ParquetReadSupport { val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType, caseSensitive) Types .buildGroup(parquetRecord.getRepetition) - .as(parquetRecord.getOriginalType) + .as(parquetRecord.getLogicalTypeAnnotation) .addFields(clippedParquetFields: _*) .named(parquetRecord.getName) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index dca12ff6b4deb..de61ed4c065eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -26,8 +26,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} -import org.apache.parquet.schema.{GroupType, OriginalType, Type} -import org.apache.parquet.schema.OriginalType.LIST +import org.apache.parquet.schema.{GroupType, LogicalTypeAnnotation, Type} +import org.apache.parquet.schema.LogicalTypeAnnotation.{TimeUnit, _} import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{BINARY, FIXED_LEN_BYTE_ARRAY, INT32, INT64, INT96} import org.apache.spark.internal.Logging @@ -110,11 +110,12 @@ private[parquet] class ParquetPrimitiveConverter(val updater: ParentContainerUpd * - a root [[ParquetRowConverter]] for [[org.apache.parquet.schema.MessageType]] `root`, * which contains: * - a [[ParquetPrimitiveConverter]] for required - * [[org.apache.parquet.schema.OriginalType.INT_32]] field `f1`, and + * [[org.apache.parquet.schema.LogicalTypeAnnotation.intType()]] field `f1`, and * - a nested [[ParquetRowConverter]] for optional [[GroupType]] `f2`, which contains: * - a [[ParquetPrimitiveConverter]] for required * [[org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE]] field `f21`, and - * - a [[ParquetStringConverter]] for optional [[org.apache.parquet.schema.OriginalType.UTF8]] + * - a [[ParquetStringConverter]] for optional + * [[org.apache.parquet.schema.LogicalTypeAnnotation.stringType()]] * string field `f22` * * When used as a root converter, [[NoopUpdater]] should be used since root converters don't have @@ -242,6 +243,12 @@ private[parquet] class ParquetRowConverter( } } + def isTimestampWithUnit(parquetType: Type, timeUnit: LogicalTypeAnnotation.TimeUnit): Boolean = { + val logicalType = parquetType.getLogicalTypeAnnotation + logicalType.isInstanceOf[TimestampLogicalTypeAnnotation] && + logicalType.asInstanceOf[TimestampLogicalTypeAnnotation].getUnit == timeUnit + } + /** * Creates a converter for the given Parquet type `parquetType` and Spark SQL data type * `catalystType`. Converted values are handled by `updater`. @@ -316,18 +323,28 @@ private[parquet] class ParquetRowConverter( case StringType => new ParquetStringConverter(updater) - case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MICROS => + case TimestampType if isTimestampWithUnit(parquetType, TimeUnit.MICROS) => new ParquetPrimitiveConverter(updater) { override def addLong(value: Long): Unit = { - updater.setLong(timestampRebaseFunc(value)) + val time = timestampRebaseFunc(value) + val utc = parquetType.getLogicalTypeAnnotation + .asInstanceOf[TimestampLogicalTypeAnnotation].isAdjustedToUTC + val adjTime = if (utc) value else DateTimeUtils.convertTz(time, + convertTz.getOrElse(DateTimeUtils.ZoneIdUTC), DateTimeUtils.ZoneIdUTC) + updater.setLong(adjTime.asInstanceOf[Long]) } } - case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MILLIS => + case TimestampType if isTimestampWithUnit(parquetType, TimeUnit.MILLIS) => new ParquetPrimitiveConverter(updater) { override def addLong(value: Long): Unit = { val micros = DateTimeUtils.millisToMicros(value) - updater.setLong(timestampRebaseFunc(micros)) + val rawTime = timestampRebaseFunc(micros) + val utc = parquetType.getLogicalTypeAnnotation + .asInstanceOf[TimestampLogicalTypeAnnotation].isAdjustedToUTC + val adjTime = if (utc) rawTime else DateTimeUtils.convertTz(rawTime, + convertTz.getOrElse(DateTimeUtils.ZoneIdUTC), DateTimeUtils.ZoneIdUTC) + updater.setLong(adjTime.asInstanceOf[Long]) } } @@ -354,7 +371,8 @@ private[parquet] class ParquetRowConverter( // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor // annotated by `LIST` or `MAP` should be interpreted as a required list of required // elements where the element type is the type of the field. - case t: ArrayType if parquetType.getOriginalType != LIST => + case t: ArrayType if + !parquetType.getLogicalTypeAnnotation.isInstanceOf[ListLogicalTypeAnnotation] => if (parquetType.isPrimitive) { new RepeatedPrimitiveConverter(parquetType, t.elementType, updater) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index ff804e25ede4b..af97986a07459 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.parquet.schema._ -import org.apache.parquet.schema.OriginalType._ +import org.apache.parquet.schema.LogicalTypeAnnotation._ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ @@ -92,10 +92,10 @@ class ParquetToSparkSchemaConverter( private def convertPrimitiveField(field: PrimitiveType): DataType = { val typeName = field.getPrimitiveTypeName - val originalType = field.getOriginalType + val logicalTypeAnnotation = field.getLogicalTypeAnnotation def typeString = - if (originalType == null) s"$typeName" else s"$typeName ($originalType)" + if (logicalTypeAnnotation == null) s"$typeName" else s"$typeName ($logicalTypeAnnotation)" def typeNotSupported() = throw new AnalysisException(s"Parquet type not supported: $typeString") @@ -110,8 +110,9 @@ class ParquetToSparkSchemaConverter( // specified in field.getDecimalMetadata. This is useful when interpreting decimal types stored // as binaries with variable lengths. def makeDecimalType(maxPrecision: Int = -1): DecimalType = { - val precision = field.getDecimalMetadata.getPrecision - val scale = field.getDecimalMetadata.getScale + val decimalType = field.getLogicalTypeAnnotation.asInstanceOf[DecimalLogicalTypeAnnotation] + val precision = decimalType.getPrecision + val scale = decimalType.getScale ParquetSchemaConverter.checkConversionRequirement( maxPrecision == -1 || 1 <= precision && precision <= maxPrecision, @@ -128,26 +129,27 @@ class ParquetToSparkSchemaConverter( case DOUBLE => DoubleType case INT32 => - originalType match { - case INT_8 => ByteType - case INT_16 => ShortType - case INT_32 | null => IntegerType - case DATE => DateType - case DECIMAL => makeDecimalType(Decimal.MAX_INT_DIGITS) - case UINT_8 => typeNotSupported() - case UINT_16 => typeNotSupported() - case UINT_32 => typeNotSupported() - case TIME_MILLIS => typeNotImplemented() + logicalTypeAnnotation match { + case i: IntLogicalTypeAnnotation => + if (!i.isSigned) typeNotSupported() else i.getBitWidth match { + case 8 => ByteType + case 16 => ShortType + case 32 => IntegerType + } + case null => IntegerType + case _: DateLogicalTypeAnnotation => DateType + case _: DecimalLogicalTypeAnnotation => makeDecimalType(Decimal.MAX_INT_DIGITS) + case _: TimeLogicalTypeAnnotation => typeNotImplemented() case _ => illegalType() } case INT64 => - originalType match { - case INT_64 | null => LongType - case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS) - case UINT_64 => typeNotSupported() - case TIMESTAMP_MICROS => TimestampType - case TIMESTAMP_MILLIS => TimestampType + logicalTypeAnnotation match { + case i: IntLogicalTypeAnnotation => + if (!i.isSigned) typeNotSupported() else LongType + case null => LongType + case _: DecimalLogicalTypeAnnotation => makeDecimalType(Decimal.MAX_LONG_DIGITS) + case _: TimestampLogicalTypeAnnotation => TimestampType case _ => illegalType() } @@ -159,19 +161,21 @@ class ParquetToSparkSchemaConverter( TimestampType case BINARY => - originalType match { - case UTF8 | ENUM | JSON => StringType + logicalTypeAnnotation match { + case _: StringLogicalTypeAnnotation | + _: EnumLogicalTypeAnnotation | _: JsonLogicalTypeAnnotation => StringType case null if assumeBinaryIsString => StringType case null => BinaryType - case BSON => BinaryType - case DECIMAL => makeDecimalType() + case _: BsonLogicalTypeAnnotation => BinaryType + case _: DecimalLogicalTypeAnnotation => makeDecimalType() case _ => illegalType() } case FIXED_LEN_BYTE_ARRAY => - originalType match { - case DECIMAL => makeDecimalType(Decimal.maxPrecisionForBytes(field.getTypeLength)) - case INTERVAL => typeNotImplemented() + logicalTypeAnnotation match { + case _: DecimalLogicalTypeAnnotation => + makeDecimalType(Decimal.maxPrecisionForBytes(field.getTypeLength)) + case _: IntervalLogicalTypeAnnotation => typeNotImplemented() case _ => illegalType() } @@ -180,7 +184,7 @@ class ParquetToSparkSchemaConverter( } private def convertGroupField(field: GroupType): DataType = { - Option(field.getOriginalType).fold(convert(field): DataType) { + Option(field.getLogicalTypeAnnotation).fold(convert(field): DataType) { // A Parquet list is represented as a 3-level structure: // // group (LIST) { @@ -194,7 +198,7 @@ class ParquetToSparkSchemaConverter( // we need to check whether the 2nd level or the 3rd level refers to list element type. // // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists - case LIST => + case _: ListLogicalTypeAnnotation => ParquetSchemaConverter.checkConversionRequirement( field.getFieldCount == 1, s"Invalid list type $field") @@ -214,7 +218,7 @@ class ParquetToSparkSchemaConverter( // `MAP_KEY_VALUE` is for backwards-compatibility // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules-1 // scalastyle:on - case MAP | MAP_KEY_VALUE => + case _: MapLogicalTypeAnnotation | _: MapKeyValueTypeAnnotation => ParquetSchemaConverter.checkConversionRequirement( field.getFieldCount == 1 && !field.getType(0).isPrimitive, s"Invalid map type: $field") @@ -342,10 +346,10 @@ class SparkToParquetSchemaConverter( Types.primitive(BOOLEAN, repetition).named(field.name) case ByteType => - Types.primitive(INT32, repetition).as(INT_8).named(field.name) + Types.primitive(INT32, repetition).as(intType(8, true)).named(field.name) case ShortType => - Types.primitive(INT32, repetition).as(INT_16).named(field.name) + Types.primitive(INT32, repetition).as(intType(16, true)).named(field.name) case IntegerType => Types.primitive(INT32, repetition).named(field.name) @@ -360,10 +364,10 @@ class SparkToParquetSchemaConverter( Types.primitive(DOUBLE, repetition).named(field.name) case StringType => - Types.primitive(BINARY, repetition).as(UTF8).named(field.name) + Types.primitive(BINARY, repetition).as(stringType()).named(field.name) case DateType => - Types.primitive(INT32, repetition).as(DATE).named(field.name) + Types.primitive(INT32, repetition).as(dateType()).named(field.name) // NOTE: Spark SQL can write timestamp values to Parquet using INT96, TIMESTAMP_MICROS or // TIMESTAMP_MILLIS. TIMESTAMP_MICROS is recommended but INT96 is the default to keep the @@ -384,9 +388,11 @@ class SparkToParquetSchemaConverter( case SQLConf.ParquetOutputTimestampType.INT96 => Types.primitive(INT96, repetition).named(field.name) case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS => - Types.primitive(INT64, repetition).as(TIMESTAMP_MICROS).named(field.name) + Types.primitive(INT64, repetition).as(timestampType(true, TimeUnit.MICROS)) + .named(field.name) case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MILLIS => - Types.primitive(INT64, repetition).as(TIMESTAMP_MILLIS).named(field.name) + Types.primitive(INT64, repetition).as(timestampType(true, TimeUnit.MILLIS)) + .named(field.name) } case BinaryType => @@ -403,9 +409,7 @@ class SparkToParquetSchemaConverter( case DecimalType.Fixed(precision, scale) if writeLegacyParquetFormat => Types .primitive(FIXED_LEN_BYTE_ARRAY, repetition) - .as(DECIMAL) - .precision(precision) - .scale(scale) + .as(decimalType(scale, precision)) .length(Decimal.minBytesForPrecision(precision)) .named(field.name) @@ -418,9 +422,7 @@ class SparkToParquetSchemaConverter( if precision <= Decimal.MAX_INT_DIGITS && !writeLegacyParquetFormat => Types .primitive(INT32, repetition) - .as(DECIMAL) - .precision(precision) - .scale(scale) + .as(decimalType(scale, precision)) .named(field.name) // Uses INT64 for 1 <= precision <= 18 @@ -428,18 +430,14 @@ class SparkToParquetSchemaConverter( if precision <= Decimal.MAX_LONG_DIGITS && !writeLegacyParquetFormat => Types .primitive(INT64, repetition) - .as(DECIMAL) - .precision(precision) - .scale(scale) + .as(decimalType(scale, precision)) .named(field.name) // Uses FIXED_LEN_BYTE_ARRAY for all other precisions case DecimalType.Fixed(precision, scale) if !writeLegacyParquetFormat => Types .primitive(FIXED_LEN_BYTE_ARRAY, repetition) - .as(DECIMAL) - .precision(precision) - .scale(scale) + .as(decimalType(scale, precision)) .length(Decimal.minBytesForPrecision(precision)) .named(field.name) @@ -464,7 +462,7 @@ class SparkToParquetSchemaConverter( // `array` as its element name as below. Therefore, we build manually // the correct group type here via the builder. (See SPARK-16777) Types - .buildGroup(repetition).as(LIST) + .buildGroup(repetition).as(listType()) .addField(Types .buildGroup(REPEATED) // "array" is the name chosen by parquet-hive (1.7.0 and prior version) @@ -482,7 +480,7 @@ class SparkToParquetSchemaConverter( // Here too, we should not use `listOfElements`. (See SPARK-16777) Types - .buildGroup(repetition).as(LIST) + .buildGroup(repetition).as(listType()) // "array" is the name chosen by parquet-avro (1.7.0 and prior version) .addField(convertField(StructField("array", elementType, nullable), REPEATED)) .named(field.name) @@ -513,7 +511,7 @@ class SparkToParquetSchemaConverter( // } // } Types - .buildGroup(repetition).as(LIST) + .buildGroup(repetition).as(listType()) .addField( Types.repeatedGroup() .addField(convertField(StructField("element", elementType, containsNull))) @@ -528,7 +526,7 @@ class SparkToParquetSchemaConverter( // } // } Types - .buildGroup(repetition).as(MAP) + .buildGroup(repetition).as(mapType()) .addField( Types .repeatedGroup() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 20d0de45ba352..23c649c17a19d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -139,20 +139,6 @@ case class ParquetPartitionReaderFactory( lazy val footerFileMetaData = ParquetFileReader.readFooter(conf, filePath, SKIP_ROW_GROUPS).getFileMetaData - // Try to push down filters when filter push-down is enabled. - val pushed = if (enableParquetFilterPushDown) { - val parquetSchema = footerFileMetaData.getSchema - val parquetFilters = new ParquetFilters(parquetSchema, pushDownDate, pushDownTimestamp, - pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive) - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(parquetFilters.createFilter) - .reduceOption(FilterApi.and) - } else { - None - } // PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 timestamps' // *only* if the file was created by something other than "parquet-mr", so check the actual // writer here for this file. We have to do this per-file, as each file in the table may @@ -168,6 +154,22 @@ case class ParquetPartitionReaderFactory( None } + // Try to push down filters when filter push-down is enabled. + val pushed = if (enableParquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new ParquetFilters(parquetSchema, pushDownDate, pushDownTimestamp, + pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive, + convertTz.getOrElse(DateTimeUtils.ZoneIdUTC)) + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(parquetFilters.createFilter) + .reduceOption(FilterApi.and) + } else { + None + } + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 44053830defe5..9ebb118a3d4cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} @@ -49,10 +50,12 @@ case class ParquetScanBuilder( val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold val isCaseSensitive = sqlConf.caseSensitiveAnalysis + val convertTz = DateTimeUtils.getZoneId(sqlConf.sessionLocalTimeZone) val parquetSchema = new SparkToParquetSchemaConverter(sparkSession.sessionState.conf).convert(readDataSchema()) val parquetFilters = new ParquetFilters(parquetSchema, pushDownDate, pushDownTimestamp, - pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive) + pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive, + convertTz) parquetFilters.convertibleFilters(this.filters).toArray } diff --git a/sql/core/src/test/resources/test-data/timestamp_dictionary.parq b/sql/core/src/test/resources/test-data/timestamp_dictionary.parq new file mode 100644 index 0000000000000..cef7a4f5d4c67 Binary files /dev/null and b/sql/core/src/test/resources/test-data/timestamp_dictionary.parq differ diff --git a/sql/core/src/test/resources/test-data/timestamp_dictionary.txt b/sql/core/src/test/resources/test-data/timestamp_dictionary.txt new file mode 100644 index 0000000000000..bbbc002d11558 --- /dev/null +++ b/sql/core/src/test/resources/test-data/timestamp_dictionary.txt @@ -0,0 +1,4 @@ +-144674181;1965-06-01 05:43:39.000;1965-06-01 12:43:39.000;1965-06-01 05:43:39.000;1965-06-01 12:43:39.000 +0;1969-12-31 16:00:00.000;1970-01-01 00:00:00.000;1969-12-31 16:00:00.000;1970-01-01 00:00:00.000 +0;1969-12-31 16:00:00.000;1970-01-01 00:00:00.000;1969-12-31 16:00:00.000;1970-01-01 00:00:00.000 +0;1969-12-31 16:00:00.000;1970-01-01 00:00:00.000;1969-12-31 16:00:00.000;1970-01-01 00:00:00.000 diff --git a/sql/core/src/test/resources/test-data/timestamp_plain.parq b/sql/core/src/test/resources/test-data/timestamp_plain.parq new file mode 100644 index 0000000000000..e262a0cdbbdb9 Binary files /dev/null and b/sql/core/src/test/resources/test-data/timestamp_plain.parq differ diff --git a/sql/core/src/test/resources/test-data/timestamp_plain.txt b/sql/core/src/test/resources/test-data/timestamp_plain.txt new file mode 100644 index 0000000000000..e439d6745565f --- /dev/null +++ b/sql/core/src/test/resources/test-data/timestamp_plain.txt @@ -0,0 +1,2 @@ +-144674181;1965-06-01 05:43:39.000;1965-06-01 12:43:39.000;1965-06-01 05:43:39.000;1965-06-01 12:43:39.000 +0;1969-12-31 16:00:00.000;1970-01-01 00:00:00.000;1969-12-31 16:00:00.000;1970-01-01 00:00:00.000 diff --git a/sql/core/src/test/resources/test-data/timestamp_pushdown.parq b/sql/core/src/test/resources/test-data/timestamp_pushdown.parq new file mode 100644 index 0000000000000..a7e72ad7663bc Binary files /dev/null and b/sql/core/src/test/resources/test-data/timestamp_pushdown.parq differ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 329a3e4983792..5230667f527f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -77,7 +77,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared new ParquetFilters(schema, conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp, conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringStartWith, conf.parquetFilterPushDownInFilterThreshold, - caseSensitive.getOrElse(conf.caseSensitiveAnalysis)) + caseSensitive.getOrElse(conf.caseSensitiveAnalysis), + ZoneId.systemDefault()) override def beforeEach(): Unit = { super.beforeEach() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala index 2fe5953cbe12e..8027a169a2e29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File import java.time.ZoneOffset +import java.util.TimeZone import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.{Path, PathFilter} @@ -26,10 +27,11 @@ import org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER import org.apache.parquet.hadoop.ParquetFileReader import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName -import org.apache.spark.sql.Row +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedSparkSession { test("parquet files with different physical schemas but share the same logical schema") { @@ -97,7 +99,7 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS } } - test("parquet timestamp conversion") { + test("parquet Impala timestamp conversion") { // Make a table with one parquet file written by impala, and one parquet file written by spark. // We should only adjust the timestamps in the impala file, and only if the conf is set val impalaFile = "test-data/impala_timestamp.parq" @@ -202,4 +204,126 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS } } } + + test("parquet timestamp read path") { + Seq("timestamp_plain", "timestamp_dictionary").foreach({ file => + val timestampPath = Thread.currentThread() + .getContextClassLoader.getResource("test-data/" + file + ".parq").toURI.getPath + val expectedPath = Thread.currentThread() + .getContextClassLoader.getResource("test-data/" + file + ".txt").toURI.getPath + + val schema = StructType(Array( + StructField("rawValue", LongType, false), + StructField("millisUtc", TimestampType, false), + StructField("millisNonUtc", TimestampType, false), + StructField("microsUtc", TimestampType, false), + StructField("microsNonUtc", TimestampType, false))) + + withTempPath {tableDir => + val textValues = spark.read + .schema(schema) + .option("inferSchema", false) + .option("header", false) + .option("timestampFormat", "yyyy-MM-dd HH:mm:ss.SSS") + .option("delimiter", ";").csv(expectedPath).collect + val timestamps = textValues.map( + row => (row.getLong(0), + row.getTimestamp(1), + row.getTimestamp(2), + row.getTimestamp(3), + row.getTimestamp(4) + ) + ) + FileUtils.copyFile(new File(timestampPath), new File(tableDir, "part-00001.parq")) + + Seq(false, true).foreach { vectorized => + withSQLConf( + (SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, vectorized.toString()) + ) { + val readBack = spark.read.parquet(tableDir.getAbsolutePath).collect + + val expected = timestamps.map(_.toString).sorted + val actual = readBack.map( + row => (row.getLong(0), + row.getTimestamp(1), + row.getTimestamp(2), + row.getTimestamp(3), + row.getTimestamp(4) + ) + ).map(_.toString).sorted + assert(readBack.size === expected.size) + withClue(s"vectorized = $vectorized, file = $file") { + assert(actual === expected) + } + } + } + } + }) + } + + // predicates to test for + // s"${column._2} > to_timestamp('2017-10-29 00:45:00.0')" + // s"${column._2} >= to_timestamp('2017-10-29 00:45:00.0')" + // s"${column._2} != to_timestamp('1970-01-01 00:00:55.0')" + test("parquet timestamp predicate pushdown") { + val timestampPath = Thread.currentThread() + .getContextClassLoader.getResource("test-data/timestamp_pushdown.parq").toURI.getPath + + def verifyPredicate(dataFrame: DataFrame, column: String, + item: String, predicate: String, vectorized: Boolean): Unit = { + val filter = s"$column $predicate to_timestamp('$item')" + withSQLConf( + (SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key, "true") + ) { + val withPushdown = dataFrame.where(filter).collect + withSQLConf( + (SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key, "false") + ) { + val withoutPushdown = dataFrame.where(filter).collect + withClue(s"vectorized = $vectorized, column = ${column}, item = $item") { + assert(withPushdown === withoutPushdown) + } + } + } + } + + def withTimeZone(timeZone: String)(f: => Unit): Unit = { + val tz = TimeZone.getDefault + // Java 8+ ZoneId-API does not support changing timezone at runtime + TimeZone.setDefault(TimeZone.getTimeZone(timeZone)) + try f finally TimeZone.setDefault(tz) + } + + withTempPath { tableDir => + FileUtils.copyFile(new File(timestampPath), new File(tableDir, "part-00001.parq")) + + Seq("America/Los_Angeles", "Australia/Perth").foreach({ timeZone => + withTimeZone(timeZone) { + Seq(false, true).foreach { vectorized => + withSQLConf( + (SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, vectorized.toString()) + ) { + val losAngeles = spark.read.parquet(tableDir.getAbsolutePath).select("inLosAngeles") + .collect.map(_.getString(0)) + val utc = spark.read.parquet(tableDir.getAbsolutePath).select("inUtc") + .collect.map(_.getString(0)) + val singapore = spark.read.parquet(tableDir.getAbsolutePath).select("inPerth") + .collect.map(_.getString(0)) + Seq(losAngeles, utc, singapore).foreach(values => values.foreach(item => + Seq("millisUtc", "millisNonUtc", "microsUtc", "microsNonUtc").foreach(column => { + val dataFrame = spark.read.parquet(tableDir.getAbsolutePath).select(column) + verifyPredicate(dataFrame, column, item, "=", vectorized) + verifyPredicate(dataFrame, column, item, "!=", vectorized) + verifyPredicate(dataFrame, column, item, ">", vectorized) + verifyPredicate(dataFrame, column, item, ">=", vectorized) + verifyPredicate(dataFrame, column, item, "<", vectorized) + verifyPredicate(dataFrame, column, item, "<=", vectorized) + }) + )) + } + } + } + }) + } + } }