From 71841a392218218ea972250707e1ea45a24c2642 Mon Sep 17 00:00:00 2001 From: Wei-Ting Kuo Date: Wed, 2 Nov 2022 17:52:24 +0800 Subject: [PATCH 1/5] add .round() before casting to integer --- arrow/src/compute/kernels/cast.rs | 70 ++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 24 deletions(-) diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index 4c724b6401b9..3081a9e86150 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -353,7 +353,7 @@ where { let mul = 10_f64.powi(scale as i32); - unary::(array, |v| (v.as_() * mul) as i128) + unary::(array, |v| (v.as_() * mul).round() as i128) .with_precision_and_scale(precision, scale) .map(|a| Arc::new(a) as ArrayRef) } @@ -368,9 +368,11 @@ where { let mul = 10_f64.powi(scale as i32); - unary::(array, |v| i256::from_i128((v.as_() * mul) as i128)) - .with_precision_and_scale(precision, scale) - .map(|a| Arc::new(a) as ArrayRef) + unary::(array, |v| { + i256::from_i128((v.as_() * mul).round() as i128) + }) + .with_precision_and_scale(precision, scale) + .map(|a| Arc::new(a) as ArrayRef) } /// Cast the primitive array using [`PrimitiveArray::reinterpret_cast`] @@ -3192,8 +3194,8 @@ mod tests { Some(2.2), Some(4.4), None, - Some(1.123_456_7), - Some(1.123_456_7), + Some(1.123_456_4), // round down + Some(1.123_456_7), // round up ]); let array = Arc::new(array) as ArrayRef; generate_cast_test_case!( @@ -3205,8 +3207,8 @@ mod tests { Some(2200000_i128), Some(4400000_i128), None, - Some(1123456_i128), - Some(1123456_i128), + Some(1123456_i128), // round down + Some(1123457_i128), // round up ] ); @@ -3216,9 +3218,10 @@ mod tests { Some(2.2), Some(4.4), None, - Some(1.123_456_789_123_4), - Some(1.123_456_789_012_345_6), - Some(1.123_456_789_012_345_6), + Some(1.123_456_489_123_4), // round up + Some(1.123_456_789_123_4), // round up + Some(1.123_456_489_012_345_6), // round down + Some(1.123_456_789_012_345_6), // round up ]); let array = Arc::new(array) as ArrayRef; generate_cast_test_case!( @@ -3230,9 +3233,10 @@ mod tests { Some(2200000_i128), Some(4400000_i128), None, - Some(1123456_i128), - Some(1123456_i128), - Some(1123456_i128), + Some(1123456_i128), // round down + Some(1123457_i128), // round up + Some(1123456_i128), // round down + Some(1123457_i128), // round up ] ); } @@ -3307,8 +3311,8 @@ mod tests { Some(2.2), Some(4.4), None, - Some(1.123_456_7), - Some(1.123_456_7), + Some(1.123_456_4), // round down + Some(1.123_456_7), // round up ]); let array = Arc::new(array) as ArrayRef; generate_cast_test_case!( @@ -3320,8 +3324,8 @@ mod tests { Some(i256::from_i128(2200000_i128)), Some(i256::from_i128(4400000_i128)), None, - Some(i256::from_i128(1123456_i128)), - Some(i256::from_i128(1123456_i128)), + Some(i256::from_i128(1123456_i128)), // round down + Some(i256::from_i128(1123457_i128)), // round up ] ); @@ -3331,9 +3335,10 @@ mod tests { Some(2.2), Some(4.4), None, - Some(1.123_456_789_123_4), - Some(1.123_456_789_012_345_6), - Some(1.123_456_789_012_345_6), + Some(1.123_456_489_123_4), // round down + Some(1.123_456_789_123_4), // round up + Some(1.123_456_489_012_345_6), // round down + Some(1.123_456_789_012_345_6), // round up ]); let array = Arc::new(array) as ArrayRef; generate_cast_test_case!( @@ -3345,9 +3350,10 @@ mod tests { Some(i256::from_i128(2200000_i128)), Some(i256::from_i128(4400000_i128)), None, - Some(i256::from_i128(1123456_i128)), - Some(i256::from_i128(1123456_i128)), - Some(i256::from_i128(1123456_i128)), + Some(i256::from_i128(1123456_i128)), // round down + Some(i256::from_i128(1123457_i128)), // round up + Some(i256::from_i128(1123456_i128)), // round down + Some(i256::from_i128(1123457_i128)), // round up ] ); } @@ -5994,4 +6000,20 @@ mod tests { .collect::>(); assert_eq!(&out, &vec!["[0, 1, 2]", "[3, 4, 5]", "[6, 7]"]); } + + #[test] + #[cfg(not(feature = "force_validate"))] + fn test_cast_f64_to_decimal128() { + // to reproduce https://github.com/apache/arrow-rs/issues/2997 + + let decimal_type = DataType::Decimal128(18, 2); + let array = Float64Array::from(vec![Some(0.0699999999)]); + let array = Arc::new(array) as ArrayRef; + generate_cast_test_case!( + &array, + Decimal128Array, + &decimal_type, + vec![Some(7_i128),] + ); + } } From f230cf7a6b6ddb1124d5443d975662253ca99c04 Mon Sep 17 00:00:00 2001 From: Wei-Ting Kuo Date: Wed, 2 Nov 2022 19:06:51 +0800 Subject: [PATCH 2/5] add more test cases --- arrow/src/compute/kernels/cast.rs | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index 3081a9e86150..b0cb0c641969 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -6007,13 +6007,39 @@ mod tests { // to reproduce https://github.com/apache/arrow-rs/issues/2997 let decimal_type = DataType::Decimal128(18, 2); - let array = Float64Array::from(vec![Some(0.0699999999)]); + let array = Float64Array::from(vec![ + Some(0.0699999999), + Some(0.0659999999), + Some(0.0649999999), + ]); let array = Arc::new(array) as ArrayRef; generate_cast_test_case!( &array, Decimal128Array, &decimal_type, - vec![Some(7_i128),] + vec![ + Some(7_i128), // round up + Some(7_i128), // round up + Some(6_i128), // round down + ] + ); + + let decimal_type = DataType::Decimal128(18, 3); + let array = Float64Array::from(vec![ + Some(0.0699999999), + Some(0.0659999999), + Some(0.0649999999), + ]); + let array = Arc::new(array) as ArrayRef; + generate_cast_test_case!( + &array, + Decimal128Array, + &decimal_type, + vec![ + Some(70_i128), // round up + Some(66_i128), // round up + Some(65_i128), // round up + ] ); } } From 5d61321cf7ee286dea61449c17800e458e0f36a4 Mon Sep 17 00:00:00 2001 From: Wei-Ting Kuo Date: Wed, 2 Nov 2022 23:22:36 +0800 Subject: [PATCH 3/5] update test cases --- arrow/src/compute/kernels/cast.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index b0cb0c641969..fbabc1ab8acd 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -6010,6 +6010,7 @@ mod tests { let array = Float64Array::from(vec![ Some(0.0699999999), Some(0.0659999999), + Some(0.0650000000), Some(0.0649999999), ]); let array = Arc::new(array) as ArrayRef; @@ -6018,6 +6019,7 @@ mod tests { Decimal128Array, &decimal_type, vec![ + Some(7_i128), // round up Some(7_i128), // round up Some(7_i128), // round up Some(6_i128), // round down @@ -6028,6 +6030,7 @@ mod tests { let array = Float64Array::from(vec![ Some(0.0699999999), Some(0.0659999999), + Some(0.0650000000), Some(0.0649999999), ]); let array = Arc::new(array) as ArrayRef; @@ -6038,6 +6041,7 @@ mod tests { vec![ Some(70_i128), // round up Some(66_i128), // round up + Some(65_i128), // round down Some(65_i128), // round up ] ); From b10e35d440be3c290ee2e10502b2fc23f033c56c Mon Sep 17 00:00:00 2001 From: Wei-Ting Kuo Date: Wed, 2 Nov 2022 23:31:35 +0800 Subject: [PATCH 4/5] add doc --- arrow/src/compute/kernels/cast.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index fbabc1ab8acd..fac5a79734bc 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -297,6 +297,9 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { /// * Time32 and Time64: precision lost when going to higher interval /// * Timestamp and Date{32|64}: precision lost when going to higher interval /// * Temporal to/from backing primitive: zero-copy with data type change +/// * Casting from `float32/float64` to `Decimal(precision, scale)` rounds to the `scale` decimals +/// (i.e. casting 6.4999 to Decimal(10, 1) becomes 6.5). This is the breaking change from `26.0.0`. +/// It used to truncate it instead of round (i.e. outputs 6.4 instead) /// /// Unsupported Casts /// * To or from `StructArray` From 8d249fd66066d056cb64a78e807446dc2878d48d Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Thu, 3 Nov 2022 17:31:45 +1300 Subject: [PATCH 5/5] Format --- arrow/src/compute/kernels/cast.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index fac5a79734bc..4ad8dd99e73e 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -297,8 +297,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { /// * Time32 and Time64: precision lost when going to higher interval /// * Timestamp and Date{32|64}: precision lost when going to higher interval /// * Temporal to/from backing primitive: zero-copy with data type change -/// * Casting from `float32/float64` to `Decimal(precision, scale)` rounds to the `scale` decimals -/// (i.e. casting 6.4999 to Decimal(10, 1) becomes 6.5). This is the breaking change from `26.0.0`. +/// * Casting from `float32/float64` to `Decimal(precision, scale)` rounds to the `scale` decimals +/// (i.e. casting 6.4999 to Decimal(10, 1) becomes 6.5). This is the breaking change from `26.0.0`. /// It used to truncate it instead of round (i.e. outputs 6.4 instead) /// /// Unsupported Casts