From ceb75937acf9d74faf3c00e8c259516ee42ce477 Mon Sep 17 00:00:00 2001 From: theirix Date: Sat, 30 May 2026 18:13:10 +0100 Subject: [PATCH 01/11] Support decimals for gcd and lcm --- datafusion/functions/src/math/common.rs | 320 +++++++++++++++ datafusion/functions/src/math/gcd.rs | 431 +++++++++++++++++--- datafusion/functions/src/math/lcm.rs | 299 ++++++++++++-- datafusion/functions/src/math/mod.rs | 1 + datafusion/functions/src/utils.rs | 90 +++- datafusion/sqllogictest/test_files/math.slt | 46 +++ 6 files changed, 1077 insertions(+), 110 deletions(-) create mode 100644 datafusion/functions/src/math/common.rs diff --git a/datafusion/functions/src/math/common.rs b/datafusion/functions/src/math/common.rs new file mode 100644 index 0000000000000..9bb6f6fe1e35c --- /dev/null +++ b/datafusion/functions/src/math/common.rs @@ -0,0 +1,320 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrowNativeTypeOp; +use arrow::error::ArrowError; +use num_traits::{CheckedMul, CheckedNeg, Signed}; +use std::fmt::Display; +use std::mem::swap; +use std::ops::RemAssign; + +/// A gcd helper to compute GCD using Euclidean GCD algorithm +/// on non-negative numbers (scalars and decimals) +fn gcd_helper(a: T, b: T) -> Result +where + T: ArrowNativeTypeOp + RemAssign + CheckedNeg, +{ + debug_assert!(a >= T::ZERO); + debug_assert!(b >= T::ZERO); + let (mut a, mut b) = if a > b { (a, b) } else { (b, a) }; + + while b != T::ZERO { + swap(&mut a, &mut b); + b %= a; + } + + Ok(a) +} + +/// Computes gcd of two unsigned integers using Binary GCD algorithm +/// Faster, works with integers only +pub(crate) fn unsigned_gcd(mut a: u64, mut b: u64) -> u64 { + if a == 0 { + return b; + } + if b == 0 { + return a; + } + + let shift = (a | b).trailing_zeros(); + a >>= a.trailing_zeros(); + loop { + b >>= b.trailing_zeros(); + if a > b { + swap(&mut a, &mut b); + } + b -= a; + if b == 0 { + return a << shift; + } + } +} + +/// Computes gcd of two signed numbers (integers or decimals), +/// checking for output integer overflow +pub(crate) fn gcd_signed(x: T, y: T) -> Result +where + T: ArrowNativeTypeOp + RemAssign + Signed + CheckedNeg, +{ + // Make absolute values, keeping type + let a = if x.is_positive() { + x + } else { + x.checked_neg() + .ok_or_else(|| ArrowError::ComputeError("Signed integer overflow".into()))? + }; + let b = if y.is_positive() { + y + } else { + y.checked_neg() + .ok_or_else(|| ArrowError::ComputeError("Signed integer overflow".into()))? + }; + // Call with signed numbers + gcd_helper(a, b) +} + +/// Computes gcd of two signed integers +pub(crate) fn gcd_signed_int(x: i64, y: i64) -> Result { + let a = x.unsigned_abs(); + let b = y.unsigned_abs(); + + // Call with unsigned numbers + let r = unsigned_gcd(a, b); + // gcd(i64::MIN, i64::MIN) = u64::MIN.unsigned_abs() cannot fit into i64 + r.try_into().map_err(|_| { + ArrowError::ComputeError(format!("Signed integer overflow in GCD({x}, {y})")) + }) +} + +/// Computes lcm of two signed numbers (integers or decimals) +pub(crate) fn lcm_signed(x: T, y: T) -> Result +where + T: ArrowNativeTypeOp + RemAssign + Signed + CheckedNeg + CheckedMul + Display, +{ + if x == T::ZERO || y == T::ZERO { + return Ok(T::ZERO); + } + + // Make absolute values, keeping type + let a = if x.is_positive() { + x + } else { + x.checked_neg() + .ok_or_else(|| ArrowError::ComputeError("Signed integer overflow".into()))? + }; + let b = if y.is_positive() { + y + } else { + y.checked_neg() + .ok_or_else(|| ArrowError::ComputeError("Signed integer overflow".into()))? + }; + // Call with signed numbers + let gcd = gcd_helper(a, b)?; + // gcd is not zero since both a and b are not zero, so the division is safe. + (a / gcd).checked_mul(&b).ok_or_else(|| { + ArrowError::ComputeError(format!("Signed integer overflow in LCM({x}, {y})")) + }) +} + +/// Computes lcm of two signed integers, +/// checking for output integer overflow +pub(crate) fn lcm_signed_int(x: i64, y: i64) -> Result { + if x == 0 || y == 0 { + return Ok(0); + } + + let a = x.unsigned_abs(); + let b = y.unsigned_abs(); + + let gcd = gcd_helper::(a, b)?; + // gcd is not zero since both a and b are not zero, so the division is safe. + (a / gcd) + .checked_mul(b) + .and_then(|v| i64::try_from(v).ok()) + .ok_or_else(|| { + ArrowError::ComputeError(format!("Signed integer overflow in LCM({x}, {y})")) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_buffer::i256; + + const GCD_COMMON_TEST_CASES: [(i64, i64, i64); 18] = [ + // Basic cases + (48, 18, 6), + (54, 24, 6), + (100, 50, 50), + (17, 19, 1), + (21, 14, 7), + // Edge cases with 0 + (0, 0, 0), + (0, 5, 5), + (10, 0, 10), + // Same numbers + (7, 7, 7), + (100, 100, 100), + // One is 1 + (1, 1, 1), + (1, 100, 1), + (999, 1, 1), + // Large numbers + (1000000, 500000, 500000), + (123456, 789012, 12), + (999999, 111111, 111111), + // Powers of 2 + (64, 128, 64), + (1024, 2048, 1024), + ]; + + const LCM_COMMON_TEST_CASES: [(i64, i64, i64); 18] = [ + // Basic cases + (48, 18, 144), + (54, 24, 216), + (100, 50, 100), + (17, 19, 323), + (21, 14, 42), + // Edge cases with 0 + (0, 0, 0), + (0, 5, 0), + (10, 0, 0), + // Same numbers + (7, 7, 7), + (100, 100, 100), + // One is 1 + (1, 1, 1), + (1, 100, 100), + (999, 1, 999), + // Large numbers + (1_000_000, 500_000, 1_000_000), + (123_456, 789_012, 8_117_355_456), + (999_999, 111_111, 999_999), + // Powers of 2 + (64, 128, 128), + (1024, 2048, 2048), + ]; + + #[test] + fn test_gcd_i64() { + let test_cases: Vec<(i64, i64, i64)> = [ + GCD_COMMON_TEST_CASES.into(), + vec![ + // Max value cases + (1, i64::MAX, 1), + (i64::MAX, 1, 1), + (i64::MAX, i64::MAX, i64::MAX), + ], + ] + .concat(); + + // Success cases + for (a, b, expected) in test_cases { + let actual_euclidean = gcd_signed(a, b).expect("should succeed"); + assert_eq!( + actual_euclidean, expected, + "gcd_signed({a}, {b}) expected {expected}, actual {actual_euclidean}" + ); + let actual_binary: i64 = + unsigned_gcd(a.try_into().unwrap(), b.try_into().unwrap()) + .try_into() + .expect("overflow"); + assert_eq!( + actual_binary, expected, + "unsigned_gcd({a}, {b}) expected {expected}, actual {actual_binary}" + ); + } + } + + #[test] + fn test_gcd_decimal() { + let test_cases: Vec<(i256, i256, i256)> = [ + GCD_COMMON_TEST_CASES + .iter() + .map(|&(a, b, c)| (i256::from(a), i256::from(b), i256::from(c))) + .collect(), + vec![ + (i256::from(1), i256::MAX, i256::from(1)), + (i256::MAX, i256::from(1), i256::from(1)), + (i256::MAX, i256::MAX, i256::MAX), + ], + ] + .concat(); + + // Success cases + for (a, b, expected) in test_cases { + let actual = gcd_signed(a, b).expect("should succeed"); + assert_eq!( + actual, expected, + "euclid_gcd({a}, {b}) expected {expected}, actual {actual}" + ); + } + } + + #[test] + fn test_lcm_i64() { + let test_cases: Vec<(i64, i64, i64)> = [ + LCM_COMMON_TEST_CASES.into(), + vec![ + // Negative inputs - LCM is always non-negative + (-6, 4, 12), + (-4, -6, 12), + // Max value cases + (1, i64::MAX, i64::MAX), + (i64::MAX, 1, i64::MAX), + (i64::MAX, i64::MAX, i64::MAX), + ], + ] + .concat(); + + for (a, b, expected) in test_cases { + let actual = lcm_signed_int(a, b).expect("should succeed"); + assert_eq!( + actual, expected, + "lcm_signed_int({a}, {b}) expected {expected}, actual {actual}" + ); + } + } + + #[test] + fn test_lcm_decimal() { + let test_cases: Vec<(i256, i256, i256)> = [ + LCM_COMMON_TEST_CASES + .iter() + .map(|&(a, b, c)| (i256::from(a), i256::from(b), i256::from(c))) + .collect(), + vec![ + // Negative inputs - LCM is always non-negative + (i256::from(-6_i64), i256::from(4_i64), i256::from(12_i64)), + (i256::from(-4_i64), i256::from(-6_i64), i256::from(12_i64)), + // Max value cases + (i256::from(1_i64), i256::MAX, i256::MAX), + (i256::MAX, i256::from(1_i64), i256::MAX), + (i256::MAX, i256::MAX, i256::MAX), + ], + ] + .concat(); + + for (a, b, expected) in test_cases { + let actual = lcm_signed(a, b).expect("should succeed"); + assert_eq!( + actual, expected, + "lcm_signed({a}, {b}) expected {expected}, actual {actual}" + ); + } + } +} diff --git a/datafusion/functions/src/math/gcd.rs b/datafusion/functions/src/math/gcd.rs index 8b92c454d9b4c..b9490dd0dc845 100644 --- a/datafusion/functions/src/math/gcd.rs +++ b/datafusion/functions/src/math/gcd.rs @@ -17,16 +17,20 @@ use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; use arrow::compute::try_binary; -use arrow::datatypes::{DataType, Int64Type}; -use arrow::error::ArrowError; -use std::mem::swap; +use arrow::datatypes::{ + DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Int64Type, +}; use std::sync::Arc; +use crate::math::common::{gcd_signed, gcd_signed_int, unsigned_gcd}; +use crate::utils::calculate_binary_decimal_math_cast; +use datafusion_common::utils::take_function_args; use datafusion_common::{Result, ScalarValue, exec_err, internal_datafusion_err}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_expr_common::type_coercion::binary::decimal_coercion; use datafusion_macros::user_doc; #[user_doc( @@ -58,11 +62,7 @@ impl Default for GcdFunc { impl GcdFunc { pub fn new() -> Self { Self { - signature: Signature::uniform( - 2, - vec![DataType::Int64], - Volatility::Immutable, - ), + signature: Signature::user_defined(Volatility::Immutable), } } } @@ -76,37 +76,123 @@ impl ScalarUDFImpl for GcdFunc { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Int64) + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [arg1, arg2] = take_function_args(self.name(), arg_types)?; + + let coerced_type = match (arg1, arg2) { + (DataType::Null, _) | (_, DataType::Null) => Ok(DataType::Int64), + (lhs, rhs) if lhs.is_integer() && rhs.is_integer() => Ok(DataType::Int64), + (lhs, rhs) if lhs.is_decimal() || rhs.is_decimal() => { + decimal_coercion(lhs, rhs).map(Ok).unwrap_or_else(|| { + exec_err!( + "Unsupported argument types {lhs:?} and {rhs:?} for function {}", + self.name() + ) + }) + } + (lhs, rhs) => { + exec_err!( + "Unsupported argument types {lhs:?} and {rhs:?} for function {}", + self.name() + ) + } + }?; + Ok(vec![coerced_type.clone(), coerced_type]) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let number_rows = args.number_rows; let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| { internal_datafusion_err!("Expected 2 arguments for function gcd") })?; - match args { - [ColumnarValue::Array(a), ColumnarValue::Array(b)] => { - compute_gcd_for_arrays(&a, &b) + if args[0].data_type() == DataType::Int64 { + // Optimized path for both integers + match args { + [ColumnarValue::Array(a), ColumnarValue::Array(b)] => { + compute_gcd_for_arrays(&a, &b) + } + [ + ColumnarValue::Scalar(ScalarValue::Int64(a)), + ColumnarValue::Scalar(ScalarValue::Int64(b)), + ] => match (a, b) { + (Some(a), Some(b)) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + Some(gcd_signed_int(a, b)?), + ))), + _ => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))), + }, + [ + ColumnarValue::Array(a), + ColumnarValue::Scalar(ScalarValue::Int64(b)), + ] => compute_gcd_with_scalar(&a, b), + [ + ColumnarValue::Scalar(ScalarValue::Int64(a)), + ColumnarValue::Array(b), + ] => compute_gcd_with_scalar(&b, a), + _ => exec_err!("Unsupported argument types for function gcd"), } - [ - ColumnarValue::Scalar(ScalarValue::Int64(a)), - ColumnarValue::Scalar(ScalarValue::Int64(b)), - ] => match (a, b) { - (Some(a), Some(b)) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( - Some(compute_gcd(a, b)?), - ))), - _ => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))), - }, - [ - ColumnarValue::Array(a), - ColumnarValue::Scalar(ScalarValue::Int64(b)), - ] => compute_gcd_with_scalar(&a, b), - [ - ColumnarValue::Scalar(ScalarValue::Int64(a)), - ColumnarValue::Array(b), - ] => compute_gcd_with_scalar(&b, a), - _ => exec_err!("Unsupported argument types for function gcd"), + } else { + // Decimal path: convert left to array and use generic helper + let left = args[0].to_array(number_rows)?; + let right = &args[1]; + + let arr: ArrayRef = match (left.data_type(), right.data_type()) { + ( + lhs @ DataType::Decimal32(precision, scale), + rhs @ DataType::Decimal32(_, _), + ) if *lhs == rhs => calculate_binary_decimal_math_cast::< + Decimal32Type, + Decimal32Type, + Decimal32Type, + _, + >( + &left, right, gcd_signed, *precision, *scale, lhs + )?, + ( + lhs @ DataType::Decimal64(precision, scale), + rhs @ DataType::Decimal64(_, _), + ) if *lhs == rhs => calculate_binary_decimal_math_cast::< + Decimal64Type, + Decimal64Type, + Decimal64Type, + _, + >( + &left, right, gcd_signed, *precision, *scale, lhs + )?, + ( + lhs @ DataType::Decimal128(precision, scale), + rhs @ DataType::Decimal128(_, _), + ) if *lhs == rhs => calculate_binary_decimal_math_cast::< + Decimal128Type, + Decimal128Type, + Decimal128Type, + _, + >( + &left, right, gcd_signed, *precision, *scale, lhs + )?, + ( + lhs @ DataType::Decimal256(precision, scale), + rhs @ DataType::Decimal256(_, _), + ) if *lhs == rhs => calculate_binary_decimal_math_cast::< + Decimal256Type, + Decimal256Type, + Decimal256Type, + _, + >( + &left, right, gcd_signed, *precision, *scale, lhs + )?, + (lhs, rhs) => { + exec_err!( + "Unsupported data types {lhs:?} and {rhs:?} for function {}", + self.name() + ) + }?, + }; + Ok(ColumnarValue::Array(arr)) } } @@ -118,7 +204,7 @@ impl ScalarUDFImpl for GcdFunc { fn compute_gcd_for_arrays(a: &ArrayRef, b: &ArrayRef) -> Result { let a = a.as_primitive::(); let b = b.as_primitive::(); - try_binary(a, b, compute_gcd) + try_binary(a, b, gcd_signed_int) .map(|arr: PrimitiveArray| { ColumnarValue::Array(Arc::new(arr) as ArrayRef) }) @@ -141,44 +227,271 @@ fn compute_gcd_with_scalar(arr: &ArrayRef, scalar: Option) -> Result { let result: PrimitiveArray = - prim.try_unary(|val| compute_gcd(val, scalar_value))?; + prim.try_unary(|val| gcd_signed_int(val, scalar_value))?; Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) } None => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))), } } -/// Computes gcd of two unsigned integers using Binary GCD algorithm. -pub(super) fn unsigned_gcd(mut a: u64, mut b: u64) -> u64 { - if a == 0 { - return b; +#[cfg(test)] +mod tests { + use super::*; + use crate::math::common::gcd_signed; + use arrow::array::{Array, Decimal128Array, Int64Array}; + use arrow::datatypes::{DECIMAL128_MAX_PRECISION, Field}; + use arrow_buffer::i256; + use datafusion_common::ScalarValue; + use datafusion_common::cast::{as_decimal128_array, as_int64_array}; + use datafusion_common::config::ConfigOptions; + use std::sync::Arc; + + #[test] + fn test_i64_array() { + let arg_fields = vec![ + Field::new("a", DataType::Int64, true).into(), + Field::new("b", DataType::Int64, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Int64Array::from(vec![ + 0, 2, 0, 2, 15, 20, + ]))), + ColumnarValue::Array(Arc::new(Int64Array::from(vec![ + 0, 0, 2, 3, 10, 1000, + ]))), + ], + arg_fields, + number_rows: 6, + return_field: Field::new("f", DataType::Int64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = GcdFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function"); + + match result { + ColumnarValue::Array(arr) => { + let values = + as_int64_array(&arr).expect("failed to convert result to an array"); + assert_eq!(values.len(), 6); + assert_eq!(values.value(0), 0); + assert_eq!(values.value(1), 2); + assert_eq!(values.value(2), 2); + assert_eq!(values.value(3), 1); + assert_eq!(values.value(4), 5); + assert_eq!(values.value(5), 20); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } } - if b == 0 { - return a; + + #[test] + fn test_decimal_scalar() { + let arg_fields = vec![ + Field::new("a", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), true) + .into(), + Field::new("b", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), true) + .into(), + ]; + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(i128::from(2)), + DECIMAL128_MAX_PRECISION, + 0, + )), + ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(i128::from(3)), + DECIMAL128_MAX_PRECISION, + 0, + )), + ], + arg_fields, + number_rows: 1, + return_field: Field::new( + "f", + DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), + true, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = GcdFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function power"); + + match result { + ColumnarValue::Array(arr) => { + let ints = as_decimal128_array(&arr) + .expect("failed to convert result to an array"); + + assert_eq!(ints.len(), 1); + assert_eq!(ints.value(0), i128::from(1)); + // Signature stays the same as input + assert_eq!( + *arr.data_type(), + DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0) + ); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } } - let shift = (a | b).trailing_zeros(); - a >>= a.trailing_zeros(); - loop { - b >>= b.trailing_zeros(); - if a > b { - swap(&mut a, &mut b); + #[test] + fn test_decimal_array_scalar() { + let arg_fields = vec![ + Field::new("a", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), true) + .into(), + Field::new("b", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), true) + .into(), + ]; + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new( + Decimal128Array::from(vec![2, 15]) + .with_precision_and_scale(DECIMAL128_MAX_PRECISION, 0) + .unwrap(), + )), + ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(i128::from(3)), + DECIMAL128_MAX_PRECISION, + 0, + )), + ], + arg_fields, + number_rows: 2, + return_field: Field::new( + "f", + DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), + true, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = GcdFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function power"); + + match result { + ColumnarValue::Array(arr) => { + let ints = as_decimal128_array(&arr) + .expect("failed to convert result to an array"); + + assert_eq!(ints.len(), 2); + assert_eq!(ints.value(0), i128::from(1)); + assert_eq!(ints.value(1), i128::from(3)); + // Signature stays the same as input + assert_eq!( + *arr.data_type(), + DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0) + ); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } } - b -= a; - if b == 0 { - return a << shift; + } + + #[test] + fn test_coercion() { + let mut coerced = GcdFunc::new() + .coerce_types(&[DataType::Int64, DataType::Int32]) + .expect("coercion should succeed"); + assert_eq!(coerced, vec![DataType::Int64, DataType::Int64]); + + coerced = GcdFunc::new() + .coerce_types(&[DataType::Decimal128(10, 2), DataType::Int32]) + .expect("coercion should succeed"); + + assert_eq!( + coerced, + vec![DataType::Decimal128(12, 2), DataType::Decimal128(12, 2)] + ); + + coerced = GcdFunc::new() + .coerce_types(&[DataType::Decimal128(10, 2), DataType::Null]) + .expect("coercion should succeed"); + + assert_eq!(coerced, vec![DataType::Int64, DataType::Int64]); + } + + const GCD_COMMON_TEST_CASES: [(i64, i64, i64); 18] = [ + // Basic cases + (48, 18, 6), + (54, 24, 6), + (100, 50, 50), + (17, 19, 1), + (21, 14, 7), + // Edge cases with 0 + (0, 0, 0), + (0, 5, 5), + (10, 0, 10), + // Same numbers + (7, 7, 7), + (100, 100, 100), + // One is 1 + (1, 1, 1), + (1, 100, 1), + (999, 1, 1), + // Large numbers + (1000000, 500000, 500000), + (123456, 789012, 12), + (999999, 111111, 111111), + // Powers of 2 + (64, 128, 64), + (1024, 2048, 1024), + ]; + + #[test] + fn test_gcd_i64() { + let test_cases: Vec<(i64, i64, i64)> = [ + GCD_COMMON_TEST_CASES.into(), + vec![ + // Max value cases + (1, i64::MAX, 1), + (i64::MAX, 1, 1), + (i64::MAX, i64::MAX, i64::MAX), + ], + ] + .concat(); + + // Success cases + for (a, b, expected) in test_cases { + let actual = gcd_signed(a, b).expect("should succeed"); + assert_eq!( + actual, expected, + "euclid_gcd({a}, {b}) expected {expected}, actual {actual}" + ); } } -} -/// Computes greatest common divisor using Binary GCD algorithm. -pub fn compute_gcd(x: i64, y: i64) -> Result { - let a = x.unsigned_abs(); - let b = y.unsigned_abs(); - let r = unsigned_gcd(a, b); - // The result can be up to 2^63 (e.g. gcd(i64::MIN, 0) or - // gcd(i64::MIN, i64::MIN)), which does not fit into i64. - r.try_into().map_err(|_| { - ArrowError::ComputeError(format!("Signed integer overflow in GCD({x}, {y})")) - }) + #[test] + fn test_gcd_decimal128() { + let test_cases: Vec<(i256, i256, i256)> = [ + GCD_COMMON_TEST_CASES + .iter() + .map(|&(a, b, c)| (i256::from(a), i256::from(b), i256::from(c))) + .collect(), + vec![ + (i256::from(1), i256::MAX, i256::from(1)), + (i256::MAX, i256::from(1), i256::from(1)), + (i256::MAX, i256::MAX, i256::MAX), + ], + ] + .concat(); + + // Success cases + for (a, b, expected) in test_cases { + let actual = gcd_signed(a, b).expect("should succeed"); + assert_eq!( + actual, expected, + "euclid_gcd({a}, {b}) expected {expected}, actual {actual}" + ); + } + } } diff --git a/datafusion/functions/src/math/lcm.rs b/datafusion/functions/src/math/lcm.rs index 9398e9f8d6e00..705bcc798b193 100644 --- a/datafusion/functions/src/math/lcm.rs +++ b/datafusion/functions/src/math/lcm.rs @@ -15,25 +15,22 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - -use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; -use arrow::compute::try_binary; -use arrow::datatypes::DataType; -use arrow::datatypes::DataType::Int64; -use arrow::datatypes::Int64Type; +use arrow::array::ArrayRef; +use arrow::datatypes::{ + DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Int64Type, +}; -use arrow::error::ArrowError; +use crate::math::common::{lcm_signed, lcm_signed_int}; +use crate::utils::{calculate_binary_decimal_math_cast, calculate_binary_math_cast}; +use datafusion_common::utils::take_function_args; use datafusion_common::{Result, exec_err}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_expr_common::type_coercion::binary::decimal_coercion; use datafusion_macros::user_doc; -use super::gcd::unsigned_gcd; -use crate::utils::make_scalar_function; - #[user_doc( doc_section(label = "Math Functions"), description = "Returns the least common multiple of `expression_x` and `expression_y`. Returns 0 if either input is zero.", @@ -62,9 +59,8 @@ impl Default for LcmFunc { impl LcmFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform(2, vec![Int64], Volatility::Immutable), + signature: Signature::user_defined(Volatility::Immutable), } } } @@ -78,12 +74,99 @@ impl ScalarUDFImpl for LcmFunc { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Int64) + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [arg1, arg2] = take_function_args(self.name(), arg_types)?; + + let coerced_type = match (arg1, arg2) { + (DataType::Null, _) | (_, DataType::Null) => Ok(DataType::Int64), + (lhs, rhs) if lhs.is_integer() && rhs.is_integer() => Ok(DataType::Int64), + (lhs, rhs) if lhs.is_decimal() || rhs.is_decimal() => { + decimal_coercion(lhs, rhs).map(Ok).unwrap_or_else(|| { + exec_err!( + "Unsupported argument types {lhs:?} and {rhs:?} for function {}", + self.name() + ) + }) + } + (lhs, rhs) => { + exec_err!( + "Unsupported argument types {lhs:?} and {rhs:?} for function {}", + self.name() + ) + } + }?; + Ok(vec![coerced_type.clone(), coerced_type]) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(lcm, vec![])(&args.args) + let left = &args.args[0].to_array(args.number_rows)?; + let right = &args.args[1]; + + let arr: ArrayRef = match (left.data_type(), right.data_type()) { + (lhs @ DataType::Int64, _) => { + calculate_binary_math_cast::( + &left, + right, + lcm_signed_int, + lhs, + )? + } + ( + lhs @ DataType::Decimal32(precision, scale), + rhs @ DataType::Decimal32(_, _), + ) if *lhs == rhs => { + calculate_binary_decimal_math_cast::< + Decimal32Type, + Decimal32Type, + Decimal32Type, + _, + >(&left, right, lcm_signed, *precision, *scale, lhs)? + } + ( + lhs @ DataType::Decimal64(precision, scale), + rhs @ DataType::Decimal64(_, _), + ) if *lhs == rhs => { + calculate_binary_decimal_math_cast::< + Decimal64Type, + Decimal64Type, + Decimal64Type, + _, + >(&left, right, lcm_signed, *precision, *scale, lhs)? + } + ( + lhs @ DataType::Decimal128(precision, scale), + rhs @ DataType::Decimal128(_, _), + ) if *lhs == rhs => { + calculate_binary_decimal_math_cast::< + Decimal128Type, + Decimal128Type, + Decimal128Type, + _, + >(&left, right, lcm_signed, *precision, *scale, lhs)? + } + ( + lhs @ DataType::Decimal256(precision, scale), + rhs @ DataType::Decimal256(_, _), + ) if *lhs == rhs => { + calculate_binary_decimal_math_cast::< + Decimal256Type, + Decimal256Type, + Decimal256Type, + _, + >(&left, right, lcm_signed, *precision, *scale, lhs)? + } + (lhs, rhs) => { + return exec_err!( + "Unsupported data types {lhs:?} and {rhs:?} for function {}", + self.name() + ); + } + }; + Ok(ColumnarValue::Array(arr)) } fn documentation(&self) -> Option<&Documentation> { @@ -91,36 +174,164 @@ impl ScalarUDFImpl for LcmFunc { } } -/// Lcm SQL function -fn lcm(args: &[ArrayRef]) -> Result { - let compute_lcm = |x: i64, y: i64| -> Result { - if x == 0 || y == 0 { - return Ok(0); +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Array, Decimal128Array, Int64Array}; + use arrow::datatypes::{DECIMAL128_MAX_PRECISION, Field}; + use datafusion_common::ScalarValue; + use datafusion_common::cast::{as_decimal128_array, as_int64_array}; + use datafusion_common::config::ConfigOptions; + use std::sync::Arc; + + #[test] + fn test_i64_array() { + let arg_fields = vec![ + Field::new("a", DataType::Int64, true).into(), + Field::new("b", DataType::Int64, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Int64Array::from(vec![ + 0, 2, 0, 2, 15, 20, + ]))), + ColumnarValue::Array(Arc::new(Int64Array::from(vec![ + 0, 0, 2, 3, 10, 1000, + ]))), + ], + arg_fields, + number_rows: 6, + return_field: Field::new("f", DataType::Int64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = LcmFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function"); + + match result { + ColumnarValue::Array(arr) => { + let values = + as_int64_array(&arr).expect("failed to convert result to an array"); + assert_eq!(values.len(), 6); + assert_eq!(values.value(0), 0); + assert_eq!(values.value(1), 0); + assert_eq!(values.value(2), 0); + assert_eq!(values.value(3), 6); + assert_eq!(values.value(4), 30); + assert_eq!(values.value(5), 1000); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } } + } + + #[test] + fn test_decimal_scalar() { + let arg_fields = vec![ + Field::new("a", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), true) + .into(), + Field::new("b", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), true) + .into(), + ]; + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(i128::from(2)), + DECIMAL128_MAX_PRECISION, + 0, + )), + ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(i128::from(3)), + DECIMAL128_MAX_PRECISION, + 0, + )), + ], + arg_fields, + number_rows: 1, + return_field: Field::new( + "f", + DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), + true, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = LcmFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function power"); + + match result { + ColumnarValue::Array(arr) => { + let ints = as_decimal128_array(&arr) + .expect("failed to convert result to an array"); + + assert_eq!(ints.len(), 1); + assert_eq!(ints.value(0), i128::from(6)); + // Signature stays the same as input + assert_eq!( + *arr.data_type(), + DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0) + ); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_decimal_array_scalar() { + let arg_fields = vec![ + Field::new("a", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), true) + .into(), + Field::new("b", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), true) + .into(), + ]; + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new( + Decimal128Array::from(vec![2, 15]) + .with_precision_and_scale(DECIMAL128_MAX_PRECISION, 0) + .unwrap(), + )), + ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(i128::from(3)), + DECIMAL128_MAX_PRECISION, + 0, + )), + ], + arg_fields, + number_rows: 2, + return_field: Field::new( + "f", + DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), + true, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = LcmFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function power"); + + match result { + ColumnarValue::Array(arr) => { + let ints = as_decimal128_array(&arr) + .expect("failed to convert result to an array"); - // lcm(x, y) = |x| * |y| / gcd(|x|, |y|) - let a = x.unsigned_abs(); - let b = y.unsigned_abs(); - let gcd = unsigned_gcd(a, b); - // gcd is not zero since both a and b are not zero, so the division is safe. - (a / gcd) - .checked_mul(b) - .and_then(|v| i64::try_from(v).ok()) - .ok_or_else(|| { - ArrowError::ComputeError(format!( - "Signed integer overflow in LCM({x}, {y})" - )) - }) - }; - - match args[0].data_type() { - Int64 => { - let arg1 = args[0].as_primitive::(); - let arg2 = args[1].as_primitive::(); - - let result: PrimitiveArray = try_binary(arg1, arg2, compute_lcm)?; - Ok(Arc::new(result) as ArrayRef) + assert_eq!(ints.len(), 2); + assert_eq!(ints.value(0), i128::from(6)); + assert_eq!(ints.value(1), i128::from(15)); + // Signature stays the same as input + assert_eq!( + *arr.data_type(), + DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0) + ); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } } - other => exec_err!("Unsupported data type {other:?} for function lcm"), } } diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 1754ccb43488a..a5d45380ecf0a 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -25,6 +25,7 @@ use std::sync::Arc; pub mod abs; pub mod bounds; pub mod ceil; +mod common; pub mod cot; mod decimal; pub mod factorial; diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index b9bde1454994c..dc0da6b487611 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -123,6 +123,7 @@ where } /// Computes a binary math function for input arrays using a specified function. +/// Deprecated, use [`calculate_binary_math_cast`] instead. /// Generic types: /// - `L`: Left array primitive type /// - `R`: Right array primitive type @@ -133,6 +134,69 @@ pub fn calculate_binary_math( right: &ColumnarValue, fun: F, ) -> Result>> +where + L: ArrowPrimitiveType, + R: ArrowPrimitiveType, + O: ArrowPrimitiveType, + F: Fn(L::Native, R::Native) -> Result, + R::Native: TryFrom, +{ + calculate_binary_math_cast::(left, right, fun, &R::DATA_TYPE) +} + +/// Computes a binary math function for input arrays using a specified function +/// and applies rescaling to given precision and scale. +/// Deprecated, use [`calculate_binary_decimal_math_cast`] instead. +/// Generic types: +/// - `L`: Left array decimal type +/// - `R`: Right array primitive type +/// - `O`: Output array decimal type +/// - `F`: Functor computing `fun(l: L, r: R) -> Result` +pub fn calculate_binary_decimal_math( + left: &dyn Array, + right: &ColumnarValue, + fun: F, + precision: u8, + scale: i8, +) -> Result>> +where + L: DecimalType, + R: ArrowPrimitiveType, + O: DecimalType, + F: Fn(L::Native, R::Native) -> Result, + R::Native: TryFrom, +{ + calculate_binary_decimal_math_cast::( + left, + right, + fun, + precision, + scale, + &R::DATA_TYPE, + ) +} + +/// Computes a binary math function for input arrays using a specified function. +/// +/// It casts the right operand to `cast_target` instead of the default `R::DATA_TYPE` to preserve +/// the right operand scale. +/// +/// # Type Parameters +/// - `L`: Left array primitive type +/// - `R`: Right array primitive type +/// - `O`: Output array primitive type +/// - `F`: Functor computing `fun(l: L, r: R) -> Result` +/// # Arguments +/// - `left`: Left input array +/// - `right`: Right input array or scalar value +/// - `fun`: Function of type `F` +/// - `cast_target`: Data type to cast right operand to before applying function +pub fn calculate_binary_math_cast( + left: &dyn Array, + right: &ColumnarValue, + fun: F, + cast_target: &DataType, +) -> Result>> where L: ArrowPrimitiveType, R: ArrowPrimitiveType, @@ -141,7 +205,7 @@ where R::Native: TryFrom, { let left = left.as_primitive::(); - let right = right.cast_to(&R::DATA_TYPE, None)?; + let right = right.cast_to(cast_target, None)?; let result = match right { ColumnarValue::Scalar(scalar) => { if scalar.is_null() { @@ -152,8 +216,7 @@ where let right = R::Native::try_from(scalar.clone()).map_err(|_| { DataFusionError::NotImplemented(format!( "Cannot convert scalar value {} to {}", - &scalar, - R::DATA_TYPE + &scalar, cast_target )) })?; left.try_unary::<_, O, _>(|lvalue| fun(lvalue, right))? @@ -168,18 +231,30 @@ where } /// Computes a binary math function for input arrays using a specified function -/// and apply rescaling to given precision and scale. -/// Generic types: +/// and applies rescaling to given precision and scale. +/// +/// It casts the right operand to `cast_target` instead of the default `R::DATA_TYPE` to preserve +/// the right operand scale. +/// +/// # Type Parameters /// - `L`: Left array decimal type /// - `R`: Right array primitive type /// - `O`: Output array decimal type /// - `F`: Functor computing `fun(l: L, r: R) -> Result` -pub fn calculate_binary_decimal_math( +/// # Arguments +/// - `left`: Left input array +/// - `right`: Right input array or scalar value +/// - `fun`: Function of type `F` +/// - `precision`: Precision to apply to output decimal array +/// - `scale`: Scale to apply to output decimal array +/// - `cast_target`: Data type to cast right operand to before applying function +pub fn calculate_binary_decimal_math_cast( left: &dyn Array, right: &ColumnarValue, fun: F, precision: u8, scale: i8, + cast_target: &DataType, ) -> Result>> where L: DecimalType, @@ -188,7 +263,8 @@ where F: Fn(L::Native, R::Native) -> Result, R::Native: TryFrom, { - let result_array = calculate_binary_math::(left, right, fun)?; + let result_array = + calculate_binary_math_cast::(left, right, fun, cast_target)?; Ok(Arc::new( result_array .as_ref() diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index e261bada87eda..39de30d0c8641 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -686,6 +686,22 @@ select gcd(-9223372036854775808, 0); query error DataFusion error: Arrow error: Compute error: Signed integer overflow in GCD\(0, \-9223372036854775808\) select gcd(0, -9223372036854775808); +# gcd decimal +query RT +select gcd(2::decimal(38, 0), 3::decimal(38, 0)), arrow_typeof(gcd(2::decimal(38, 0), 3::decimal(38, 0))); +---- +1 Decimal128(38, 0) + +query RT +select gcd(0::decimal(38, 0), 3::decimal(38, 0)), arrow_typeof(gcd(0::decimal(38, 0), 3::decimal(38, 0))); +---- +3 Decimal128(38, 0) + +query RT +select gcd(2, 3::decimal(38, 0)), arrow_typeof(gcd(2, 3::decimal(38, 0))); +---- +1 Decimal128(38, 0) + ## lcm @@ -727,6 +743,22 @@ select lcm(1, -9223372036854775808); query error DataFusion error: Arrow error: Compute error: Signed integer overflow in LCM\(2, 9223372036854775803\) select lcm(2, 9223372036854775803); +# lcm decimal +query R +select lcm(2::decimal(38, 0), 3::decimal(38, 0)); +---- +6 + +query RT +select lcm(0::decimal(38, 0), 3::decimal(38, 0)), arrow_typeof(lcm(0::decimal(38, 0), 3::decimal(38, 0))); +---- +0 Decimal128(38, 0) + +query RT +select lcm(2, 3::decimal(38, 0)), arrow_typeof(lcm(2, 3::decimal(38, 0))); +---- +6 Decimal128(38, 0) + ## pow/power @@ -899,6 +931,13 @@ SELECT lcm(6, column1) FROM (VALUES (4), (9), (0)); 18 0 +query R +SELECT lcm(6, arrow_cast(column1, 'Decimal128(38,0)')) FROM (VALUES (4), (9), (0)); +---- +12 +18 +0 + # lcm array and scalar with nulls in the array query I SELECT lcm(column1, 5) FROM (VALUES (0), (NULL), (25)); @@ -942,6 +981,13 @@ SELECT gcd(15, column1) FROM (VALUES (10), (25), (0)); 5 15 +query R +SELECT gcd(15, arrow_cast(column1, 'Decimal128(38,0)')) FROM (VALUES (10), (25), (0)); +---- +5 +5 +15 + # gcd array and scalar with nulls in the array query I SELECT gcd(column1, 12) FROM (VALUES (8), (NULL), (0), (-36)); From e7fa59832fb990e07ca5ad399b225a5579fdf1bc Mon Sep 17 00:00:00 2001 From: theirix Date: Thu, 11 Jun 2026 13:04:23 +0100 Subject: [PATCH 02/11] Use plan_err --- datafusion/functions/src/math/gcd.rs | 8 +++++--- datafusion/functions/src/math/lcm.rs | 6 +++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/datafusion/functions/src/math/gcd.rs b/datafusion/functions/src/math/gcd.rs index b9490dd0dc845..327be727326a7 100644 --- a/datafusion/functions/src/math/gcd.rs +++ b/datafusion/functions/src/math/gcd.rs @@ -25,7 +25,9 @@ use std::sync::Arc; use crate::math::common::{gcd_signed, gcd_signed_int, unsigned_gcd}; use crate::utils::calculate_binary_decimal_math_cast; use datafusion_common::utils::take_function_args; -use datafusion_common::{Result, ScalarValue, exec_err, internal_datafusion_err}; +use datafusion_common::{ + Result, ScalarValue, exec_err, internal_datafusion_err, plan_err, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -88,14 +90,14 @@ impl ScalarUDFImpl for GcdFunc { (lhs, rhs) if lhs.is_integer() && rhs.is_integer() => Ok(DataType::Int64), (lhs, rhs) if lhs.is_decimal() || rhs.is_decimal() => { decimal_coercion(lhs, rhs).map(Ok).unwrap_or_else(|| { - exec_err!( + plan_err!( "Unsupported argument types {lhs:?} and {rhs:?} for function {}", self.name() ) }) } (lhs, rhs) => { - exec_err!( + plan_err!( "Unsupported argument types {lhs:?} and {rhs:?} for function {}", self.name() ) diff --git a/datafusion/functions/src/math/lcm.rs b/datafusion/functions/src/math/lcm.rs index 705bcc798b193..166a5afcb3e1e 100644 --- a/datafusion/functions/src/math/lcm.rs +++ b/datafusion/functions/src/math/lcm.rs @@ -23,7 +23,7 @@ use arrow::datatypes::{ use crate::math::common::{lcm_signed, lcm_signed_int}; use crate::utils::{calculate_binary_decimal_math_cast, calculate_binary_math_cast}; use datafusion_common::utils::take_function_args; -use datafusion_common::{Result, exec_err}; +use datafusion_common::{Result, exec_err, plan_err}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -86,14 +86,14 @@ impl ScalarUDFImpl for LcmFunc { (lhs, rhs) if lhs.is_integer() && rhs.is_integer() => Ok(DataType::Int64), (lhs, rhs) if lhs.is_decimal() || rhs.is_decimal() => { decimal_coercion(lhs, rhs).map(Ok).unwrap_or_else(|| { - exec_err!( + plan_err!( "Unsupported argument types {lhs:?} and {rhs:?} for function {}", self.name() ) }) } (lhs, rhs) => { - exec_err!( + plan_err!( "Unsupported argument types {lhs:?} and {rhs:?} for function {}", self.name() ) From 1c56990dad5aa9fb5330a5f8c00c277f991fb4b3 Mon Sep 17 00:00:00 2001 From: theirix Date: Thu, 11 Jun 2026 13:04:36 +0100 Subject: [PATCH 03/11] Remove duplicated unit tests --- datafusion/functions/src/math/gcd.rs | 75 ---------------------------- 1 file changed, 75 deletions(-) diff --git a/datafusion/functions/src/math/gcd.rs b/datafusion/functions/src/math/gcd.rs index 327be727326a7..491283cce8207 100644 --- a/datafusion/functions/src/math/gcd.rs +++ b/datafusion/functions/src/math/gcd.rs @@ -421,79 +421,4 @@ mod tests { assert_eq!(coerced, vec![DataType::Int64, DataType::Int64]); } - - const GCD_COMMON_TEST_CASES: [(i64, i64, i64); 18] = [ - // Basic cases - (48, 18, 6), - (54, 24, 6), - (100, 50, 50), - (17, 19, 1), - (21, 14, 7), - // Edge cases with 0 - (0, 0, 0), - (0, 5, 5), - (10, 0, 10), - // Same numbers - (7, 7, 7), - (100, 100, 100), - // One is 1 - (1, 1, 1), - (1, 100, 1), - (999, 1, 1), - // Large numbers - (1000000, 500000, 500000), - (123456, 789012, 12), - (999999, 111111, 111111), - // Powers of 2 - (64, 128, 64), - (1024, 2048, 1024), - ]; - - #[test] - fn test_gcd_i64() { - let test_cases: Vec<(i64, i64, i64)> = [ - GCD_COMMON_TEST_CASES.into(), - vec![ - // Max value cases - (1, i64::MAX, 1), - (i64::MAX, 1, 1), - (i64::MAX, i64::MAX, i64::MAX), - ], - ] - .concat(); - - // Success cases - for (a, b, expected) in test_cases { - let actual = gcd_signed(a, b).expect("should succeed"); - assert_eq!( - actual, expected, - "euclid_gcd({a}, {b}) expected {expected}, actual {actual}" - ); - } - } - - #[test] - fn test_gcd_decimal128() { - let test_cases: Vec<(i256, i256, i256)> = [ - GCD_COMMON_TEST_CASES - .iter() - .map(|&(a, b, c)| (i256::from(a), i256::from(b), i256::from(c))) - .collect(), - vec![ - (i256::from(1), i256::MAX, i256::from(1)), - (i256::MAX, i256::from(1), i256::from(1)), - (i256::MAX, i256::MAX, i256::MAX), - ], - ] - .concat(); - - // Success cases - for (a, b, expected) in test_cases { - let actual = gcd_signed(a, b).expect("should succeed"); - assert_eq!( - actual, expected, - "euclid_gcd({a}, {b}) expected {expected}, actual {actual}" - ); - } - } } From e3cdd83c360f874d46f5064251a4cc23fe96d41a Mon Sep 17 00:00:00 2001 From: theirix Date: Thu, 11 Jun 2026 13:30:31 +0100 Subject: [PATCH 04/11] Replace gcd unit tests with SLT --- datafusion/functions/src/math/gcd.rs | 159 ------------------- datafusion/functions/src/math/lcm.rs | 162 -------------------- datafusion/sqllogictest/test_files/math.slt | 31 ++++ 3 files changed, 31 insertions(+), 321 deletions(-) diff --git a/datafusion/functions/src/math/gcd.rs b/datafusion/functions/src/math/gcd.rs index 491283cce8207..aeddc3f27c409 100644 --- a/datafusion/functions/src/math/gcd.rs +++ b/datafusion/functions/src/math/gcd.rs @@ -239,165 +239,6 @@ fn compute_gcd_with_scalar(arr: &ArrayRef, scalar: Option) -> Result { - let values = - as_int64_array(&arr).expect("failed to convert result to an array"); - assert_eq!(values.len(), 6); - assert_eq!(values.value(0), 0); - assert_eq!(values.value(1), 2); - assert_eq!(values.value(2), 2); - assert_eq!(values.value(3), 1); - assert_eq!(values.value(4), 5); - assert_eq!(values.value(5), 20); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } - } - - #[test] - fn test_decimal_scalar() { - let arg_fields = vec![ - Field::new("a", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), true) - .into(), - Field::new("b", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), true) - .into(), - ]; - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Decimal128( - Some(i128::from(2)), - DECIMAL128_MAX_PRECISION, - 0, - )), - ColumnarValue::Scalar(ScalarValue::Decimal128( - Some(i128::from(3)), - DECIMAL128_MAX_PRECISION, - 0, - )), - ], - arg_fields, - number_rows: 1, - return_field: Field::new( - "f", - DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), - true, - ) - .into(), - config_options: Arc::new(ConfigOptions::default()), - }; - let result = GcdFunc::new() - .invoke_with_args(args) - .expect("failed to initialize function power"); - - match result { - ColumnarValue::Array(arr) => { - let ints = as_decimal128_array(&arr) - .expect("failed to convert result to an array"); - - assert_eq!(ints.len(), 1); - assert_eq!(ints.value(0), i128::from(1)); - // Signature stays the same as input - assert_eq!( - *arr.data_type(), - DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0) - ); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } - } - - #[test] - fn test_decimal_array_scalar() { - let arg_fields = vec![ - Field::new("a", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), true) - .into(), - Field::new("b", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), true) - .into(), - ]; - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(Arc::new( - Decimal128Array::from(vec![2, 15]) - .with_precision_and_scale(DECIMAL128_MAX_PRECISION, 0) - .unwrap(), - )), - ColumnarValue::Scalar(ScalarValue::Decimal128( - Some(i128::from(3)), - DECIMAL128_MAX_PRECISION, - 0, - )), - ], - arg_fields, - number_rows: 2, - return_field: Field::new( - "f", - DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), - true, - ) - .into(), - config_options: Arc::new(ConfigOptions::default()), - }; - let result = GcdFunc::new() - .invoke_with_args(args) - .expect("failed to initialize function power"); - - match result { - ColumnarValue::Array(arr) => { - let ints = as_decimal128_array(&arr) - .expect("failed to convert result to an array"); - - assert_eq!(ints.len(), 2); - assert_eq!(ints.value(0), i128::from(1)); - assert_eq!(ints.value(1), i128::from(3)); - // Signature stays the same as input - assert_eq!( - *arr.data_type(), - DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0) - ); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } - } #[test] fn test_coercion() { diff --git a/datafusion/functions/src/math/lcm.rs b/datafusion/functions/src/math/lcm.rs index 166a5afcb3e1e..3861c43b846c4 100644 --- a/datafusion/functions/src/math/lcm.rs +++ b/datafusion/functions/src/math/lcm.rs @@ -173,165 +173,3 @@ impl ScalarUDFImpl for LcmFunc { self.doc() } } - -#[cfg(test)] -mod tests { - use super::*; - use arrow::array::{Array, Decimal128Array, Int64Array}; - use arrow::datatypes::{DECIMAL128_MAX_PRECISION, Field}; - use datafusion_common::ScalarValue; - use datafusion_common::cast::{as_decimal128_array, as_int64_array}; - use datafusion_common::config::ConfigOptions; - use std::sync::Arc; - - #[test] - fn test_i64_array() { - let arg_fields = vec![ - Field::new("a", DataType::Int64, true).into(), - Field::new("b", DataType::Int64, true).into(), - ]; - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(Arc::new(Int64Array::from(vec![ - 0, 2, 0, 2, 15, 20, - ]))), - ColumnarValue::Array(Arc::new(Int64Array::from(vec![ - 0, 0, 2, 3, 10, 1000, - ]))), - ], - arg_fields, - number_rows: 6, - return_field: Field::new("f", DataType::Int64, true).into(), - config_options: Arc::new(ConfigOptions::default()), - }; - let result = LcmFunc::new() - .invoke_with_args(args) - .expect("failed to initialize function"); - - match result { - ColumnarValue::Array(arr) => { - let values = - as_int64_array(&arr).expect("failed to convert result to an array"); - assert_eq!(values.len(), 6); - assert_eq!(values.value(0), 0); - assert_eq!(values.value(1), 0); - assert_eq!(values.value(2), 0); - assert_eq!(values.value(3), 6); - assert_eq!(values.value(4), 30); - assert_eq!(values.value(5), 1000); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } - } - - #[test] - fn test_decimal_scalar() { - let arg_fields = vec![ - Field::new("a", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), true) - .into(), - Field::new("b", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), true) - .into(), - ]; - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Decimal128( - Some(i128::from(2)), - DECIMAL128_MAX_PRECISION, - 0, - )), - ColumnarValue::Scalar(ScalarValue::Decimal128( - Some(i128::from(3)), - DECIMAL128_MAX_PRECISION, - 0, - )), - ], - arg_fields, - number_rows: 1, - return_field: Field::new( - "f", - DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), - true, - ) - .into(), - config_options: Arc::new(ConfigOptions::default()), - }; - let result = LcmFunc::new() - .invoke_with_args(args) - .expect("failed to initialize function power"); - - match result { - ColumnarValue::Array(arr) => { - let ints = as_decimal128_array(&arr) - .expect("failed to convert result to an array"); - - assert_eq!(ints.len(), 1); - assert_eq!(ints.value(0), i128::from(6)); - // Signature stays the same as input - assert_eq!( - *arr.data_type(), - DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0) - ); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } - } - - #[test] - fn test_decimal_array_scalar() { - let arg_fields = vec![ - Field::new("a", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), true) - .into(), - Field::new("b", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), true) - .into(), - ]; - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(Arc::new( - Decimal128Array::from(vec![2, 15]) - .with_precision_and_scale(DECIMAL128_MAX_PRECISION, 0) - .unwrap(), - )), - ColumnarValue::Scalar(ScalarValue::Decimal128( - Some(i128::from(3)), - DECIMAL128_MAX_PRECISION, - 0, - )), - ], - arg_fields, - number_rows: 2, - return_field: Field::new( - "f", - DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), - true, - ) - .into(), - config_options: Arc::new(ConfigOptions::default()), - }; - let result = LcmFunc::new() - .invoke_with_args(args) - .expect("failed to initialize function power"); - - match result { - ColumnarValue::Array(arr) => { - let ints = as_decimal128_array(&arr) - .expect("failed to convert result to an array"); - - assert_eq!(ints.len(), 2); - assert_eq!(ints.value(0), i128::from(6)); - assert_eq!(ints.value(1), i128::from(15)); - // Signature stays the same as input - assert_eq!( - *arr.data_type(), - DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0) - ); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } - } -} diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index 39de30d0c8641..e1f69256af364 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -931,6 +931,14 @@ SELECT lcm(6, column1) FROM (VALUES (4), (9), (0)); 18 0 +query I +SELECT lcm(column1, column2) FROM (VALUES (0, 5), (3, 5), (25, 5), (-16, 5)); +---- +0 +15 +25 +80 + query R SELECT lcm(6, arrow_cast(column1, 'Decimal128(38,0)')) FROM (VALUES (4), (9), (0)); ---- @@ -938,6 +946,13 @@ SELECT lcm(6, arrow_cast(column1, 'Decimal128(38,0)')) FROM (VALUES (4), (9), (0 18 0 +query R +SELECT lcm(arrow_cast(column1, 'Decimal128(38,0)'), arrow_cast(column2, 'Decimal128(38,0)')) FROM (VALUES (6, 4), (6, 9), (6, 0)); +---- +12 +18 +0 + # lcm array and scalar with nulls in the array query I SELECT lcm(column1, 5) FROM (VALUES (0), (NULL), (25)); @@ -981,6 +996,14 @@ SELECT gcd(15, column1) FROM (VALUES (10), (25), (0)); 5 15 +query I +SELECT gcd(column1, column2) FROM (VALUES (8, 12), (18, 12), (0, 12), (-36, 12)); +---- +4 +6 +12 +12 + query R SELECT gcd(15, arrow_cast(column1, 'Decimal128(38,0)')) FROM (VALUES (10), (25), (0)); ---- @@ -988,6 +1011,14 @@ SELECT gcd(15, arrow_cast(column1, 'Decimal128(38,0)')) FROM (VALUES (10), (25), 5 15 +query R +SELECT gcd(arrow_cast(column1, 'Decimal128(38,0)'), arrow_cast(column2, 'Decimal128(38,0)')) FROM (VALUES (15, 10), (15, 25), (15, 0)); +---- +5 +5 +15 + + # gcd array and scalar with nulls in the array query I SELECT gcd(column1, 12) FROM (VALUES (8), (NULL), (0), (-36)); From ec41c47f2f197dfdb0c92efb1b412732d9387cf6 Mon Sep 17 00:00:00 2001 From: theirix Date: Thu, 11 Jun 2026 14:47:41 +0100 Subject: [PATCH 05/11] Tests for decimal and non-integer inputs --- datafusion/sqllogictest/test_files/math.slt | 22 +++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index e1f69256af364..13d458161c12e 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -702,6 +702,22 @@ select gcd(2, 3::decimal(38, 0)), arrow_typeof(gcd(2, 3::decimal(38, 0))); ---- 1 Decimal128(38, 0) +query RR +select gcd(-15::decimal(38, 0), -3::decimal(38, 0)), gcd(-15::decimal(38, 0), 3::decimal(38, 0)); +---- +3 3 + +# floats are coerced to decimals with rounding +query RT +select gcd(15.3::decimal(38, 0), 2.9::decimal(38, 0)), arrow_typeof(gcd(15.3::decimal(38, 0), 2.9::decimal(38, 0))); +---- +3 Decimal128(38, 0) + +# both decimal arguments are coerced to widest - decimal(38, 5), return type is that as well +query RT +select gcd(15::decimal(30, 2), 3::decimal(38, 5)), arrow_typeof(gcd(15::decimal(30, 2), 3::decimal(38, 5))); +---- +3 Decimal128(38, 5) ## lcm @@ -759,6 +775,12 @@ select lcm(2, 3::decimal(38, 0)), arrow_typeof(lcm(2, 3::decimal(38, 0))); ---- 6 Decimal128(38, 0) +# both decimal arguments are coerced to widest - decimal(38, 5), return type is that as well +query RT +select lcm(2::decimal(30, 2), 3::decimal(38, 5)), arrow_typeof(lcm(2::decimal(30, 2), 3::decimal(38, 5))); +---- +6 Decimal128(38, 5) + ## pow/power From 4b739015c36ed23e722214b3be483af52b964fe8 Mon Sep 17 00:00:00 2001 From: theirix Date: Fri, 12 Jun 2026 11:57:33 +0100 Subject: [PATCH 06/11] Change deprecations --- datafusion/functions/src/math/lcm.rs | 16 +++++++--------- datafusion/functions/src/utils.rs | 6 ++++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/datafusion/functions/src/math/lcm.rs b/datafusion/functions/src/math/lcm.rs index 3861c43b846c4..245dba0ba3938 100644 --- a/datafusion/functions/src/math/lcm.rs +++ b/datafusion/functions/src/math/lcm.rs @@ -21,7 +21,7 @@ use arrow::datatypes::{ }; use crate::math::common::{lcm_signed, lcm_signed_int}; -use crate::utils::{calculate_binary_decimal_math_cast, calculate_binary_math_cast}; +use crate::utils::{calculate_binary_decimal_math_cast, calculate_binary_math}; use datafusion_common::utils::take_function_args; use datafusion_common::{Result, exec_err, plan_err}; use datafusion_expr::{ @@ -107,14 +107,12 @@ impl ScalarUDFImpl for LcmFunc { let right = &args.args[1]; let arr: ArrayRef = match (left.data_type(), right.data_type()) { - (lhs @ DataType::Int64, _) => { - calculate_binary_math_cast::( - &left, - right, - lcm_signed_int, - lhs, - )? - } + (DataType::Int64, _) => calculate_binary_math::< + Int64Type, + Int64Type, + Int64Type, + _, + >(&left, right, lcm_signed_int)?, ( lhs @ DataType::Decimal32(precision, scale), rhs @ DataType::Decimal32(_, _), diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index dc0da6b487611..a30bee09c6b87 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -123,7 +123,6 @@ where } /// Computes a binary math function for input arrays using a specified function. -/// Deprecated, use [`calculate_binary_math_cast`] instead. /// Generic types: /// - `L`: Left array primitive type /// - `R`: Right array primitive type @@ -146,12 +145,15 @@ where /// Computes a binary math function for input arrays using a specified function /// and applies rescaling to given precision and scale. -/// Deprecated, use [`calculate_binary_decimal_math_cast`] instead. /// Generic types: /// - `L`: Left array decimal type /// - `R`: Right array primitive type /// - `O`: Output array decimal type /// - `F`: Functor computing `fun(l: L, r: R) -> Result` +#[deprecated( + since = "55.0.0", + note = "Use `calculate_binary_decimal_math_cast` instead" +)] pub fn calculate_binary_decimal_math( left: &dyn Array, right: &ColumnarValue, From e998de8046ae69cebdaaa59ddda4d3a80c1d28eb Mon Sep 17 00:00:00 2001 From: theirix Date: Fri, 12 Jun 2026 12:33:26 +0100 Subject: [PATCH 07/11] Avoid hard-deprecation until internal usage is migrated --- datafusion/functions/src/utils.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index a30bee09c6b87..bfc7fe10bc36f 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -150,10 +150,7 @@ where /// - `R`: Right array primitive type /// - `O`: Output array decimal type /// - `F`: Functor computing `fun(l: L, r: R) -> Result` -#[deprecated( - since = "55.0.0", - note = "Use `calculate_binary_decimal_math_cast` instead" -)] +/// Deprecated. Use `calculate_binary_decimal_math_cast` instead pub fn calculate_binary_decimal_math( left: &dyn Array, right: &ColumnarValue, From a65d1b4badfbe2dd57f0da7c6e29608735533f59 Mon Sep 17 00:00:00 2001 From: theirix Date: Fri, 12 Jun 2026 12:33:38 +0100 Subject: [PATCH 08/11] Made calculate_binary_math_cast private --- datafusion/functions/src/utils.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index bfc7fe10bc36f..d99a28c9f1b0f 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -190,7 +190,7 @@ where /// - `right`: Right input array or scalar value /// - `fun`: Function of type `F` /// - `cast_target`: Data type to cast right operand to before applying function -pub fn calculate_binary_math_cast( +fn calculate_binary_math_cast( left: &dyn Array, right: &ColumnarValue, fun: F, From 323721f89b5eac7cf0e5d1d4288038846da3b47e Mon Sep 17 00:00:00 2001 From: theirix Date: Fri, 12 Jun 2026 13:14:36 +0100 Subject: [PATCH 09/11] doc: fixup comment placement --- datafusion/functions/src/utils.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index d99a28c9f1b0f..56ffa7eff8242 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -145,12 +145,12 @@ where /// Computes a binary math function for input arrays using a specified function /// and applies rescaling to given precision and scale. +/// Deprecated. Use `calculate_binary_decimal_math_cast` instead /// Generic types: /// - `L`: Left array decimal type /// - `R`: Right array primitive type /// - `O`: Output array decimal type /// - `F`: Functor computing `fun(l: L, r: R) -> Result` -/// Deprecated. Use `calculate_binary_decimal_math_cast` instead pub fn calculate_binary_decimal_math( left: &dyn Array, right: &ColumnarValue, From a26eb47d35af2c631344704876052cf999f57dfd Mon Sep 17 00:00:00 2001 From: theirix Date: Sat, 13 Jun 2026 13:52:50 +0100 Subject: [PATCH 10/11] Deprecate calculate_binary_decimal_math and port usages in round UDF --- datafusion/functions/src/math/round.rs | 14 +++++++++----- datafusion/functions/src/utils.rs | 5 ++++- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index 78016c0f52f71..aacc8820a8cb6 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::{calculate_binary_decimal_math, calculate_binary_math}; +use crate::utils::{calculate_binary_decimal_math_cast, calculate_binary_math}; use arrow::array::ArrayRef; use arrow::datatypes::DataType::{ @@ -486,7 +486,7 @@ fn round_columnar( } (Decimal32(input_precision, scale), Decimal32(precision, new_scale)) => { // reduce scale to reclaim integer precision - let result = calculate_binary_decimal_math::< + let result = calculate_binary_decimal_math_cast::< Decimal32Type, Int32Type, Decimal32Type, @@ -518,11 +518,12 @@ fn round_columnar( }, *precision, *new_scale, + &DataType::Int32, )?; result as _ } (Decimal64(input_precision, scale), Decimal64(precision, new_scale)) => { - let result = calculate_binary_decimal_math::< + let result = calculate_binary_decimal_math_cast::< Decimal64Type, Int32Type, Decimal64Type, @@ -551,11 +552,12 @@ fn round_columnar( }, *precision, *new_scale, + &DataType::Int32, )?; result as _ } (Decimal128(input_precision, scale), Decimal128(precision, new_scale)) => { - let result = calculate_binary_decimal_math::< + let result = calculate_binary_decimal_math_cast::< Decimal128Type, Int32Type, Decimal128Type, @@ -584,11 +586,12 @@ fn round_columnar( }, *precision, *new_scale, + &DataType::Int32, )?; result as _ } (Decimal256(input_precision, scale), Decimal256(precision, new_scale)) => { - let result = calculate_binary_decimal_math::< + let result = calculate_binary_decimal_math_cast::< Decimal256Type, Int32Type, Decimal256Type, @@ -617,6 +620,7 @@ fn round_columnar( }, *precision, *new_scale, + &DataType::Int32, )?; result as _ } diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 56ffa7eff8242..39683e9a6afa2 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -145,12 +145,15 @@ where /// Computes a binary math function for input arrays using a specified function /// and applies rescaling to given precision and scale. -/// Deprecated. Use `calculate_binary_decimal_math_cast` instead /// Generic types: /// - `L`: Left array decimal type /// - `R`: Right array primitive type /// - `O`: Output array decimal type /// - `F`: Functor computing `fun(l: L, r: R) -> Result` +#[deprecated( + since = "55.0.0", + note = "Use `calculate_binary_decimal_math_cast` instead" +)] pub fn calculate_binary_decimal_math( left: &dyn Array, right: &ColumnarValue, From 07401f74bfdda4d8dc501be9869c278556335fb7 Mon Sep 17 00:00:00 2001 From: theirix Date: Sun, 14 Jun 2026 07:40:18 +0100 Subject: [PATCH 11/11] Update non-whole number case with scales to match Postgres --- datafusion/sqllogictest/test_files/math.slt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index 4e53a37f61de7..583d6f6777865 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -707,11 +707,11 @@ select gcd(-15::decimal(38, 0), -3::decimal(38, 0)), gcd(-15::decimal(38, 0), 3: ---- 3 3 -# floats are coerced to decimals with rounding +# non-whole number case query RT -select gcd(15.3::decimal(38, 0), 2.9::decimal(38, 0)), arrow_typeof(gcd(15.3::decimal(38, 0), 2.9::decimal(38, 0))); +select gcd(15.3::decimal(38, 1), 2.9::decimal(38, 1)), arrow_typeof(gcd(15.3::decimal(38, 1), 2.9::decimal(38, 1))); ---- -3 Decimal128(38, 0) +0.1 Decimal128(38, 1) # both decimal arguments are coerced to widest - decimal(38, 5), return type is that as well query RT