diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index a3abe545d529..3e23a059bf3e 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -344,16 +344,43 @@ fn cast_floating_point_to_decimal128( array: &PrimitiveArray, precision: u8, scale: u8, + cast_options: &CastOptions, ) -> Result where ::Native: AsPrimitive, { let mul = 10_f64.powi(scale as i32); - array - .unary::<_, Decimal128Type>(|v| (v.as_() * mul).round() as i128) - .with_precision_and_scale(precision, scale) - .map(|a| Arc::new(a) as ArrayRef) + if cast_options.safe { + let iter = array + .iter() + .map(|v| v.and_then(|v| (mul * v.as_()).round().to_i128())); + let casted_array = + unsafe { PrimitiveArray::::from_trusted_len_iter(iter) }; + casted_array + .with_precision_and_scale(precision, scale) + .map(|a| Arc::new(a) as ArrayRef) + } else { + array + .try_unary::<_, Decimal128Type, _>(|v| { + mul.mul_checked(v.as_()).and_then(|value| { + let mul_v = value.round(); + let integer: i128 = mul_v.to_i128().ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Overflowing on {:?}", + Decimal128Type::PREFIX, + precision, + scale, + v + )) + })?; + + Ok(integer) + }) + }) + .and_then(|a| a.with_precision_and_scale(precision, scale)) + .map(|a| Arc::new(a) as ArrayRef) + } } fn cast_floating_point_to_decimal256( @@ -588,11 +615,13 @@ pub fn cast_with_options( as_primitive_array::(array), *precision, *scale, + cast_options, ), Float64 => cast_floating_point_to_decimal128( as_primitive_array::(array), *precision, *scale, + cast_options, ), Null => Ok(new_null_array(to_type, array.len())), _ => Err(ArrowError::CastError(format!( @@ -6110,4 +6139,31 @@ mod tests { ); assert!(casted_array.is_err()); } + + #[test] + fn test_cast_floating_point_to_decimal128_overflow() { + let array = Float64Array::from(vec![f64::MAX]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(38, 30), + &CastOptions { safe: true }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(38, 30), + &CastOptions { safe: false }, + ); + let err = casted_array.unwrap_err().to_string(); + let expected_error = "Cast error: Cannot cast to Decimal128(38, 30)"; + assert!( + err.contains(expected_error), + "did not find expected error '{}' in actual error '{}'", + expected_error, + err + ); + } }