Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,65 +118,36 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFus

Ok(ColumnarValue::Array(Arc::new(hexed)))
}
DataType::Dictionary(_, value_type) if matches!(**value_type, DataType::Int64) => {
DataType::Dictionary(_, value_type) => {
let dict = as_dictionary_array::<Int32Type>(&array);

let hexed_values = as_int64_array(dict.values())?;
let values = hexed_values
let values = match **value_type {
DataType::Int64 => as_int64_array(dict.values())?
.iter()
.map(|v| v.map(hex_int64))
.collect::<Vec<_>>(),
DataType::Utf8 => as_string_array(dict.values())
.iter()
.map(|v| v.map(hex_bytes).transpose())
.collect::<Result<_, _>>()?,
DataType::Binary => as_binary_array(dict.values())?
.iter()
.map(|v| v.map(hex_bytes).transpose())
.collect::<Result<_, _>>()?,
_ => exec_err!(
"hex got an unexpected argument type: {:?}",
array.data_type()
)?,
};

let new_values: Vec<Option<String>> = dict
.keys()
.iter()
.map(|v| v.map(hex_int64))
.collect::<Vec<_>>();
.map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None))
.collect();

let keys = dict.keys().clone();
let mut new_keys = Vec::with_capacity(values.len());
let string_array_values = StringArray::from(new_values);

for key in keys.iter() {
let key = key.map(|k| values[k as usize].clone()).unwrap_or(None);
new_keys.push(key);
}

