Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 101 additions & 16 deletions datafusion/spark/src/function/math/hex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,22 +108,50 @@ impl ScalarUDFImpl for SparkHex {
}
}

/// Hex encoding lookup tables for fast byte-to-hex conversion
const HEX_CHARS_LOWER: &[u8; 16] = b"0123456789abcdef";
const HEX_CHARS_UPPER: &[u8; 16] = b"0123456789ABCDEF";
/// Hex encoding lookup tables for fast byte-to-hex conversion.
///
/// Each entry maps a full byte to its two-character hex encoding so the
/// hot loop becomes one load + one two-byte extend per input byte instead
/// of two nibble lookups and two pushes.
const HEX_CHARS_UPPER_NIBBLES: &[u8; 16] = b"0123456789ABCDEF";
const HEX_CHARS_LOWER_NIBBLES: &[u8; 16] = b"0123456789abcdef";

const HEX_LOOKUP_UPPER: [[u8; 2]; 256] = build_hex_lookup(HEX_CHARS_UPPER_NIBBLES);
const HEX_LOOKUP_LOWER: [[u8; 2]; 256] = build_hex_lookup(HEX_CHARS_LOWER_NIBBLES);

const fn build_hex_lookup(nibbles: &[u8; 16]) -> [[u8; 2]; 256] {
let mut table = [[0u8; 2]; 256];
let mut i = 0;
while i < 256 {
table[i][0] = nibbles[(i >> 4) & 0xF];
table[i][1] = nibbles[i & 0xF];
i += 1;
}
table
}

#[inline]
fn hex_int64(num: i64, buffer: &mut [u8; 16]) -> &[u8] {
if num == 0 {
return b"0";
}

// Walk the value two nibbles (one full byte) at a time. The buffer is
// filled from the right so the high-order nibbles end up first; the
// returned slice trims leading zeros automatically.
let mut n = num as u64;
let mut i = 16;
while n != 0 {
while n >= 0x10 {
i -= 2;
let pair = HEX_LOOKUP_UPPER[(n & 0xFF) as usize];
buffer[i] = pair[0];
buffer[i + 1] = pair[1];
n >>= 8;
}
if n > 0 {
// Single remaining high nibble (value 0x1..=0xF).
i -= 1;
buffer[i] = HEX_CHARS_UPPER[(n & 0xF) as usize];
n >>= 4;
buffer[i] = HEX_CHARS_UPPER_NIBBLES[n as usize];
}
&buffer[i..]
}
Expand All @@ -140,21 +168,21 @@ where
{
let mut builder = StringBuilder::with_capacity(len, len * 64);
let mut buffer = Vec::with_capacity(64);
let hex_chars = if lowercase {
HEX_CHARS_LOWER
let lookup = if lowercase {
&HEX_LOOKUP_LOWER
} else {
HEX_CHARS_UPPER
&HEX_LOOKUP_UPPER
};

for v in iter {
if let Some(b) = v {
buffer.clear();
let bytes = b.as_ref();
buffer.clear();
buffer.reserve(bytes.len() * 2);
for &byte in bytes {
buffer.push(hex_chars[(byte >> 4) as usize]);
buffer.push(hex_chars[(byte & 0x0f) as usize]);
buffer.extend_from_slice(&lookup[byte as usize]);
}
// SAFETY: buffer contains only ASCII hex digests, which are valid UTF-8
// SAFETY: buffer contains only ASCII hex digits, which are valid UTF-8.
unsafe {
builder.append_value(from_utf8_unchecked(&buffer));
}
Expand Down Expand Up @@ -327,7 +355,9 @@ mod test {
use std::str::from_utf8_unchecked;
use std::sync::Arc;

use arrow::array::{DictionaryArray, Int32Array, Int64Array, StringArray};
use arrow::array::{
BinaryArray, DictionaryArray, Int32Array, Int64Array, StringArray,
};
use arrow::{
array::{
BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringDictionaryBuilder,
Expand Down Expand Up @@ -427,19 +457,74 @@ mod test {

#[test]
fn test_hex_int64() {
let test_cases = vec![(1234, "4D2"), (-1, "FFFFFFFFFFFFFFFF")];
let test_cases = vec![
(0_i64, "0"),
(1, "1"),
(15, "F"),
(16, "10"),
(255, "FF"),
(256, "100"),
(1234, "4D2"),
(i64::MAX, "7FFFFFFFFFFFFFFF"),
(i64::MIN, "8000000000000000"),
(-1, "FFFFFFFFFFFFFFFF"),
];

for (num, expected) in test_cases {
let mut cache = [0u8; 16];
let slice = super::hex_int64(num, &mut cache);

unsafe {
let result = from_utf8_unchecked(slice);
assert_eq!(expected, result);
assert_eq!(expected, result, "hex_int64({num}) mismatch");
}
}
}

#[test]
fn test_hex_lookup_table_covers_all_bytes() {
// Cross-check the precomputed table against an independent encoder
// for every possible byte value and both casings.
for byte in 0u8..=255 {
let upper = format!("{byte:02X}");
let lower = format!("{byte:02x}");
let upper_pair = super::HEX_LOOKUP_UPPER[byte as usize];
let lower_pair = super::HEX_LOOKUP_LOWER[byte as usize];
assert_eq!(
upper.as_bytes(),
&upper_pair,
"upper encoding mismatch for byte 0x{byte:02X}"
);
assert_eq!(
lower.as_bytes(),
&lower_pair,
"lower encoding mismatch for byte 0x{byte:02X}"
);
}
}

#[test]
fn test_spark_hex_binary_round_trip_all_bytes() {
// Single-row binary input containing every byte value, encoded in
// a single column. Catches per-byte regressions in the bytes path.
let payload: Vec<u8> = (0u8..=255).collect();
let bin_array = BinaryArray::from(vec![Some(payload.as_slice())]);

let result =
super::spark_hex(&[ColumnarValue::Array(Arc::new(bin_array))]).unwrap();
let array = match result {
ColumnarValue::Array(array) => array,
_ => panic!("Expected array"),
};
let strings = as_string_array(&array);
let mut expected = String::with_capacity(512);
for byte in 0u8..=255 {
use std::fmt::Write;
write!(expected, "{byte:02X}").unwrap();
}
assert_eq!(strings.value(0), expected);
}

#[test]
fn test_spark_hex_int64() {
let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]);
Expand Down
Loading