diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index 4c724b6401b9..4ad8dd99e73e 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -297,6 +297,9 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { /// * Time32 and Time64: precision lost when going to higher interval /// * Timestamp and Date{32|64}: precision lost when going to higher interval /// * Temporal to/from backing primitive: zero-copy with data type change +/// * Casting from `float32/float64` to `Decimal(precision, scale)` rounds to the `scale` decimals +/// (i.e. casting 6.4999 to Decimal(10, 1) becomes 6.5). This is the breaking change from `26.0.0`. +/// It used to truncate it instead of round (i.e. outputs 6.4 instead) /// /// Unsupported Casts /// * To or from `StructArray` @@ -353,7 +356,7 @@ where { let mul = 10_f64.powi(scale as i32); - unary::(array, |v| (v.as_() * mul) as i128) + unary::(array, |v| (v.as_() * mul).round() as i128) .with_precision_and_scale(precision, scale) .map(|a| Arc::new(a) as ArrayRef) } @@ -368,9 +371,11 @@ where { let mul = 10_f64.powi(scale as i32); - unary::(array, |v| i256::from_i128((v.as_() * mul) as i128)) - .with_precision_and_scale(precision, scale) - .map(|a| Arc::new(a) as ArrayRef) + unary::(array, |v| { + i256::from_i128((v.as_() * mul).round() as i128) + }) + .with_precision_and_scale(precision, scale) + .map(|a| Arc::new(a) as ArrayRef) } /// Cast the primitive array using [`PrimitiveArray::reinterpret_cast`] @@ -3192,8 +3197,8 @@ mod tests { Some(2.2), Some(4.4), None, - Some(1.123_456_7), - Some(1.123_456_7), + Some(1.123_456_4), // round down + Some(1.123_456_7), // round up ]); let array = Arc::new(array) as ArrayRef; generate_cast_test_case!( @@ -3205,8 +3210,8 @@ mod tests { Some(2200000_i128), Some(4400000_i128), None, - Some(1123456_i128), - Some(1123456_i128), + Some(1123456_i128), // round down + Some(1123457_i128), // round up ] ); @@ -3216,9 +3221,10 @@ mod tests { Some(2.2), Some(4.4), None, - Some(1.123_456_789_123_4), - Some(1.123_456_789_012_345_6), - Some(1.123_456_789_012_345_6), + Some(1.123_456_489_123_4), // round up + Some(1.123_456_789_123_4), // round up + Some(1.123_456_489_012_345_6), // round down + Some(1.123_456_789_012_345_6), // round up ]); let array = Arc::new(array) as ArrayRef; generate_cast_test_case!( @@ -3230,9 +3236,10 @@ mod tests { Some(2200000_i128), Some(4400000_i128), None, - Some(1123456_i128), - Some(1123456_i128), - Some(1123456_i128), + Some(1123456_i128), // round down + Some(1123457_i128), // round up + Some(1123456_i128), // round down + Some(1123457_i128), // round up ] ); } @@ -3307,8 +3314,8 @@ mod tests { Some(2.2), Some(4.4), None, - Some(1.123_456_7), - Some(1.123_456_7), + Some(1.123_456_4), // round down + Some(1.123_456_7), // round up ]); let array = Arc::new(array) as ArrayRef; generate_cast_test_case!( @@ -3320,8 +3327,8 @@ mod tests { Some(i256::from_i128(2200000_i128)), Some(i256::from_i128(4400000_i128)), None, - Some(i256::from_i128(1123456_i128)), - Some(i256::from_i128(1123456_i128)), + Some(i256::from_i128(1123456_i128)), // round down + Some(i256::from_i128(1123457_i128)), // round up ] ); @@ -3331,9 +3338,10 @@ mod tests { Some(2.2), Some(4.4), None, - Some(1.123_456_789_123_4), - Some(1.123_456_789_012_345_6), - Some(1.123_456_789_012_345_6), + Some(1.123_456_489_123_4), // round down + Some(1.123_456_789_123_4), // round up + Some(1.123_456_489_012_345_6), // round down + Some(1.123_456_789_012_345_6), // round up ]); let array = Arc::new(array) as ArrayRef; generate_cast_test_case!( @@ -3345,9 +3353,10 @@ mod tests { Some(i256::from_i128(2200000_i128)), Some(i256::from_i128(4400000_i128)), None, - Some(i256::from_i128(1123456_i128)), - Some(i256::from_i128(1123456_i128)), - Some(i256::from_i128(1123456_i128)), + Some(i256::from_i128(1123456_i128)), // round down + Some(i256::from_i128(1123457_i128)), // round up + Some(i256::from_i128(1123456_i128)), // round down + Some(i256::from_i128(1123457_i128)), // round up ] ); } @@ -5994,4 +6003,50 @@ mod tests { .collect::>(); assert_eq!(&out, &vec!["[0, 1, 2]", "[3, 4, 5]", "[6, 7]"]); } + + #[test] + #[cfg(not(feature = "force_validate"))] + fn test_cast_f64_to_decimal128() { + // to reproduce https://github.com/apache/arrow-rs/issues/2997 + + let decimal_type = DataType::Decimal128(18, 2); + let array = Float64Array::from(vec![ + Some(0.0699999999), + Some(0.0659999999), + Some(0.0650000000), + Some(0.0649999999), + ]); + let array = Arc::new(array) as ArrayRef; + generate_cast_test_case!( + &array, + Decimal128Array, + &decimal_type, + vec![ + Some(7_i128), // round up + Some(7_i128), // round up + Some(7_i128), // round up + Some(6_i128), // round down + ] + ); + + let decimal_type = DataType::Decimal128(18, 3); + let array = Float64Array::from(vec![ + Some(0.0699999999), + Some(0.0659999999), + Some(0.0650000000), + Some(0.0649999999), + ]); + let array = Arc::new(array) as ArrayRef; + generate_cast_test_case!( + &array, + Decimal128Array, + &decimal_type, + vec![ + Some(70_i128), // round up + Some(66_i128), // round up + Some(65_i128), // round down + Some(65_i128), // round up + ] + ); + } }