diff --git a/native/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs b/native/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs index 5191e53fa2..e6059818b9 100644 --- a/native/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs +++ b/native/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs @@ -118,65 +118,36 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result { + DataType::Dictionary(_, value_type) => { let dict = as_dictionary_array::(&array); - let hexed_values = as_int64_array(dict.values())?; - let values = hexed_values + let values = match **value_type { + DataType::Int64 => as_int64_array(dict.values())? + .iter() + .map(|v| v.map(hex_int64)) + .collect::>(), + DataType::Utf8 => as_string_array(dict.values()) + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?, + DataType::Binary => as_binary_array(dict.values())? + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?, + _ => exec_err!( + "hex got an unexpected argument type: {:?}", + array.data_type() + )?, + }; + + let new_values: Vec> = dict + .keys() .iter() - .map(|v| v.map(hex_int64)) - .collect::>(); + .map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None)) + .collect(); - let keys = dict.keys().clone(); - let mut new_keys = Vec::with_capacity(values.len()); + let string_array_values = StringArray::from(new_values); - for key in keys.iter() { - let key = key.map(|k| values[k as usize].clone()).unwrap_or(None); - new_keys.push(key); - } - - let string_array_values = StringArray::from(new_keys); - Ok(ColumnarValue::Array(Arc::new(string_array_values))) - } - DataType::Dictionary(_, value_type) if matches!(**value_type, DataType::Utf8) => { - let dict = as_dictionary_array::(&array); - - let hexed_values = as_string_array(dict.values()); - let values: Vec> = hexed_values - .iter() - .map(|v| v.map(hex_bytes).transpose()) - .collect::>()?; - - let keys = dict.keys().clone(); - - let mut new_keys = Vec::with_capacity(values.len()); - - for key in keys.iter() { - let key = key.map(|k| values[k as usize].clone()).unwrap_or(None); - new_keys.push(key); - } - - let string_array_values = StringArray::from(new_keys); - Ok(ColumnarValue::Array(Arc::new(string_array_values))) - } - DataType::Dictionary(_, value_type) if matches!(**value_type, DataType::Binary) => { - let dict = as_dictionary_array::(&array); - - let hexed_values = as_binary_array(dict.values())?; - let values: Vec> = hexed_values - .iter() - .map(|v| v.map(hex_bytes).transpose()) - .collect::>()?; - - let keys = dict.keys().clone(); - let mut new_keys = Vec::with_capacity(values.len()); - - for key in keys.iter() { - let key = key.map(|k| values[k as usize].clone()).unwrap_or(None); - new_keys.push(key); - } - - let string_array_values = StringArray::from(new_keys); Ok(ColumnarValue::Array(Arc::new(string_array_values))) } _ => exec_err!( diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs index 7f53583e8d..8702ce7070 100644 --- a/native/spark-expr/src/cast.rs +++ b/native/spark-expr/src/cast.rs @@ -31,7 +31,7 @@ use arrow::{ GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait, PrimitiveArray, }, - compute::{cast_with_options, unary, CastOptions}, + compute::{cast_with_options, take, unary, CastOptions}, datatypes::{ ArrowPrimitiveType, Decimal128Type, DecimalType, Float32Type, Float64Type, Int64Type, TimestampMicrosecondType, @@ -40,6 +40,7 @@ use arrow::{ record_batch::RecordBatch, util::display::FormatOptions, }; +use arrow_array::DictionaryArray; use arrow_schema::{DataType, Schema}; use datafusion_common::{ @@ -98,7 +99,6 @@ macro_rules! cast_utf8_to_int { result }}; } - macro_rules! cast_utf8_to_timestamp { ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ let len = $array.len(); @@ -507,19 +507,27 @@ impl Cast { let to_type = &self.data_type; let array = array_with_timezone(array, self.timezone.clone(), Some(to_type))?; let from_type = array.data_type().clone(); - - // unpack dictionary string arrays first - // TODO: we are unpacking a dictionary-encoded array and then performing - // the cast. We could potentially improve performance here by casting the - // dictionary values directly without unpacking the array first, although this - // would add more complexity to the code let array = match &from_type { DataType::Dictionary(key_type, value_type) if key_type.as_ref() == &DataType::Int32 && (value_type.as_ref() == &DataType::Utf8 || value_type.as_ref() == &DataType::LargeUtf8) => { - cast_with_options(&array, value_type.as_ref(), &CAST_OPTIONS)? + let dict_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a dictionary array"); + + let casted_dictionary = DictionaryArray::::new( + dict_array.keys().clone(), + self.cast_array(dict_array.values().clone())?, + ); + + let casted_result = match to_type { + DataType::Dictionary(_, _) => Arc::new(casted_dictionary.clone()), + _ => take(casted_dictionary.values().as_ref(), dict_array.keys(), None)?, + }; + return Ok(spark_cast(casted_result, &from_type, to_type)); } _ => array, }; @@ -724,26 +732,31 @@ impl Cast { .downcast_ref::>() .expect("Expected a string array"); - let cast_array: ArrayRef = match to_type { - DataType::Date32 => { - let len = string_array.len(); - let mut cast_array = PrimitiveArray::::builder(len); - for i in 0..len { - if !string_array.is_null(i) { - match date_parser(string_array.value(i), eval_mode) { - Ok(Some(cast_value)) => cast_array.append_value(cast_value), - Ok(None) => cast_array.append_null(), - Err(e) => return Err(e), - } - } else { - cast_array.append_null() - } + if to_type != &DataType::Date32 { + unreachable!("Invalid data type {:?} in cast from string", to_type); + } + + let len = string_array.len(); + let mut cast_array = PrimitiveArray::::builder(len); + + for i in 0..len { + let value = if string_array.is_null(i) { + None + } else { + match date_parser(string_array.value(i), eval_mode) { + Ok(Some(cast_value)) => Some(cast_value), + Ok(None) => None, + Err(e) => return Err(e), } - Arc::new(cast_array.finish()) as ArrayRef + }; + + match value { + Some(cast_value) => cast_array.append_value(cast_value), + None => cast_array.append_null(), } - _ => unreachable!("Invalid data type {:?} in cast from string", to_type), - }; - Ok(cast_array) + } + + Ok(Arc::new(cast_array.finish()) as ArrayRef) } fn cast_string_to_timestamp( @@ -1796,6 +1809,37 @@ mod tests { assert_eq!(result.len(), 2); } + #[test] + fn test_cast_dict_string_to_timestamp() -> DataFusionResult<()> { + // prepare input data + let keys = Int32Array::from(vec![0, 1]); + let values: ArrayRef = Arc::new(StringArray::from(vec![ + Some("2020-01-01T12:34:56.123456"), + Some("T2"), + ])); + let dict_array = Arc::new(DictionaryArray::new(keys, values)); + + // prepare cast expression + let timezone = "UTC".to_string(); + let expr = Arc::new(Column::new("a", 0)); // this is not used by the test + let cast = Cast::new( + expr, + DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())), + EvalMode::Legacy, + timezone.clone(), + ); + + // test casting string dictionary array to timestamp array + let result = cast.cast_array(dict_array)?; + assert_eq!( + *result.data_type(), + DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.into())) + ); + assert_eq!(result.len(), 2); + + Ok(()) + } + #[test] fn date_parser_test() { for date in &[