let string_array_values = StringArray::from(new_keys);
Ok(ColumnarValue::Array(Arc::new(string_array_values)))
}
DataType::Dictionary(_, value_type) if matches!(**value_type, DataType::Utf8) => {
let dict = as_dictionary_array::<Int32Type>(&array);

let hexed_values = as_string_array(dict.values());
let values: Vec<Option<String>> = hexed_values
.iter()
.map(|v| v.map(hex_bytes).transpose())
.collect::<Result<_, _>>()?;

let keys = dict.keys().clone();

let mut new_keys = Vec::with_capacity(values.len());

for key in keys.iter() {
let key = key.map(|k| values[k as usize].clone()).unwrap_or(None);
new_keys.push(key);
}

let string_array_values = StringArray::from(new_keys);
Ok(ColumnarValue::Array(Arc::new(string_array_values)))
}
DataType::Dictionary(_, value_type) if matches!(**value_type, DataType::Binary) => {
let dict = as_dictionary_array::<Int32Type>(&array);

let hexed_values = as_binary_array(dict.values())?;
let values: Vec<Option<String>> = hexed_values
.iter()
.map(|v| v.map(hex_bytes).transpose())
.collect::<Result<_, _>>()?;

let keys = dict.keys().clone();
let mut new_keys = Vec::with_capacity(values.len());

for key in keys.iter() {
let key = key.map(|k| values[k as usize].clone()).unwrap_or(None);
new_keys.push(key);
}

let string_array_values = StringArray::from(new_keys);
Ok(ColumnarValue::Array(Arc::new(string_array_values)))
}
_ => exec_err!(
Expand Down
98 changes: 71 additions & 27 deletions native/spark-expr/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use arrow::{
GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait,
PrimitiveArray,
},
compute::{cast_with_options, unary, CastOptions},
compute::{cast_with_options, take, unary, CastOptions},
datatypes::{
ArrowPrimitiveType, Decimal128Type, DecimalType, Float32Type, Float64Type, Int64Type,
TimestampMicrosecondType,
Expand All @@ -40,6 +40,7 @@ use arrow::{
record_batch::RecordBatch,
util::display::FormatOptions,
};
use arrow_array::DictionaryArray;
use arrow_schema::{DataType, Schema};

use datafusion_common::{
Expand Down Expand Up @@ -98,7 +99,6 @@ macro_rules! cast_utf8_to_int {
result
}};
}

macro_rules! cast_utf8_to_timestamp {
($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
let len = $array.len();
Expand Down Expand Up @@ -507,19 +507,27 @@ impl Cast {
let to_type = &self.data_type;
let array = array_with_timezone(array, self.timezone.clone(), Some(to_type))?;
let from_type = array.data_type().clone();

// unpack dictionary string arrays first
// TODO: we are unpacking a dictionary-encoded array and then performing
// the cast. We could potentially improve performance here by casting the
// dictionary values directly without unpacking the array first, although this
// would add more complexity to the code
let array = match &from_type {
DataType::Dictionary(key_type, value_type)
if key_type.as_ref() == &DataType::Int32
&& (value_type.as_ref() == &DataType::Utf8
|| value_type.as_ref() == &DataType::LargeUtf8) =>
{
cast_with_options(&array, value_type.as_ref(), &CAST_OPTIONS)?
let dict_array = array
.as_any()
.downcast_ref::<DictionaryArray<Int32Type>>()
.expect("Expected a dictionary array");

let casted_dictionary = DictionaryArray::<Int32Type>::new(
dict_array.keys().clone(),
self.cast_array(dict_array.values().clone())?,
);

let casted_result = match to_type {
DataType::Dictionary(_, _) => Arc::new(casted_dictionary.clone()),
_ => take(casted_dictionary.values().as_ref(), dict_array.keys(), None)?,
};
return Ok(spark_cast(casted_result, &from_type, to_type));
}
_ => array,
};
Expand Down Expand Up @@ -724,26 +732,31 @@ impl Cast {
.downcast_ref::<GenericStringArray<i32>>()
.expect("Expected a string array");

let cast_array: ArrayRef = match to_type {
DataType::Date32 => {
let len = string_array.len();
let mut cast_array = PrimitiveArray::<Date32Type>::builder(len);
for i in 0..len {
if !string_array.is_null(i) {
match date_parser(string_array.value(i), eval_mode) {
Ok(Some(cast_value)) => cast_array.append_value(cast_value),
Ok(None) => cast_array.append_null(),
Err(e) => return Err(e),
}
} else {
cast_array.append_null()
}
if to_type != &DataType::Date32 {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a logic change here. Is this just refactoring?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, no logic change here.

unreachable!("Invalid data type {:?} in cast from string", to_type);
}

let len = string_array.len();
let mut cast_array = PrimitiveArray::<Date32Type>::builder(len);

for i in 0..len {
let value = if string_array.is_null(i) {
None
} else {
match date_parser(string_array.value(i), eval_mode) {
Ok(Some(cast_value)) => Some(cast_value),
Ok(None) => None,
Err(e) => return Err(e),
}
Arc::new(cast_array.finish()) as ArrayRef
};

match value {
Some(cast_value) => cast_array.append_value(cast_value),
None => cast_array.append_null(),
}
_ => unreachable!("Invalid data type {:?} in cast from string", to_type),
};
Ok(cast_array)
}

Ok(Arc::new(cast_array.finish()) as ArrayRef)
}

fn cast_string_to_timestamp(
Expand Down Expand Up @@ -1796,6 +1809,37 @@ mod tests {
assert_eq!(result.len(), 2);
}

#[test]
fn test_cast_dict_string_to_timestamp() -> DataFusionResult<()> {
// prepare input data
let keys = Int32Array::from(vec![0, 1]);
let values: ArrayRef = Arc::new(StringArray::from(vec![
Some("2020-01-01T12:34:56.123456"),
Some("T2"),
]));
let dict_array = Arc::new(DictionaryArray::new(keys, values));

// prepare cast expression
let timezone = "UTC".to_string();
let expr = Arc::new(Column::new("a", 0)); // this is not used by the test
let cast = Cast::new(
expr,
DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())),
EvalMode::Legacy,
timezone.clone(),
);

// test casting string dictionary array to timestamp array
let result = cast.cast_array(dict_array)?;
assert_eq!(
*result.data_type(),
DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.into()))
);
assert_eq!(result.len(), 2);

Ok(())
}

#[test]
fn date_parser_test() {
for date in &[
Expand Down