diff --git a/core/benches/hash.rs b/core/benches/hash.rs index b878ebea59..d66d589257 100644 --- a/core/benches/hash.rs +++ b/core/benches/hash.rs @@ -19,8 +19,7 @@ mod common; use arrow_array::ArrayRef; -use comet::execution::datafusion::expressions::scalar_funcs::spark_murmur3_hash; -use comet::execution::datafusion::spark_hash::create_xxhash64_hashes; +use comet::execution::datafusion::expressions::scalar_funcs::{spark_murmur3_hash, spark_xxhash64}; use comet::execution::kernels::hash; use common::*; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; @@ -100,12 +99,15 @@ fn criterion_benchmark(c: &mut Criterion) { }, ); group.bench_function(BenchmarkId::new("xxhash64", BATCH_SIZE), |b| { - let input = vec![a3.clone(), a4.clone()]; - let mut dst = vec![0; BATCH_SIZE]; + let inputs = &[ + ColumnarValue::Array(a3.clone()), + ColumnarValue::Array(a4.clone()), + ColumnarValue::Scalar(ScalarValue::Int64(Some(42i64))), + ]; b.iter(|| { for _ in 0..NUM_ITER { - create_xxhash64_hashes(&input, &mut dst).unwrap(); + spark_xxhash64(inputs).unwrap(); } }); }); diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index 3c7af86769..e554256423 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -15,14 +15,8 @@ // specific language governing permissions and limitations // under the License. -use std::{ - any::Any, - cmp::min, - fmt::{Debug, Write}, - sync::Arc, -}; +use std::{any::Any, cmp::min, fmt::Debug, sync::Arc}; -use crate::execution::datafusion::spark_hash::{create_murmur3_hashes, create_xxhash64_hashes}; use arrow::{ array::{ ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, GenericStringArray, @@ -30,7 +24,7 @@ use arrow::{ }, datatypes::{validate_decimal_precision, Decimal128Type, Int64Type}, }; -use arrow_array::{Array, ArrowNativeTypeOp, Decimal128Array, StringArray}; +use arrow_array::{Array, ArrowNativeTypeOp, Decimal128Array}; use arrow_schema::DataType; use datafusion::{ execution::FunctionRegistry, @@ -39,8 +33,8 @@ use datafusion::{ physical_plan::ColumnarValue, }; use datafusion_common::{ - cast::{as_binary_array, as_generic_string_array}, - exec_err, internal_err, DataFusionError, Result as DataFusionResult, ScalarValue, + cast::as_generic_string_array, exec_err, internal_err, DataFusionError, + Result as DataFusionResult, ScalarValue, }; use datafusion_expr::ScalarUDF; use num::{ @@ -58,6 +52,11 @@ use hex::spark_hex; mod chr; use chr::spark_chr; +pub mod hash_expressions; +// exposed for benchmark only +use hash_expressions::wrap_digest_result_as_hex_string; +pub use hash_expressions::{spark_murmur3_hash, spark_xxhash64}; + macro_rules! make_comet_scalar_udf { ($name:expr, $func:ident, $data_type:ident) => {{ let scalar_func = CometScalarFunction::new( @@ -635,125 +634,3 @@ fn spark_decimal_div( let result = result.with_data_type(DataType::Decimal128(p3, s3)); Ok(ColumnarValue::Array(Arc::new(result))) } - -pub fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result { - let length = args.len(); - let seed = &args[length - 1]; - match seed { - ColumnarValue::Scalar(ScalarValue::Int32(Some(seed))) => { - // iterate over the arguments to find out the length of the array - let num_rows = args[0..args.len() - 1] - .iter() - .find_map(|arg| match arg { - ColumnarValue::Array(array) => Some(array.len()), - ColumnarValue::Scalar(_) => None, - }) - .unwrap_or(1); - let mut hashes: Vec = vec![0_u32; num_rows]; - hashes.fill(*seed as u32); - let arrays = args[0..args.len() - 1] - .iter() - .map(|arg| match arg { - ColumnarValue::Array(array) => array.clone(), - ColumnarValue::Scalar(scalar) => { - scalar.clone().to_array_of_size(num_rows).unwrap() - } - }) - .collect::>(); - create_murmur3_hashes(&arrays, &mut hashes)?; - if num_rows == 1 { - Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some( - hashes[0] as i32, - )))) - } else { - let hashes: Vec = hashes.into_iter().map(|x| x as i32).collect(); - Ok(ColumnarValue::Array(Arc::new(Int32Array::from(hashes)))) - } - } - _ => { - internal_err!( - "The seed of function murmur3_hash must be an Int32 scalar value, but got: {:?}.", - seed - ) - } - } -} - -fn spark_xxhash64(args: &[ColumnarValue]) -> Result { - let length = args.len(); - let seed = &args[length - 1]; - match seed { - ColumnarValue::Scalar(ScalarValue::Int64(Some(seed))) => { - // iterate over the arguments to find out the length of the array - let num_rows = args[0..args.len() - 1] - .iter() - .find_map(|arg| match arg { - ColumnarValue::Array(array) => Some(array.len()), - ColumnarValue::Scalar(_) => None, - }) - .unwrap_or(1); - let mut hashes: Vec = vec![0_u64; num_rows]; - hashes.fill(*seed as u64); - let arrays = args[0..args.len() - 1] - .iter() - .map(|arg| match arg { - ColumnarValue::Array(array) => array.clone(), - ColumnarValue::Scalar(scalar) => { - scalar.clone().to_array_of_size(num_rows).unwrap() - } - }) - .collect::>(); - create_xxhash64_hashes(&arrays, &mut hashes)?; - if num_rows == 1 { - Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some( - hashes[0] as i64, - )))) - } else { - let hashes: Vec = hashes.into_iter().map(|x| x as i64).collect(); - Ok(ColumnarValue::Array(Arc::new(Int64Array::from(hashes)))) - } - } - _ => { - internal_err!( - "The seed of function xxhash64 must be an Int64 scalar value, but got: {:?}.", - seed - ) - } - } -} - -#[inline] -fn hex_encode>(data: T) -> String { - let mut s = String::with_capacity(data.as_ref().len() * 2); - for b in data.as_ref() { - // Writing to a string never errors, so we can unwrap here. - write!(&mut s, "{b:02x}").unwrap(); - } - s -} - -fn wrap_digest_result_as_hex_string( - args: &[ColumnarValue], - digest: ScalarFunctionImplementation, -) -> Result { - let value = digest(args)?; - match value { - ColumnarValue::Array(array) => { - let binary_array = as_binary_array(&array)?; - let string_array: StringArray = binary_array - .iter() - .map(|opt| opt.map(hex_encode::<_>)) - .collect(); - Ok(ColumnarValue::Array(Arc::new(string_array))) - } - ColumnarValue::Scalar(ScalarValue::Binary(opt)) => Ok(ColumnarValue::Scalar( - ScalarValue::Utf8(opt.map(hex_encode::<_>)), - )), - _ => { - exec_err!( - "digest function should return binary value, but got: {:?}", - value.data_type() - ) - } - } -} diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/hash_expressions.rs b/core/src/execution/datafusion/expressions/scalar_funcs/hash_expressions.rs new file mode 100644 index 0000000000..67d7281628 --- /dev/null +++ b/core/src/execution/datafusion/expressions/scalar_funcs/hash_expressions.rs @@ -0,0 +1,138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::execution::datafusion::expressions::scalar_funcs::hex::hex_strings; +use crate::execution::datafusion::spark_hash::{create_murmur3_hashes, create_xxhash64_hashes}; +use arrow_array::{ArrayRef, Int32Array, Int64Array, StringArray}; +use datafusion_common::cast::as_binary_array; +use datafusion_common::{exec_err, internal_err, DataFusionError, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; +use std::sync::Arc; + +/// Spark compatible murmur3 hash in vectorized execution fashion +pub fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result { + let length = args.len(); + let seed = &args[length - 1]; + match seed { + ColumnarValue::Scalar(ScalarValue::Int32(Some(seed))) => { + // iterate over the arguments to find out the length of the array + let num_rows = args[0..args.len() - 1] + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + ColumnarValue::Scalar(_) => None, + }) + .unwrap_or(1); + let mut hashes: Vec = vec![0_u32; num_rows]; + hashes.fill(*seed as u32); + let arrays = args[0..args.len() - 1] + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => { + scalar.clone().to_array_of_size(num_rows).unwrap() + } + }) + .collect::>(); + create_murmur3_hashes(&arrays, &mut hashes)?; + if num_rows == 1 { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some( + hashes[0] as i32, + )))) + } else { + let hashes: Vec = hashes.into_iter().map(|x| x as i32).collect(); + Ok(ColumnarValue::Array(Arc::new(Int32Array::from(hashes)))) + } + } + _ => { + internal_err!( + "The seed of function murmur3_hash must be an Int32 scalar value, but got: {:?}.", + seed + ) + } + } +} + +/// Spark compatible xxhash64 in vectorized execution fashion +pub fn spark_xxhash64(args: &[ColumnarValue]) -> Result { + let length = args.len(); + let seed = &args[length - 1]; + match seed { + ColumnarValue::Scalar(ScalarValue::Int64(Some(seed))) => { + // iterate over the arguments to find out the length of the array + let num_rows = args[0..args.len() - 1] + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + ColumnarValue::Scalar(_) => None, + }) + .unwrap_or(1); + let mut hashes: Vec = vec![0_u64; num_rows]; + hashes.fill(*seed as u64); + let arrays = args[0..args.len() - 1] + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => { + scalar.clone().to_array_of_size(num_rows).unwrap() + } + }) + .collect::>(); + create_xxhash64_hashes(&arrays, &mut hashes)?; + if num_rows == 1 { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some( + hashes[0] as i64, + )))) + } else { + let hashes: Vec = hashes.into_iter().map(|x| x as i64).collect(); + Ok(ColumnarValue::Array(Arc::new(Int64Array::from(hashes)))) + } + } + _ => { + internal_err!( + "The seed of function xxhash64 must be an Int64 scalar value, but got: {:?}.", + seed + ) + } + } +} + +pub(super) fn wrap_digest_result_as_hex_string( + args: &[ColumnarValue], + digest: ScalarFunctionImplementation, +) -> Result { + let value = digest(args)?; + match value { + ColumnarValue::Array(array) => { + let binary_array = as_binary_array(&array)?; + let string_array: StringArray = binary_array + .iter() + .map(|opt| opt.map(hex_strings::<_>)) + .collect(); + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + ColumnarValue::Scalar(ScalarValue::Binary(opt)) => Ok(ColumnarValue::Scalar( + ScalarValue::Utf8(opt.map(hex_strings::<_>)), + )), + _ => { + exec_err!( + "digest function should return binary value, but got: {:?}", + value.data_type() + ) + } + } +} diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs index ea572574a1..5191e53fa2 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs @@ -34,13 +34,31 @@ fn hex_int64(num: i64) -> String { format!("{:X}", num) } -fn hex_bytes>(bytes: T) -> Result { - let bytes = bytes.as_ref(); - let length = bytes.len(); - let mut hex_string = String::with_capacity(length * 2); - for &byte in bytes { - write!(&mut hex_string, "{:02X}", byte)?; +#[inline(always)] +fn hex_encode>(data: T, lower_case: bool) -> String { + let mut s = String::with_capacity(data.as_ref().len() * 2); + if lower_case { + for b in data.as_ref() { + // Writing to a string never errors, so we can unwrap here. + write!(&mut s, "{b:02x}").unwrap(); + } + } else { + for b in data.as_ref() { + // Writing to a string never errors, so we can unwrap here. + write!(&mut s, "{b:02X}").unwrap(); + } } + s +} + +#[inline(always)] +pub(super) fn hex_strings>(data: T) -> String { + hex_encode(data, true) +} + +#[inline(always)] +fn hex_bytes>(bytes: T) -> Result { + let hex_string = hex_encode(bytes, false); Ok(hex_string) } @@ -246,14 +264,14 @@ mod test { fn test_dictionary_hex_binary() { let mut input_builder = BinaryDictionaryBuilder::::new(); input_builder.append_value("1"); - input_builder.append_value("1"); + input_builder.append_value("j"); input_builder.append_null(); input_builder.append_value("3"); let input = input_builder.finish(); let mut expected_builder = StringBuilder::new(); expected_builder.append_value("31"); - expected_builder.append_value("31"); + expected_builder.append_value("6A"); expected_builder.append_null(); expected_builder.append_value("33"); let expected = expected_builder.finish();