Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 79 additions & 24 deletions arrow/src/compute/kernels/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
/// * Time32 and Time64: precision lost when going to higher interval
/// * Timestamp and Date{32|64}: precision lost when going to higher interval
/// * Temporal to/from backing primitive: zero-copy with data type change
/// * Casting from `float32/float64` to `Decimal(precision, scale)` rounds to the `scale` decimals
/// (i.e. casting 6.4999 to Decimal(10, 1) becomes 6.5). This is the breaking change from `26.0.0`.
/// It used to truncate it instead of round (i.e. outputs 6.4 instead)
///
/// Unsupported Casts
/// * To or from `StructArray`
Expand Down Expand Up @@ -353,7 +356,7 @@ where
{
let mul = 10_f64.powi(scale as i32);

unary::<T, _, Decimal128Type>(array, |v| (v.as_() * mul) as i128)
unary::<T, _, Decimal128Type>(array, |v| (v.as_() * mul).round() as i128)
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
}
Expand All @@ -368,9 +371,11 @@ where
{
let mul = 10_f64.powi(scale as i32);

unary::<T, _, Decimal256Type>(array, |v| i256::from_i128((v.as_() * mul) as i128))
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
unary::<T, _, Decimal256Type>(array, |v| {
i256::from_i128((v.as_() * mul).round() as i128)
})
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
}

/// Cast the primitive array using [`PrimitiveArray::reinterpret_cast`]
Expand Down Expand Up @@ -3192,8 +3197,8 @@ mod tests {
Some(2.2),
Some(4.4),
None,
Some(1.123_456_7),
Some(1.123_456_7),
Some(1.123_456_4), // round down
Some(1.123_456_7), // round up
Comment on lines +3200 to +3201

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i change the test cases to

  1. round down
  2. round up

]);
let array = Arc::new(array) as ArrayRef;
generate_cast_test_case!(
Expand All @@ -3205,8 +3210,8 @@ mod tests {
Some(2200000_i128),
Some(4400000_i128),
None,
Some(1123456_i128),
Some(1123456_i128),
Some(1123456_i128), // round down
Some(1123457_i128), // round up
Comment on lines +3213 to +3214

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here's the result

]
);

Expand All @@ -3216,9 +3221,10 @@ mod tests {
Some(2.2),
Some(4.4),
None,
Some(1.123_456_789_123_4),
Some(1.123_456_789_012_345_6),
Some(1.123_456_789_012_345_6),
Some(1.123_456_489_123_4), // round up
Some(1.123_456_789_123_4), // round up
Some(1.123_456_489_012_345_6), // round down
Some(1.123_456_789_012_345_6), // round up
]);
let array = Arc::new(array) as ArrayRef;
generate_cast_test_case!(
Expand All @@ -3230,9 +3236,10 @@ mod tests {
Some(2200000_i128),
Some(4400000_i128),
None,
Some(1123456_i128),
Some(1123456_i128),
Some(1123456_i128),
Some(1123456_i128), // round down
Some(1123457_i128), // round up
Some(1123456_i128), // round down
Some(1123457_i128), // round up
]
);
}
Expand Down Expand Up @@ -3307,8 +3314,8 @@ mod tests {
Some(2.2),
Some(4.4),
None,
Some(1.123_456_7),
Some(1.123_456_7),
Some(1.123_456_4), // round down
Some(1.123_456_7), // round up
]);
let array = Arc::new(array) as ArrayRef;
generate_cast_test_case!(
Expand All @@ -3320,8 +3327,8 @@ mod tests {
Some(i256::from_i128(2200000_i128)),
Some(i256::from_i128(4400000_i128)),
None,
Some(i256::from_i128(1123456_i128)),
Some(i256::from_i128(1123456_i128)),
Some(i256::from_i128(1123456_i128)), // round down
Some(i256::from_i128(1123457_i128)), // round up
]
);

Expand All @@ -3331,9 +3338,10 @@ mod tests {
Some(2.2),
Some(4.4),
None,
Some(1.123_456_789_123_4),
Some(1.123_456_789_012_345_6),
Some(1.123_456_789_012_345_6),
Some(1.123_456_489_123_4), // round down
Some(1.123_456_789_123_4), // round up
Some(1.123_456_489_012_345_6), // round down
Some(1.123_456_789_012_345_6), // round up
]);
let array = Arc::new(array) as ArrayRef;
generate_cast_test_case!(
Expand All @@ -3345,9 +3353,10 @@ mod tests {
Some(i256::from_i128(2200000_i128)),
Some(i256::from_i128(4400000_i128)),
None,
Some(i256::from_i128(1123456_i128)),
Some(i256::from_i128(1123456_i128)),
Some(i256::from_i128(1123456_i128)),
Some(i256::from_i128(1123456_i128)), // round down
Some(i256::from_i128(1123457_i128)), // round up
Some(i256::from_i128(1123456_i128)), // round down
Some(i256::from_i128(1123457_i128)), // round up
]
);
}
Expand Down Expand Up @@ -5994,4 +6003,50 @@ mod tests {
.collect::<Vec<_>>();
assert_eq!(&out, &vec!["[0, 1, 2]", "[3, 4, 5]", "[6, 7]"]);
}

#[test]
#[cfg(not(feature = "force_validate"))]
fn test_cast_f64_to_decimal128() {
// to reproduce https://github.com/apache/arrow-rs/issues/2997

let decimal_type = DataType::Decimal128(18, 2);
let array = Float64Array::from(vec![
Some(0.0699999999),
Some(0.0659999999),
Some(0.0650000000),
Some(0.0649999999),
Comment on lines +6012 to +6017

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jimexist added the 0.065 case which will be rounded up to 0.07

]);
let array = Arc::new(array) as ArrayRef;
generate_cast_test_case!(
&array,
Decimal128Array,
&decimal_type,
vec![
Some(7_i128), // round up
Some(7_i128), // round up
Some(7_i128), // round up
Some(6_i128), // round down
]
);

let decimal_type = DataType::Decimal128(18, 3);
let array = Float64Array::from(vec![
Some(0.0699999999),
Some(0.0659999999),
Some(0.0650000000),
Some(0.0649999999),
]);
let array = Arc::new(array) as ArrayRef;
generate_cast_test_case!(
&array,
Decimal128Array,
&decimal_type,
vec![
Some(70_i128), // round up
Some(66_i128), // round up
Some(65_i128), // round down
Some(65_i128), // round up
]
);
}
}