Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions datafusion/spark/src/function/math/hex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::sync::Arc;
use crate::function::error_utils::{
invalid_arg_count_exec_err, unsupported_data_type_exec_err,
};
use arrow::array::{Array, StringArray};
use arrow::array::{Array, GenericStringBuilder, StringArray};
use arrow::datatypes::DataType;
use arrow::{
array::{as_dictionary_array, as_largestring_array, as_string_array},
Expand All @@ -35,6 +35,10 @@ use datafusion_expr::Signature;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
use std::fmt::Write;

/// Max length of a hex-encoded uint64 value is 16 characters.
/// E.g., `i64::MAX` is `7FFFFFFFFFFFFFFF` which is 16 characters long.
const MAX_HEX_INT64_LEN: usize = 16;

/// <https://spark.apache.org/docs/latest/api/sql/index.html#hex>
#[derive(Debug)]
pub struct SparkHex {
Expand Down Expand Up @@ -196,11 +200,23 @@ pub fn compute_hex(
ColumnarValue::Array(array) => match array.data_type() {
DataType::Int64 => {
let array = as_int64_array(array)?;
let arr_len = array.len();
let arr_content_len = arr_len * MAX_HEX_INT64_LEN;

// Optimized using a builder to avoid intermediate allocations
// during iter().map(...).collect()
let mut builder =
GenericStringBuilder::<i32>::with_capacity(arr_len, arr_content_len);
for value in array {
match value {
None => builder.append_null(),
Some(v) => builder.append_value(hex_int64(v)),
};
}

let hexed_array: StringArray =
array.iter().map(|v| v.map(hex_int64)).collect();
let hexed = builder.finish();

Ok(ColumnarValue::Array(Arc::new(hexed_array)))
Ok(ColumnarValue::Array(Arc::new(hexed)))
}
DataType::Utf8 => {
let array = as_string_array(array);
Expand Down Expand Up @@ -333,13 +349,17 @@ mod test {
input_builder.append_value(2);
input_builder.append_null();
input_builder.append_value(3);
input_builder.append_value(i64::MAX);
input_builder.append_value(i64::MIN);
let input = input_builder.finish();

let mut string_builder = StringBuilder::new();
string_builder.append_value("1");
string_builder.append_value("2");
string_builder.append_null();
string_builder.append_value("3");
string_builder.append_value("7FFFFFFFFFFFFFFF");
string_builder.append_value("8000000000000000");
let expected = string_builder.finish();

let columnar_value = ColumnarValue::Array(Arc::new(input));
Expand Down Expand Up @@ -397,7 +417,14 @@ mod test {

#[test]
fn test_spark_hex_int64() {
let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]);
let int_array = Int64Array::from(vec![
Some(1),
Some(2),
None,
Some(3),
Some(i64::MAX),
Some(i64::MIN),
]);
let columnar_value = ColumnarValue::Array(Arc::new(int_array));

let result = super::spark_hex(&[columnar_value]).unwrap();
Expand All @@ -412,6 +439,8 @@ mod test {
Some("2".to_string()),
None,
Some("3".to_string()),
Some("7FFFFFFFFFFFFFFF".to_string()),
Some("8000000000000000".to_string()),
]);

assert_eq!(string_array, &expected_array);
Expand Down
Loading