diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index ff07940422a0b..fa1f27a15ef1d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -588,7 +588,7 @@ public MapData getMap(int ordinal) { /** * Returns the decimal for rowId. */ - public final Decimal getDecimal(int rowId, int precision, int scale) { + public Decimal getDecimal(int rowId, int precision, int scale) { if (precision <= Decimal.MAX_INT_DIGITS()) { return Decimal.createUnsafe(getInt(rowId), precision, scale); } else if (precision <= Decimal.MAX_LONG_DIGITS()) { @@ -617,7 +617,7 @@ public final void putDecimal(int rowId, Decimal value, int precision) { /** * Returns the UTF8String for rowId. */ - public final UTF8String getUTF8String(int rowId) { + public UTF8String getUTF8String(int rowId) { if (dictionary == null) { ColumnVector.Array a = getByteArray(rowId); return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); @@ -630,7 +630,7 @@ public final UTF8String getUTF8String(int rowId) { /** * Returns the byte array for rowId. */ - public final byte[] getBinary(int rowId) { + public byte[] getBinary(int rowId) { if (dictionary == null) { ColumnVector.Array array = getByteArray(rowId); byte[] bytes = new byte[array.length]; @@ -980,6 +980,14 @@ public ColumnVector getDictionaryIds() { return dictionaryIds; } + public ColumnVector(DataType type) { + this.capacity = 0; + this.type = type; + this.childColumns = null; + this.resultArray = null; + this.resultStruct = null; + } + /** * Sets up the common state and also handles creating the child columns if this is a nested * type. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index a6ce4c2edc232..a0e31e45cf9d0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -466,6 +466,18 @@ public void filterNullsInColumn(int ordinal) { nullFilteredColumns.add(ordinal); } + /** + * A public Ctor which accepts allocated ColumnVectors. + */ + public ColumnarBatch(ColumnVector[] columns, int maxRows) { + this.columns = columns; + this.capacity = maxRows; + this.schema = null; + this.nullFilteredColumns = new HashSet<>(); + this.filteredRows = new boolean[maxRows]; + this.row = new Row(this); + } + private ColumnarBatch(StructType schema, int maxRows, MemoryMode memMode) { this.schema = schema; this.capacity = maxRows; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b2a50c646bd03..5cd6ca6c3d6a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -248,6 +248,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ORC_VECTORIZED_READER_ENABLED = + SQLConfigBuilder("spark.sql.orc.enableVectorizedReader") + .doc("Enables vectorized orc reader.") + .booleanConf + .createWithDefault(false) + val ORC_FILTER_PUSHDOWN_ENABLED = SQLConfigBuilder("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf @@ -692,6 +698,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED) + def orcVectorizedReaderEnabled: Boolean = getConf(ORC_VECTORIZED_READER_ENABLED) + def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/OrcColumnVector.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/OrcColumnVector.java new file mode 100644 index 0000000000000..e80ae2ee8499d --- /dev/null +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/OrcColumnVector.java @@ -0,0 +1,358 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.io.orc; + +import org.apache.commons.lang.NotImplementedException; +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column vector class wrapping Hive's ColumnVector. Because Spark ColumnarBatch only accepts + * Spark's vectorized.ColumnVector, this column vector is used to adapt Hive ColumnVector with + * Spark ColumnarBatch. This class inherits Spark's vectorized.ColumnVector class, but all data + * setter methods (e.g., putInt) in Spark vectorized.ColumnVector are not implemented. + */ +public class OrcColumnVector extends org.apache.spark.sql.execution.vectorized.ColumnVector { + private ColumnVector col; + + public OrcColumnVector(ColumnVector col, DataType type) { + super(type); + this.col = col; + } + + /* A helper method to get the row index in a column. */ + private int getRowIndex(int rowId) { + return this.col.isRepeating ? 0 : rowId; + } + + @Override + public long valuesNativeAddress() { + throw new NotImplementedException(); + } + + @Override + public long nullsNativeAddress() { + throw new NotImplementedException(); + } + + @Override + public void close() { + } + + // + // APIs dealing with nulls + // + + @Override + public void putNotNull(int rowId) { + throw new NotImplementedException(); + } + + @Override + public void putNull(int rowId) { + throw new NotImplementedException(); + } + + @Override + public void putNulls(int rowId, int count) { + throw new NotImplementedException(); + } + + @Override + public void putNotNulls(int rowId, int count) { + throw new NotImplementedException(); + } + + @Override + public boolean isNullAt(int rowId) { + return col.isNull[getRowIndex(rowId)]; + } + + // + // APIs dealing with Booleans + // + + @Override + public void putBoolean(int rowId, boolean value) { + throw new NotImplementedException(); + } + + @Override + public void putBooleans(int rowId, int count, boolean value) { + throw new NotImplementedException(); + } + + @Override + public boolean getBoolean(int rowId) { + LongColumnVector col = (LongColumnVector) this.col; + return col.vector[getRowIndex(rowId)] > 0; + } + + // + // APIs dealing with Bytes + // + + @Override + public void putByte(int rowId, byte value) { + throw new NotImplementedException(); + } + + @Override + public void putBytes(int rowId, int count, byte value) { + throw new NotImplementedException(); + } + + @Override + public void putBytes(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public byte getByte(int rowId) { + LongColumnVector col = (LongColumnVector) this.col; + return (byte) col.vector[getRowIndex(rowId)]; + } + + // + // APIs dealing with Shorts + // + + @Override + public void putShort(int rowId, short value) { + throw new NotImplementedException(); + } + + @Override + public void putShorts(int rowId, int count, short value) { + throw new NotImplementedException(); + } + + @Override + public void putShorts(int rowId, int count, short[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public short getShort(int rowId) { + LongColumnVector col = (LongColumnVector) this.col; + return (short) col.vector[getRowIndex(rowId)]; + } + + // + // APIs dealing with Ints + // + + @Override + public void putInt(int rowId, int value) { + throw new NotImplementedException(); + } + + @Override + public void putInts(int rowId, int count, int value) { + throw new NotImplementedException(); + } + + @Override + public void putInts(int rowId, int count, int[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public int getInt(int rowId) { + LongColumnVector col = (LongColumnVector) this.col; + return (int) col.vector[getRowIndex(rowId)]; + } + + /** + * Returns the dictionary Id for rowId. + */ + @Override + public int getDictId(int rowId) { + throw new NotImplementedException(); + } + + // + // APIs dealing with Longs + // + + @Override + public void putLong(int rowId, long value) { + throw new NotImplementedException(); + } + + @Override + public void putLongs(int rowId, int count, long value) { + throw new NotImplementedException(); + } + + @Override + public void putLongs(int rowId, int count, long[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public long getLong(int rowId) { + LongColumnVector col = (LongColumnVector) this.col; + return (long) col.vector[getRowIndex(rowId)]; + } + + // + // APIs dealing with floats + // + + @Override + public void putFloat(int rowId, float value) { + throw new NotImplementedException(); + } + + @Override + public void putFloats(int rowId, int count, float value) { + throw new NotImplementedException(); + } + + @Override + public void putFloats(int rowId, int count, float[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public void putFloats(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public float getFloat(int rowId) { + DoubleColumnVector col = (DoubleColumnVector) this.col; + return (float) col.vector[getRowIndex(rowId)]; + } + + // + // APIs dealing with doubles + // + + @Override + public void putDouble(int rowId, double value) { + throw new NotImplementedException(); + } + + @Override + public void putDoubles(int rowId, int count, double value) { + throw new NotImplementedException(); + } + + @Override + public void putDoubles(int rowId, int count, double[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public double getDouble(int rowId) { + DoubleColumnVector col = (DoubleColumnVector) this.col; + return (double) col.vector[getRowIndex(rowId)]; + } + + // + // APIs dealing with Arrays + // + + @Override + public int getArrayLength(int rowId) { + throw new NotImplementedException(); + } + + @Override + public int getArrayOffset(int rowId) { + throw new NotImplementedException(); + } + + @Override + public void putArray(int rowId, int offset, int length) { + throw new NotImplementedException(); + } + + @Override + public void loadBytes(org.apache.spark.sql.execution.vectorized.ColumnVector.Array array) { + throw new NotImplementedException(); + } + + /** + * Returns the decimal for rowId. + */ + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + DecimalColumnVector col = (DecimalColumnVector) this.col; + int index = getRowIndex(rowId); + return Decimal.apply(col.vector[index].getHiveDecimal().bigDecimalValue(), precision, scale); + } + + /** + * Returns the UTF8String for rowId. + */ + @Override + public UTF8String getUTF8String(int rowId) { + BytesColumnVector col = (BytesColumnVector) this.col; + int index = getRowIndex(rowId); + return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]); + } + + /** + * Returns the byte array for rowId. + */ + @Override + public byte[] getBinary(int rowId) { + BytesColumnVector col = (BytesColumnVector) this.col; + int index = getRowIndex(rowId); + byte[] binary = new byte[col.length[index]]; + System.arraycopy(col.vector[index], col.start[index], binary, 0, binary.length); + return binary; + } + + // + // APIs dealing with Byte Arrays + // + @Override + public int putByteArray(int rowId, byte[] value, int offset, int length) { + throw new NotImplementedException(); + } + + @Override + protected void reserveInternal(int newCapacity) { + throw new NotImplementedException(); + } +} diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java new file mode 100644 index 0000000000000..f220d4d6e6fac --- /dev/null +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.io.orc; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.mapred.FileSplit; +import org.apache.hadoop.mapred.RecordReader; + +/** + * A mapred.RecordReader that returns VectorizedRowBatch. + */ +public class SparkVectorizedOrcRecordReader + implements RecordReader { + private final org.apache.hadoop.hive.ql.io.orc.RecordReader reader; + private final long offset; + private final long length; + private float progress = 0.0f; + private ObjectInspector objectInspector; + private List columnIDs; + + public SparkVectorizedOrcRecordReader( + Reader file, + Configuration conf, + FileSplit fileSplit, + List columnIDs) throws IOException { + this.offset = fileSplit.getStart(); + this.length = fileSplit.getLength(); + this.objectInspector = file.getObjectInspector(); + this.columnIDs = columnIDs; + this.reader = OrcInputFormat.createReaderFromFile(file, conf, this.offset, + this.length); + this.progress = reader.getProgress(); + } + + /** + * Create a ColumnVector based on given ObjectInspector's type info. + * + * @param inspector ObjectInspector + */ + private ColumnVector createColumnVector(ObjectInspector inspector) { + switch(inspector.getCategory()) { + case PRIMITIVE: + { + PrimitiveTypeInfo primitiveTypeInfo = + (PrimitiveTypeInfo) ((PrimitiveObjectInspector)inspector).getTypeInfo(); + switch(primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + case BYTE: + case SHORT: + case INT: + case LONG: + case DATE: + case INTERVAL_YEAR_MONTH: + return new LongColumnVector(VectorizedRowBatch.DEFAULT_SIZE); + case FLOAT: + case DOUBLE: + return new DoubleColumnVector(VectorizedRowBatch.DEFAULT_SIZE); + case BINARY: + case STRING: + case CHAR: + case VARCHAR: + BytesColumnVector column = new BytesColumnVector(VectorizedRowBatch.DEFAULT_SIZE); + column.initBuffer(); + return column; + case DECIMAL: + DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) primitiveTypeInfo; + return new DecimalColumnVector(VectorizedRowBatch.DEFAULT_SIZE, + decimalTypeInfo.precision(), decimalTypeInfo.scale()); + default: + throw new RuntimeException("Vectorization is not supported for datatype:" + + primitiveTypeInfo.getPrimitiveCategory() + ". " + + "Please disable spark.sql.orc.enableVectorizedReader."); + } + } + default: + throw new RuntimeException("Vectorization is not supported for datatype:" + + inspector.getCategory() + ". " + + "Please disable the config spark.sql.orc.enableVectorizedReader."); + } + } + + /** + * Create VectorizedRowBatch from ObjectInspector + * + * @param oi StructObjectInspector + * @return VectorizedRowBatch + */ + private VectorizedRowBatch constructVectorizedRowBatch(StructObjectInspector oi) { + List fields = oi.getAllStructFieldRefs(); + VectorizedRowBatch result = new VectorizedRowBatch(fields.size()); + for (int i = 0; i < columnIDs.size(); i++) { + int fieldIndex = columnIDs.get(i); + ObjectInspector fieldObjectInspector = fields.get(fieldIndex).getFieldObjectInspector(); + result.cols[fieldIndex] = createColumnVector(fieldObjectInspector); + } + return result; + } + + @Override + public boolean next(NullWritable key, VectorizedRowBatch value) throws IOException { + if (reader.hasNext()) { + try { + reader.nextBatch(value); + progress = reader.getProgress(); + return (value != null && !value.endOfFile && value.size > 0); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + return false; + } + + @Override + public NullWritable createKey() { + return NullWritable.get(); + } + + @Override + public VectorizedRowBatch createValue() { + return constructVectorizedRowBatch((StructObjectInspector) this.objectInspector); + } + + @Override + public long getPos() throws IOException { + return offset + (long) (progress * length); + } + + @Override + public void close() throws IOException { + reader.close(); + } + + @Override + public float getProgress() throws IOException { + return progress; + } + } diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java new file mode 100644 index 0000000000000..df4295d76bd6b --- /dev/null +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.io.orc; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.lang.NotImplementedException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.hadoop.mapreduce.lib.input.FileSplit; + +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.execution.vectorized.ColumnarBatch; +import org.apache.spark.sql.execution.vectorized.ColumnVector; +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A RecordReader that returns ColumnarBatch for Spark SQL execution. + * This reader uses an internal reader that returns Hive's VectorizedRowBatch. + */ +public class VectorizedSparkOrcNewRecordReader + extends org.apache.hadoop.mapreduce.RecordReader { + private final org.apache.hadoop.mapred.RecordReader reader; + private final int numColumns; + private VectorizedRowBatch hiveBatch; + private float progress = 0.0f; + private List columnIDs; + + private ColumnVector[] orcColumns; + private ColumnarBatch columnarBatch;; + + /** + * If true, this class returns batches instead of rows. + */ + private boolean returnColumnarBatch; + + private long numRowsOfBatch = 0; + private int indexOfRow = 0; + + public VectorizedSparkOrcNewRecordReader( + Reader file, + Configuration conf, + FileSplit fileSplit, + List columnIDs, + StructType requiredSchema, + StructType partitionColumns, + InternalRow partitionValues) throws IOException { + List types = file.getTypes(); + numColumns = (types.size() == 0) ? 0 : types.get(0).getSubtypesCount(); + this.reader = new SparkVectorizedOrcRecordReader(file, conf, + new org.apache.hadoop.mapred.FileSplit(fileSplit), columnIDs); + + this.hiveBatch = this.reader.createValue(); + + this.columnIDs = new ArrayList<>(columnIDs); + this.orcColumns = new ColumnVector[columnIDs.size() + partitionValues.numFields()]; + + // Allocate Spark ColumnVectors for data columns. + for (int i = 0; i < columnIDs.size(); i++) { + org.apache.hadoop.hive.ql.exec.vector.ColumnVector col = + this.hiveBatch.cols[columnIDs.get(i)]; + this.orcColumns[i] = new OrcColumnVector(col, requiredSchema.fields()[i].dataType()); + } + + // Allocate Spark ColumnVectors for partition columns. + if (partitionValues.numFields() > 0) { + int i = 0; + int base = columnIDs.size(); + for (StructField f : partitionColumns.fields()) { + // Use onheap for partition column vectors. + ColumnVector col = ColumnVector.allocate( + VectorizedRowBatch.DEFAULT_SIZE, + f.dataType(), + MemoryMode.ON_HEAP); + ColumnVectorUtils.populate(col, partitionValues, i); + col.setIsConstant(); + this.orcColumns[base + i] = col; + i++; + } + } + + // Allocate Spark ColumnBatch + this.columnarBatch = new ColumnarBatch(this.orcColumns, VectorizedRowBatch.DEFAULT_SIZE); + + this.progress = reader.getProgress(); + } + + @Override + public void close() throws IOException { + reader.close(); + } + + /* + * Can be called before any rows are returned to enable returning columnar batches directly. + */ + public void enableReturningBatches() { + returnColumnarBatch = true; + } + + @Override + public NullWritable getCurrentKey() throws IOException, InterruptedException { + return NullWritable.get(); + } + + @Override + public Object getCurrentValue() throws IOException, InterruptedException { + if (returnColumnarBatch) return this.columnarBatch; + return columnarBatch.getRow(indexOfRow - 1); + } + + @Override + public float getProgress() throws IOException, InterruptedException { + return progress; + } + + @Override + public void initialize(InputSplit split, TaskAttemptContext context) + throws IOException, InterruptedException { + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + if (returnColumnarBatch) return nextBatch(); + + if (indexOfRow >= numRowsOfBatch) { + return nextBatch(); + } else { + indexOfRow++; + return true; + } + } + + /** + * Advances to the next batch of rows. Returns false if there are no more. + */ + public boolean nextBatch() throws IOException, InterruptedException { + if (reader.next(NullWritable.get(), hiveBatch)) { + if (hiveBatch.endOfFile) { + progress = 1.0f; + numRowsOfBatch = 0; + columnarBatch.setNumRows((int) numRowsOfBatch); + indexOfRow = 0; + return false; + } else { + assert hiveBatch.numCols == numColumns : "Incorrect number of columns in the current batch"; + numRowsOfBatch = hiveBatch.count(); + columnarBatch.setNumRows((int) numRowsOfBatch); + indexOfRow = 0; + progress = reader.getProgress(); + return true; + } + } else { + return false; + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 42c92ed5cae26..c572c7f8da9f1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive.orc import java.net.URI import java.util.Properties +import scala.collection.JavaConverters._ + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hive.conf.HiveConf.ConfVars @@ -35,10 +37,11 @@ import org.apache.spark.TaskContext import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} import org.apache.spark.sql.sources.{Filter, _} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{AtomicType, StructType, TimestampType} import org.apache.spark.util.SerializableConfiguration /** @@ -107,6 +110,20 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable true } + override def buildReaderWithPartitionValues( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + // For Orc data source, `buildReader` already handles partition values appending. Here we + // simply delegate to `buildReader`. + buildReader( + sparkSession, dataSchema, partitionSchema, requiredSchema, filters, options, hadoopConf) + } + override def buildReader( sparkSession: SparkSession, dataSchema: StructType, @@ -126,6 +143,15 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields) + val enableVectorizedReader: Boolean = + sparkSession.sessionState.conf.orcVectorizedReaderEnabled && + resultSchema.forall(f => f.dataType.isInstanceOf[AtomicType] && + !f.dataType.isInstanceOf[TimestampType]) + + // Whole stage codegen (PhysicalRDD) is able to deal with batches directly + val returningBatch = supportBatch(sparkSession, resultSchema) + (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value @@ -139,34 +165,74 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val physicalSchema = maybePhysicalSchema.get OrcRelation.setRequiredColumns(conf, physicalSchema, requiredSchema) - val orcRecordReader = { - val job = Job.getInstance(conf) - FileInputFormat.setInputPaths(job, file.filePath) - - val fileSplit = new FileSplit( - new Path(new URI(file.filePath)), file.start, file.length, Array.empty - ) - // Custom OrcRecordReader is used to get - // ObjectInspector during recordReader creation itself and can - // avoid NameNode call in unwrapOrcStructs per file. - // Specifically would be helpful for partitioned datasets. - val orcReader = OrcFile.createReader( - new Path(new URI(file.filePath)), OrcFile.readerOptions(conf)) - new SparkOrcNewRecordReader(orcReader, conf, fileSplit.getStart, fileSplit.getLength) - } - - val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => recordsIterator.close())) + val job = Job.getInstance(conf) + FileInputFormat.setInputPaths(job, file.filePath) + + val fileSplit = new FileSplit( + new Path(new URI(file.filePath)), file.start, file.length, Array.empty + ) + // Custom OrcRecordReader is used to get + // ObjectInspector during recordReader creation itself and can + // avoid NameNode call in unwrapOrcStructs per file. + // Specifically would be helpful for partitioned datasets. + val orcReader = OrcFile.createReader( + new Path(new URI(file.filePath)), OrcFile.readerOptions(conf)) + + if (enableVectorizedReader) { + val columnIDs = + requiredSchema.map(a => physicalSchema.fieldIndex(a.name): Integer).sorted.asJava + val orcRecordReader = new VectorizedSparkOrcNewRecordReader( + orcReader, + conf, + fileSplit, + columnIDs, + requiredSchema, + partitionSchema, + file.partitionValues) + + if (returningBatch) { + orcRecordReader.enableReturningBatches() + } + val recordsIterator = new RecordReaderIterator(orcRecordReader) + Option(TaskContext.get()) + .foreach(_.addTaskCompletionListener(_ => recordsIterator.close())) + // VectorizedSparkOrcNewRecordReader appends the columns internally to avoid another copy. + recordsIterator.asInstanceOf[Iterator[InternalRow]] + } else { + val orcRecordReader = + new SparkOrcNewRecordReader(orcReader, conf, fileSplit.getStart, fileSplit.getLength) + val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader) + Option(TaskContext.get()) + .foreach(_.addTaskCompletionListener(_ => recordsIterator.close())) + + // Unwraps `OrcStruct`s to `UnsafeRow`s + val iter = OrcRelation.unwrapOrcStructs( + conf, + requiredSchema, + Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]), + recordsIterator) + + if (partitionSchema.length == 0) { + // There is no partition columns + iter + } else { + val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val joinedRow = new JoinedRow() + val appendPartitionColumns = GenerateUnsafeProjection.generate(fullSchema, fullSchema) - // Unwraps `OrcStruct`s to `UnsafeRow`s - OrcRelation.unwrapOrcStructs( - conf, - requiredSchema, - Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]), - recordsIterator) + iter.map(d => appendPartitionColumns(joinedRow(d, file.partitionValues))) + } + } } } } + + /** + * Returns whether the reader will return the rows as batch or not. + */ + override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { + OrcRelation.supportBatch(sparkSession, schema) + } } private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) @@ -309,4 +375,15 @@ private[orc] object OrcRelation extends HiveInspectors { val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip HiveShim.appendReadColumns(conf, sortedIDs, sortedNames) } + + /** + * Returns whether the reader will return the rows as batch or not. + */ + def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { + val conf = sparkSession.sessionState.conf + conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled && + schema.length <= conf.wholeStageMaxNumFields && + schema.forall(f => f.dataType.isInstanceOf[AtomicType] && + !f.dataType.isInstanceOf[TimestampType]) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index b8761e9de2886..63f1022063d96 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive.orc import java.nio.charset.StandardCharsets import java.sql.Timestamp +import scala.util.Try + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.io.orc.{OrcStruct, SparkOrcNewRecordReader} import org.scalatest.BeforeAndAfterAll @@ -54,7 +56,43 @@ case class Contact(name: String, phone: String) case class Person(name: String, age: Int, contacts: Seq[Contact]) -class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { +class OrcQuerySuite extends OrcQueryBase { + override protected val value = "false" + private var currentValue: Option[String] = None + + override protected def beforeAll(): Unit = { + currentValue = Try(spark.conf.get(key)).toOption + spark.conf.set(key, value) + } + + override protected def afterAll(): Unit = { + currentValue match { + case Some(value) => spark.conf.set(key, value) + case None => spark.conf.unset(key) + } + } +} + +class OrcQueryVectorizedSuite extends OrcQueryBase { + override protected val value = "true" + private var currentValue: Option[String] = None + + override protected def beforeAll(): Unit = { + currentValue = Try(spark.conf.get(key)).toOption + spark.conf.set(key, value) + } + + override protected def afterAll(): Unit = { + currentValue match { + case Some(value) => spark.conf.set(key, value) + case None => spark.conf.unset(key) + } + } +} + +abstract class OrcQueryBase extends QueryTest with BeforeAndAfterAll with OrcTest { + protected val key = SQLConf.ORC_VECTORIZED_READER_ENABLED.key + protected val value: String test("Read/write All Types") { val data = (0 to 255).map { i => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/OrcColumnVectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/OrcColumnVectorSuite.scala new file mode 100644 index 0000000000000..c61956c809240 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/OrcColumnVectorSuite.scala @@ -0,0 +1,281 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc.vectorized + +import scala.util.Random + +import org.apache.commons.lang.NotImplementedException +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector +import org.apache.hadoop.hive.ql.io.orc.OrcColumnVector +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class OrcColumnVectorSuite extends SparkFunSuite { + // This helper method access the internal vector of Hive's ColumnVector classes. + private def fillColumnVector[T](col: ColumnVector, values: Seq[T]): Unit = { + col match { + case lv: LongColumnVector => + assert(lv.vector.length == values.length) + values.zipWithIndex.map { case (v, idx) => + lv.vector(idx) = v.asInstanceOf[Long] + } + case bv: BytesColumnVector => + assert(bv.vector.length == values.length) + values.zipWithIndex.map { case (v, idx) => + val array = v.asInstanceOf[Seq[Byte]].toArray + bv.vector(idx) = array + bv.start(idx) = 0 + bv.length(idx) = array.length + } + case dv: DoubleColumnVector => + assert(dv.vector.length == values.length) + values.zipWithIndex.map { case (v, idx) => + dv.vector(idx) = v.asInstanceOf[Double] + } + case dv: DecimalColumnVector => + assert(dv.vector.length == values.length) + values.zipWithIndex.map { case (v, idx) => + val writable = new HiveDecimalWritable(v.asInstanceOf[HiveDecimal]) + dv.vector(idx) = writable + } + case _ => + assert(false, s"${col.getClass.getName} is not supported") + } + } + + private def dataGenerator[T](num: Int)(randomize: (Int) => T): Seq[T] = { + (0 until num).map { i => + randomize(i) + } + } + + private def getAllRowsFromColumn[T] + (rowNum: Int, col: OrcColumnVector)(accessor: (OrcColumnVector, Int) => T): Seq[T] = { + (0 until rowNum).map { rowId => + accessor(col, rowId) + } + } + + private def testLongColumnVector[T](num: Int, dt: DataType) + (genExpected: (Seq[Long] => Seq[T])) + (genActual: (OrcColumnVector, Int) => Seq[T]): Unit = { + val seed = System.currentTimeMillis() + val random = new Random(seed) + + val data = dataGenerator(num) { _ => + random.nextLong() + } + + val lv = new LongColumnVector(num) + fillColumnVector(lv, data) + assert(data === lv.vector) + + val expected = genExpected(data) + + val orcCol = new OrcColumnVector(lv, dt) + val actual = genActual(orcCol, num) + assert(actual === expected) + } + + private def testDoubleColumnVector[T](num: Int, dt: DataType) + (genExpected: (Seq[Double] => Seq[T])) + (genActual: (OrcColumnVector, Int) => Seq[T]): Unit = { + val seed = System.currentTimeMillis() + val random = new Random(seed) + + val data = dataGenerator(num) { _ => + random.nextDouble() + } + + val lv = new DoubleColumnVector(num) + fillColumnVector(lv, data) + assert(data === lv.vector) + + val expected = genExpected(data) + + val orcCol = new OrcColumnVector(lv, dt) + val actual = genActual(orcCol, num) + assert(actual === expected) + } + + private def testBytesColumnVector[T](num: Int, dt: DataType) + (genExpected: (Seq[Seq[Byte]] => Seq[T])) + (genActual: (OrcColumnVector, Int) => Seq[T]): Unit = { + val seed = System.currentTimeMillis() + val random = new Random(seed) + + val schema = new StructType().add("binary", BinaryType, false) + val data = dataGenerator(num) { _ => + RandomDataGenerator.randomRow(random, schema).getAs[Array[Byte]](0).toSeq + } + + val lv = new BytesColumnVector(num) + fillColumnVector(lv, data) + assert(data === lv.vector) + + val expected = genExpected(data) + + val orcCol = new OrcColumnVector(lv, dt) + val actual = genActual(orcCol, num) + actual.zip(expected).foreach { case (a, e) => + assert(a === e) + } + } + + private def testDecimalColumnVector(num: Int) + (genExpected: (Seq[HiveDecimal] => Seq[java.math.BigDecimal])) + (genActual: (OrcColumnVector, Int, Int, Int) => Seq[java.math.BigDecimal]): Unit = { + val seed = System.currentTimeMillis() + val random = new Random(seed) + + val decimalTypes = Seq(DecimalType.ShortDecimal, DecimalType.IntDecimal, + DecimalType.ByteDecimal, DecimalType.FloatDecimal, DecimalType.LongDecimal) + + decimalTypes.foreach { decimalType => + val schema = new StructType().add("decimal", decimalType, false) + val data = dataGenerator(num) { _ => + val javaDecimal = RandomDataGenerator.randomRow(random, schema).getDecimal(0) + HiveDecimal.create(javaDecimal) + } + + val lv = new DecimalColumnVector(num, decimalType.precision, decimalType.scale) + fillColumnVector(lv, data) + assert(data === lv.vector.map(_.getHiveDecimal(decimalType.precision, decimalType.scale))) + + val expected = genExpected(data) + + val orcCol = new OrcColumnVector(lv, decimalType) + val actual = genActual(orcCol, num, decimalType.precision, decimalType.scale) + actual.zip(expected).foreach { case (a, e) => + assert(a.compareTo(e) == 0) + } + } + } + + test("Hive LongColumnVector: Boolean") { + val genExpected = (data: Seq[Long]) => data.map(_ > 0) + val genActual = (orcCol: OrcColumnVector, num: Int) => { + getAllRowsFromColumn(num, orcCol) { (col, rowId) => + col.getBoolean(rowId) + } + } + testLongColumnVector(100, BooleanType)(genExpected)(genActual) + } + + test("Hive LongColumnVector: Int") { + val genExpected = (data: Seq[Long]) => data.map(_.toInt) + val genActual = (orcCol: OrcColumnVector, num: Int) => { + getAllRowsFromColumn(num, orcCol) { (col, rowId) => + col.getInt(rowId) + } + } + testLongColumnVector(100, IntegerType)(genExpected)(genActual) + } + + test("Hive LongColumnVector: Byte") { + val genExpected = (data: Seq[Long]) => data.map(_.toByte) + val genActual = (orcCol: OrcColumnVector, num: Int) => { + getAllRowsFromColumn(num, orcCol) { (col, rowId) => + col.getByte(rowId) + } + } + testLongColumnVector(100, ByteType)(genExpected)(genActual) + } + + test("Hive LongColumnVector: Short") { + val genExpected = (data: Seq[Long]) => data.map(_.toShort) + val genActual = (orcCol: OrcColumnVector, num: Int) => { + getAllRowsFromColumn(num, orcCol) { (col, rowId) => + col.getShort(rowId) + } + } + testLongColumnVector(100, ShortType)(genExpected)(genActual) + } + + test("Hive LongColumnVector: Long") { + val genExpected = (data: Seq[Long]) => data + val genActual = (orcCol: OrcColumnVector, num: Int) => { + getAllRowsFromColumn(num, orcCol) { (col, rowId) => + col.getLong(rowId) + } + } + testLongColumnVector(100, LongType)(genExpected)(genActual) + } + + test("Hive DoubleColumnVector: Float") { + val genExpected = (data: Seq[Double]) => data.map(_.toFloat) + val genActual = (orcCol: OrcColumnVector, num: Int) => { + getAllRowsFromColumn(num, orcCol) { (col, rowId) => + col.getFloat(rowId) + } + } + testDoubleColumnVector(100, FloatType)(genExpected)(genActual) + } + + test("Hive DoubleColumnVector: Double") { + val genExpected = (data: Seq[Double]) => data + val genActual = (orcCol: OrcColumnVector, num: Int) => { + getAllRowsFromColumn(num, orcCol) { (col, rowId) => + col.getDouble(rowId) + } + } + testDoubleColumnVector(100, DoubleType)(genExpected)(genActual) + } + + test("Hive BytesColumnVector: Binary") { + val genExpected = (data: Seq[Seq[Byte]]) => data + val genActual = (orcCol: OrcColumnVector, num: Int) => { + getAllRowsFromColumn(num, orcCol) { (col, rowId) => + col.getBinary(rowId).toSeq + } + } + testBytesColumnVector(100, BinaryType)(genExpected)(genActual) + } + + test("Hive BytesColumnVector: String") { + val genExpected = (data: Seq[Seq[Byte]]) => { + data.map(bytes => UTF8String.fromBytes(bytes.toArray, 0, bytes.length)) + } + + val genActual = (orcCol: OrcColumnVector, num: Int) => { + getAllRowsFromColumn(num, orcCol) { (col, rowId) => + col.getUTF8String(rowId) + } + } + testBytesColumnVector(100, StringType)(genExpected)(genActual) + } + + test("Hive DecimalColumnVector") { + val genExpected = (data: Seq[HiveDecimal]) => data.map(_.bigDecimalValue()) + val genActual = (orcCol: OrcColumnVector, num: Int, precision: Int, scale: Int) => { + getAllRowsFromColumn(num, orcCol) { (col, rowId) => + col.getDecimal(rowId, precision, scale).toJavaBigDecimal + } + } + testDecimalColumnVector(100)(genExpected)(genActual) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/VectorizedSparkOrcNewRecordReaderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/VectorizedSparkOrcNewRecordReaderSuite.scala new file mode 100644 index 0000000000000..73ce68a8aebab --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/VectorizedSparkOrcNewRecordReaderSuite.scala @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc.vectorized + +import java.io.File +import java.net.URI +import java.nio.charset.StandardCharsets +import java.sql.Date + +import scala.collection.JavaConverters._ +import scala.util.{Random, Try} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch +import org.apache.hadoop.hive.ql.io.orc.{Reader, SparkVectorizedOrcRecordReader, VectorizedSparkOrcNewRecordReader} +import org.apache.hadoop.io.NullWritable +import org.apache.hadoop.mapreduce.lib.input.FileSplit +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.vectorized.ColumnarBatch +import org.apache.spark.sql.hive.orc._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class VectorizedSparkOrcNewRecordReaderSuite extends QueryTest with BeforeAndAfterAll with OrcTest { + val key = SQLConf.ORC_VECTORIZED_READER_ENABLED.key + val value = "true" + private var currentValue: Option[String] = None + + override protected def beforeAll(): Unit = { + currentValue = Try(spark.conf.get(key)).toOption + spark.conf.set(key, value) + } + + override protected def afterAll(): Unit = { + currentValue match { + case Some(value) => spark.conf.set(key, value) + case None => spark.conf.unset(key) + } + } + + private def prepareParametersForReader( + filepath: String, + requiredSchema: StructType): (Configuration, Reader, FileSplit, java.util.List[Integer]) = { + val conf = new Configuration() + val physicalSchema = OrcFileOperator.readSchema(Seq(filepath), Some(conf)).get + OrcRelation.setRequiredColumns(conf, physicalSchema, requiredSchema) + val orcReader = OrcFileOperator.getFileReader(filepath, Some(conf)).get + + val file = new File(filepath) + val fileSplit = new FileSplit(new Path(new URI(filepath)), 0, file.length(), Array.empty) + val columnIDs = + requiredSchema.map(a => physicalSchema.fieldIndex(a.name): Integer).sorted.asJava + + (conf, orcReader, fileSplit, columnIDs) + } + + private def getOrcRecordReader( + filepath: String, + requiredSchema: StructType): SparkVectorizedOrcRecordReader = { + val (conf, orcReader, fileSplit, columnIDs) = + prepareParametersForReader(filepath, requiredSchema) + new SparkVectorizedOrcRecordReader( + orcReader, + conf, + new org.apache.hadoop.mapred.FileSplit(fileSplit), + columnIDs) + } + + private def getVectorizedOrcReader( + filepath: String, + requiredSchema: StructType, + partitionSchema: StructType, + partitionValues: InternalRow): VectorizedSparkOrcNewRecordReader = { + val (conf, orcReader, fileSplit, columnIDs) = + prepareParametersForReader(filepath, requiredSchema) + val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields) + val reader = + new VectorizedSparkOrcNewRecordReader( + orcReader, conf, fileSplit, columnIDs, requiredSchema, partitionSchema, partitionValues) + + val returningBatch: Boolean = OrcRelation.supportBatch(spark, resultSchema) + if (returningBatch) { + reader.enableReturningBatches() + } + reader + } + + // Test data reading with VectorizedSparkOrcNewRecordReader: + // VectorizedSparkOrcNewRecordReader supports batch processing with Spark's ColumnarBatch. + // We test it with/without partitions. + + val partitionSchemas = Seq( + StructType(Nil), + new StructType().add("p1", IntegerType).add("p2", LongType)) + + val partitionValues = Seq( + InternalRow.empty, + InternalRow(1, 2L)) + + val partitionSettings = partitionSchemas.zip(partitionValues) + + partitionSettings.map { case (partitionSchema, partitionValue) => + val doPartition = partitionValue != InternalRow.empty + val partitionTitle = if (doPartition) "with partition" else "" + + test(s"Read types: batch processing $partitionTitle") { + val colNum = if (doPartition) 13 else 11 + val data = (0 to 255).map { i => + val dateString = "2015-08-20" + val milliseconds = Date.valueOf(dateString).getTime + i * 3600 + (s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0, + s"$i".getBytes(StandardCharsets.UTF_8), Decimal(i.toDouble).toJavaBigDecimal, + new Date(milliseconds)) + } + val expectedRows = data.map { x => + val data = Seq(UTF8String.fromString(x._1), x._2, x._3, x._4, x._5, x._6, x._7, x._8, x._9, + Decimal(x._10), DateTimeUtils.fromJavaDate(x._11)) + val dataWithPartition = if (doPartition) { + data ++ Seq(1, 2L) + } else { + data + } + InternalRow.fromSeq(dataWithPartition) + } + + withOrcFile(data) { file => + val requiredSchema = new StructType() + .add("_1", StringType) + .add("_2", IntegerType) + .add("_3", LongType) + .add("_4", FloatType) + .add("_5", DoubleType) + .add("_6", ShortType) + .add("_7", ByteType) + .add("_8", BooleanType) + .add("_9", BinaryType) + .add("_10", DecimalType.LongDecimal) + .add("_11", DateType) + val reader = getVectorizedOrcReader(file, requiredSchema, partitionSchema, partitionValue) + assert(reader.nextKeyValue()) + + // The schema is supported by ColumnarBatch. + val nextValue = reader.getCurrentValue() + assert(nextValue.isInstanceOf[ColumnarBatch]) + + val batch = nextValue.asInstanceOf[ColumnarBatch] + + assert(batch.numCols() == colNum) + assert(batch.numRows() == 256) + assert(batch.numValidRows() == 256) + assert(batch.capacity() > 0) + assert(batch.rowIterator().hasNext == true) + + assert(batch.column(0).getUTF8String(0).toString() == "0") + assert(batch.column(0).isNullAt(0) == false) + assert(batch.column(1).getInt(0) == 0) + assert(batch.column(1).isNullAt(0) == false) + assert(batch.column(4).getDouble(0) == 0.0) + assert(batch.column(4).isNullAt(0) == false) + + val it = batch.rowIterator() + expectedRows.map { row => + assert(it.hasNext()) + assert(it.next().copy() == row) + } + } + } + + test(s"Read types: no batch processing $partitionTitle") { + val dataColNum = spark.conf.get(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key).toInt + 1 + val colNum = if (doPartition) { + dataColNum + 2 + } else { + dataColNum + } + + val data = (0 to 255).map { i => + Row.fromSeq((i to dataColNum + i - 1).toSeq) + } + + val expectedRows = data.map { x => + val data = x.toSeq + val dataWithPartition = if (doPartition) { + data ++ Seq(1, 2L) + } else { + data + } + InternalRow.fromSeq(dataWithPartition) + } + + withTempPath { file => + val fields = (1 to dataColNum).map { idx => + StructField(s"_$idx", IntegerType) + } + val requiredSchema = StructType(fields.toArray) + spark.createDataFrame(sparkContext.parallelize(data), requiredSchema) + .write.orc(file.getCanonicalPath) + val path = file.getCanonicalPath + + val reader = getVectorizedOrcReader(path, requiredSchema, partitionSchema, partitionValue) + assert(reader.nextKeyValue()) + + // Column number exceeds SQLConf.WHOLESTAGE_MAX_NUM_FIELDS, + // so batch processing is not supported. + val nextValue = reader.getCurrentValue() + assert(nextValue.isInstanceOf[ColumnarBatch.Row]) + + val batchRow = nextValue.asInstanceOf[ColumnarBatch.Row] + + assert(batchRow.numFields() == colNum) + + var idx = 0 + while (reader.nextKeyValue()) { + val row = expectedRows(idx) + val batchRow = reader.getCurrentValue().asInstanceOf[ColumnarBatch.Row].copy() + assert(batchRow === row) + idx += 1 + } + } + } + } + + // Test SparkVectorizedOrcRecordReader: + // SparkVectorizedOrcRecordReader is only used by VectorizedSparkOrcNewRecordReader. + // We test it to see if it correctly constructs Hive's ColumnVector. + + test("Read Orc file with SparkVectorizedOrcRecordReader") { + val colNum = 9 + val data = (0 to 255).map { i => + (s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0, + s"$i".getBytes(StandardCharsets.UTF_8)) + } + + withOrcFile(data) { file => + val requiredSchema = new StructType() + .add("_1", StringType) + .add("_2", IntegerType) + .add("_3", LongType) + .add("_4", FloatType) + .add("_5", DoubleType) + .add("_6", ShortType) + .add("_7", ByteType) + .add("_8", BooleanType) + .add("_9", BinaryType) + val reader = getOrcRecordReader(file, requiredSchema) + val hiveBatch = reader.createValue() + assert(hiveBatch.isInstanceOf[VectorizedRowBatch]) + assert(hiveBatch.cols.length == colNum) + + var allRowCount = 0L + while (reader.next(NullWritable.get(), hiveBatch)) { + allRowCount += hiveBatch.count() + } + assert(allRowCount == 256) + } + } + + val notSupportDataTypes = Seq( + ArrayType(IntegerType, true), + MapType(IntegerType, IntegerType, true), + new StructType().add("_1", IntegerType), + TimestampType) + + notSupportDataTypes.map { notSupportDataType => + val seed = System.currentTimeMillis() + val random = new Random(seed) + + test(s"SparkVectorizedOrcRecordReader does not support: $notSupportDataType") { + val requiredSchema = new StructType() + .add("_1", notSupportDataType) + val data = (0 to 255).map { i => + RandomDataGenerator.randomRow(random, requiredSchema) + } + withTempPath { file => + spark.createDataFrame(sparkContext.parallelize(data), requiredSchema) + .write.orc(file.getCanonicalPath) + val path = file.getCanonicalPath + val reader = getOrcRecordReader(path, requiredSchema) + val exception = intercept[RuntimeException] { + reader.createValue() + } + assert(exception.getMessage.contains("Vectorization is not supported for datatype")) + } + } + } +}