diff --git a/datafusion/functions/src/math/common.rs b/datafusion/functions/src/math/common.rs new file mode 100644 index 0000000000000..9bb6f6fe1e35c --- /dev/null +++ b/datafusion/functions/src/math/common.rs @@ -0,0 +1,320 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrowNativeTypeOp; +use arrow::error::ArrowError; +use num_traits::{CheckedMul, CheckedNeg, Signed}; +use std::fmt::Display; +use std::mem::swap; +use std::ops::RemAssign; + +/// A gcd helper to compute GCD using Euclidean GCD algorithm +/// on non-negative numbers (scalars and decimals) +fn gcd_helper(a: T, b: T) -> Result +where + T: ArrowNativeTypeOp + RemAssign + CheckedNeg, +{ + debug_assert!(a >= T::ZERO); + debug_assert!(b >= T::ZERO); + let (mut a, mut b) = if a > b { (a, b) } else { (b, a) }; + + while b != T::ZERO { + swap(&mut a, &mut b); + b %= a; + } + + Ok(a) +} + +/// Computes gcd of two unsigned integers using Binary GCD algorithm +/// Faster, works with integers only +pub(crate) fn unsigned_gcd(mut a: u64, mut b: u64) -> u64 { + if a == 0 { + return b; + } + if b == 0 { + return a; + } + + let shift = (a | b).trailing_zeros(); + a >>= a.trailing_zeros(); + loop { + b >>= b.trailing_zeros(); + if a > b { + swap(&mut a, &mut b); + } + b -= a; + if b == 0 { + return a << shift; + } + } +} + +/// Computes gcd of two signed numbers (integers or decimals), +/// checking for output integer overflow +pub(crate) fn gcd_signed(x: T, y: T) -> Result +where + T: ArrowNativeTypeOp + RemAssign + Signed + CheckedNeg, +{ + // Make absolute values, keeping type + let a = if x.is_positive() { + x + } else { + x.checked_neg() + .ok_or_else(|| ArrowError::ComputeError("Signed integer overflow".into()))? + }; + let b = if y.is_positive() { + y + } else { + y.checked_neg() + .ok_or_else(|| ArrowError::ComputeError("Signed integer overflow".into()))? + }; + // Call with signed numbers + gcd_helper(a, b) +} + +/// Computes gcd of two signed integers +pub(crate) fn gcd_signed_int(x: i64, y: i64) -> Result { + let a = x.unsigned_abs(); + let b = y.unsigned_abs(); + + // Call with unsigned numbers + let r = unsigned_gcd(a, b); + // gcd(i64::MIN, i64::MIN) = u64::MIN.unsigned_abs() cannot fit into i64 + r.try_into().map_err(|_| { + ArrowError::ComputeError(format!("Signed integer overflow in GCD({x}, {y})")) + }) +} + +/// Computes lcm of two signed numbers (integers or decimals) +pub(crate) fn lcm_signed(x: T, y: T) -> Result +where + T: ArrowNativeTypeOp + RemAssign + Signed + CheckedNeg + CheckedMul + Display, +{ + if x == T::ZERO || y == T::ZERO { + return Ok(T::ZERO); + } + + // Make absolute values, keeping type + let a = if x.is_positive() { + x + } else { + x.checked_neg() + .ok_or_else(|| ArrowError::ComputeError("Signed integer overflow".into()))? + }; + let b = if y.is_positive() { + y + } else { + y.checked_neg() + .ok_or_else(|| ArrowError::ComputeError("Signed integer overflow".into()))? + }; + // Call with signed numbers + let gcd = gcd_helper(a, b)?; + // gcd is not zero since both a and b are not zero, so the division is safe. + (a / gcd).checked_mul(&b).ok_or_else(|| { + ArrowError::ComputeError(format!("Signed integer overflow in LCM({x}, {y})")) + }) +} + +/// Computes lcm of two signed integers, +/// checking for output integer overflow +pub(crate) fn lcm_signed_int(x: i64, y: i64) -> Result { + if x == 0 || y == 0 { + return Ok(0); + } + + let a = x.unsigned_abs(); + let b = y.unsigned_abs(); + + let gcd = gcd_helper::(a, b)?; + // gcd is not zero since both a and b are not zero, so the division is safe. + (a / gcd) + .checked_mul(b) + .and_then(|v| i64::try_from(v).ok()) + .ok_or_else(|| { + ArrowError::ComputeError(format!("Signed integer overflow in LCM({x}, {y})")) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_buffer::i256; + + const GCD_COMMON_TEST_CASES: [(i64, i64, i64); 18] = [ + // Basic cases + (48, 18, 6), + (54, 24, 6), + (100, 50, 50), + (17, 19, 1), + (21, 14, 7), + // Edge cases with 0 + (0, 0, 0), + (0, 5, 5), + (10, 0, 10), + // Same numbers + (7, 7, 7), + (100, 100, 100), + // One is 1 + (1, 1, 1), + (1, 100, 1), + (999, 1, 1), + // Large numbers + (1000000, 500000, 500000), + (123456, 789012, 12), + (999999, 111111, 111111), + // Powers of 2 + (64, 128, 64), + (1024, 2048, 1024), + ]; + + const LCM_COMMON_TEST_CASES: [(i64, i64, i64); 18] = [ + // Basic cases + (48, 18, 144), + (54, 24, 216), + (100, 50, 100), + (17, 19, 323), + (21, 14, 42), + // Edge cases with 0 + (0, 0, 0), + (0, 5, 0), + (10, 0, 0), + // Same numbers + (7, 7, 7), + (100, 100, 100), + // One is 1 + (1, 1, 1), + (1, 100, 100), + (999, 1, 999), + // Large numbers + (1_000_000, 500_000, 1_000_000), + (123_456, 789_012, 8_117_355_456), + (999_999, 111_111, 999_999), + // Powers of 2 + (64, 128, 128), + (1024, 2048, 2048), + ]; + + #[test] + fn test_gcd_i64() { + let test_cases: Vec<(i64, i64, i64)> = [ + GCD_COMMON_TEST_CASES.into(), + vec![ + // Max value cases + (1, i64::MAX, 1), + (i64::MAX, 1, 1), + (i64::MAX, i64::MAX, i64::MAX), + ], + ] + .concat(); + + // Success cases + for (a, b, expected) in test_cases { + let actual_euclidean = gcd_signed(a, b).expect("should succeed"); + assert_eq!( + actual_euclidean, expected, + "gcd_signed({a}, {b}) expected {expected}, actual {actual_euclidean}" + ); + let actual_binary: i64 = + unsigned_gcd(a.try_into().unwrap(), b.try_into().unwrap()) + .try_into() + .expect("overflow"); + assert_eq!( + actual_binary, expected, + "unsigned_gcd({a}, {b}) expected {expected}, actual {actual_binary}" + ); + } + } + + #[test] + fn test_gcd_decimal() { + let test_cases: Vec<(i256, i256, i256)> = [ + GCD_COMMON_TEST_CASES + .iter() + .map(|&(a, b, c)| (i256::from(a), i256::from(b), i256::from(c))) + .collect(), + vec![ + (i256::from(1), i256::MAX, i256::from(1)), + (i256::MAX, i256::from(1), i256::from(1)), + (i256::MAX, i256::MAX, i256::MAX), + ], + ] + .concat(); + + // Success cases + for (a, b, expected) in test_cases { + let actual = gcd_signed(a, b).expect("should succeed"); + assert_eq!( + actual, expected, + "euclid_gcd({a}, {b}) expected {expected}, actual {actual}" + ); + } + } + + #[test] + fn test_lcm_i64() { + let test_cases: Vec<(i64, i64, i64)> = [ + LCM_COMMON_TEST_CASES.into(), + vec![ + // Negative inputs - LCM is always non-negative + (-6, 4, 12), + (-4, -6, 12), + // Max value cases + (1, i64::MAX, i64::MAX), + (i64::MAX, 1, i64::MAX), + (i64::MAX, i64::MAX, i64::MAX), + ], + ] + .concat(); + + for (a, b, expected) in test_cases { + let actual = lcm_signed_int(a, b).expect("should succeed"); + assert_eq!( + actual, expected, + "lcm_signed_int({a}, {b}) expected {expected}, actual {actual}" + ); + } + } + + #[test] + fn test_lcm_decimal() { + let test_cases: Vec<(i256, i256, i256)> = [ + LCM_COMMON_TEST_CASES + .iter() + .map(|&(a, b, c)| (i256::from(a), i256::from(b), i256::from(c))) + .collect(), + vec![ + // Negative inputs - LCM is always non-negative + (i256::from(-6_i64), i256::from(4_i64), i256::from(12_i64)), + (i256::from(-4_i64), i256::from(-6_i64), i256::from(12_i64)), + // Max value cases + (i256::from(1_i64), i256::MAX, i256::MAX), + (i256::MAX, i256::from(1_i64), i256::MAX), + (i256::MAX, i256::MAX, i256::MAX), + ], + ] + .concat(); + + for (a, b, expected) in test_cases { + let actual = lcm_signed(a, b).expect("should succeed"); + assert_eq!( + actual, expected, + "lcm_signed({a}, {b}) expected {expected}, actual {actual}" + ); + } + } +} diff --git a/datafusion/functions/src/math/gcd.rs b/datafusion/functions/src/math/gcd.rs index 8b92c454d9b4c..aeddc3f27c409 100644 --- a/datafusion/functions/src/math/gcd.rs +++ b/datafusion/functions/src/math/gcd.rs @@ -17,16 +17,22 @@ use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; use arrow::compute::try_binary; -use arrow::datatypes::{DataType, Int64Type}; -use arrow::error::ArrowError; -use std::mem::swap; +use arrow::datatypes::{ + DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Int64Type, +}; use std::sync::Arc; -use datafusion_common::{Result, ScalarValue, exec_err, internal_datafusion_err}; +use crate::math::common::{gcd_signed, gcd_signed_int, unsigned_gcd}; +use crate::utils::calculate_binary_decimal_math_cast; +use datafusion_common::utils::take_function_args; +use datafusion_common::{ + Result, ScalarValue, exec_err, internal_datafusion_err, plan_err, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_expr_common::type_coercion::binary::decimal_coercion; use datafusion_macros::user_doc; #[user_doc( @@ -58,11 +64,7 @@ impl Default for GcdFunc { impl GcdFunc { pub fn new() -> Self { Self { - signature: Signature::uniform( - 2, - vec![DataType::Int64], - Volatility::Immutable, - ), + signature: Signature::user_defined(Volatility::Immutable), } } } @@ -76,37 +78,123 @@ impl ScalarUDFImpl for GcdFunc { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Int64) + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [arg1, arg2] = take_function_args(self.name(), arg_types)?; + + let coerced_type = match (arg1, arg2) { + (DataType::Null, _) | (_, DataType::Null) => Ok(DataType::Int64), + (lhs, rhs) if lhs.is_integer() && rhs.is_integer() => Ok(DataType::Int64), + (lhs, rhs) if lhs.is_decimal() || rhs.is_decimal() => { + decimal_coercion(lhs, rhs).map(Ok).unwrap_or_else(|| { + plan_err!( + "Unsupported argument types {lhs:?} and {rhs:?} for function {}", + self.name() + ) + }) + } + (lhs, rhs) => { + plan_err!( + "Unsupported argument types {lhs:?} and {rhs:?} for function {}", + self.name() + ) + } + }?; + Ok(vec![coerced_type.clone(), coerced_type]) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let number_rows = args.number_rows; let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| { internal_datafusion_err!("Expected 2 arguments for function gcd") })?; - match args { - [ColumnarValue::Array(a), ColumnarValue::Array(b)] => { - compute_gcd_for_arrays(&a, &b) + if args[0].data_type() == DataType::Int64 { + // Optimized path for both integers + match args { + [ColumnarValue::Array(a), ColumnarValue::Array(b)] => { + compute_gcd_for_arrays(&a, &b) + } + [ + ColumnarValue::Scalar(ScalarValue::Int64(a)), + ColumnarValue::Scalar(ScalarValue::Int64(b)), + ] => match (a, b) { + (Some(a), Some(b)) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + Some(gcd_signed_int(a, b)?), + ))), + _ => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))), + }, + [ + ColumnarValue::Array(a), + ColumnarValue::Scalar(ScalarValue::Int64(b)), + ] => compute_gcd_with_scalar(&a, b), + [ + ColumnarValue::Scalar(ScalarValue::Int64(a)), + ColumnarValue::Array(b), + ] => compute_gcd_with_scalar(&b, a), + _ => exec_err!("Unsupported argument types for function gcd"), } - [ - ColumnarValue::Scalar(ScalarValue::Int64(a)), - ColumnarValue::Scalar(ScalarValue::Int64(b)), - ] => match (a, b) { - (Some(a), Some(b)) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( - Some(compute_gcd(a, b)?), - ))), - _ => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))), - }, - [ - ColumnarValue::Array(a), - ColumnarValue::Scalar(ScalarValue::Int64(b)), - ] => compute_gcd_with_scalar(&a, b), - [ - ColumnarValue::Scalar(ScalarValue::Int64(a)), - ColumnarValue::Array(b), - ] => compute_gcd_with_scalar(&b, a), - _ => exec_err!("Unsupported argument types for function gcd"), + } else { + // Decimal path: convert left to array and use generic helper + let left = args[0].to_array(number_rows)?; + let right = &args[1]; + + let arr: ArrayRef = match (left.data_type(), right.data_type()) { + ( + lhs @ DataType::Decimal32(precision, scale), + rhs @ DataType::Decimal32(_, _), + ) if *lhs == rhs => calculate_binary_decimal_math_cast::< + Decimal32Type, + Decimal32Type, + Decimal32Type, + _, + >( + &left, right, gcd_signed, *precision, *scale, lhs + )?, + ( + lhs @ DataType::Decimal64(precision, scale), + rhs @ DataType::Decimal64(_, _), + ) if *lhs == rhs => calculate_binary_decimal_math_cast::< + Decimal64Type, + Decimal64Type, + Decimal64Type, + _, + >( + &left, right, gcd_signed, *precision, *scale, lhs + )?, + ( + lhs @ DataType::Decimal128(precision, scale), + rhs @ DataType::Decimal128(_, _), + ) if *lhs == rhs => calculate_binary_decimal_math_cast::< + Decimal128Type, + Decimal128Type, + Decimal128Type, + _, + >( + &left, right, gcd_signed, *precision, *scale, lhs + )?, + ( + lhs @ DataType::Decimal256(precision, scale), + rhs @ DataType::Decimal256(_, _), + ) if *lhs == rhs => calculate_binary_decimal_math_cast::< + Decimal256Type, + Decimal256Type, + Decimal256Type, + _, + >( + &left, right, gcd_signed, *precision, *scale, lhs + )?, + (lhs, rhs) => { + exec_err!( + "Unsupported data types {lhs:?} and {rhs:?} for function {}", + self.name() + ) + }?, + }; + Ok(ColumnarValue::Array(arr)) } } @@ -118,7 +206,7 @@ impl ScalarUDFImpl for GcdFunc { fn compute_gcd_for_arrays(a: &ArrayRef, b: &ArrayRef) -> Result { let a = a.as_primitive::(); let b = b.as_primitive::(); - try_binary(a, b, compute_gcd) + try_binary(a, b, gcd_signed_int) .map(|arr: PrimitiveArray| { ColumnarValue::Array(Arc::new(arr) as ArrayRef) }) @@ -141,44 +229,37 @@ fn compute_gcd_with_scalar(arr: &ArrayRef, scalar: Option) -> Result { let result: PrimitiveArray = - prim.try_unary(|val| compute_gcd(val, scalar_value))?; + prim.try_unary(|val| gcd_signed_int(val, scalar_value))?; Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) } None => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))), } } -/// Computes gcd of two unsigned integers using Binary GCD algorithm. -pub(super) fn unsigned_gcd(mut a: u64, mut b: u64) -> u64 { - if a == 0 { - return b; - } - if b == 0 { - return a; - } +#[cfg(test)] +mod tests { + use super::*; - let shift = (a | b).trailing_zeros(); - a >>= a.trailing_zeros(); - loop { - b >>= b.trailing_zeros(); - if a > b { - swap(&mut a, &mut b); - } - b -= a; - if b == 0 { - return a << shift; - } - } -} + #[test] + fn test_coercion() { + let mut coerced = GcdFunc::new() + .coerce_types(&[DataType::Int64, DataType::Int32]) + .expect("coercion should succeed"); + assert_eq!(coerced, vec![DataType::Int64, DataType::Int64]); -/// Computes greatest common divisor using Binary GCD algorithm. -pub fn compute_gcd(x: i64, y: i64) -> Result { - let a = x.unsigned_abs(); - let b = y.unsigned_abs(); - let r = unsigned_gcd(a, b); - // The result can be up to 2^63 (e.g. gcd(i64::MIN, 0) or - // gcd(i64::MIN, i64::MIN)), which does not fit into i64. - r.try_into().map_err(|_| { - ArrowError::ComputeError(format!("Signed integer overflow in GCD({x}, {y})")) - }) + coerced = GcdFunc::new() + .coerce_types(&[DataType::Decimal128(10, 2), DataType::Int32]) + .expect("coercion should succeed"); + + assert_eq!( + coerced, + vec![DataType::Decimal128(12, 2), DataType::Decimal128(12, 2)] + ); + + coerced = GcdFunc::new() + .coerce_types(&[DataType::Decimal128(10, 2), DataType::Null]) + .expect("coercion should succeed"); + + assert_eq!(coerced, vec![DataType::Int64, DataType::Int64]); + } } diff --git a/datafusion/functions/src/math/lcm.rs b/datafusion/functions/src/math/lcm.rs index 9398e9f8d6e00..245dba0ba3938 100644 --- a/datafusion/functions/src/math/lcm.rs +++ b/datafusion/functions/src/math/lcm.rs @@ -15,25 +15,22 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - -use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; -use arrow::compute::try_binary; -use arrow::datatypes::DataType; -use arrow::datatypes::DataType::Int64; -use arrow::datatypes::Int64Type; +use arrow::array::ArrayRef; +use arrow::datatypes::{ + DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Int64Type, +}; -use arrow::error::ArrowError; -use datafusion_common::{Result, exec_err}; +use crate::math::common::{lcm_signed, lcm_signed_int}; +use crate::utils::{calculate_binary_decimal_math_cast, calculate_binary_math}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, exec_err, plan_err}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_expr_common::type_coercion::binary::decimal_coercion; use datafusion_macros::user_doc; -use super::gcd::unsigned_gcd; -use crate::utils::make_scalar_function; - #[user_doc( doc_section(label = "Math Functions"), description = "Returns the least common multiple of `expression_x` and `expression_y`. Returns 0 if either input is zero.", @@ -62,9 +59,8 @@ impl Default for LcmFunc { impl LcmFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform(2, vec![Int64], Volatility::Immutable), + signature: Signature::user_defined(Volatility::Immutable), } } } @@ -78,49 +74,100 @@ impl ScalarUDFImpl for LcmFunc { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Int64) + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [arg1, arg2] = take_function_args(self.name(), arg_types)?; + + let coerced_type = match (arg1, arg2) { + (DataType::Null, _) | (_, DataType::Null) => Ok(DataType::Int64), + (lhs, rhs) if lhs.is_integer() && rhs.is_integer() => Ok(DataType::Int64), + (lhs, rhs) if lhs.is_decimal() || rhs.is_decimal() => { + decimal_coercion(lhs, rhs).map(Ok).unwrap_or_else(|| { + plan_err!( + "Unsupported argument types {lhs:?} and {rhs:?} for function {}", + self.name() + ) + }) + } + (lhs, rhs) => { + plan_err!( + "Unsupported argument types {lhs:?} and {rhs:?} for function {}", + self.name() + ) + } + }?; + Ok(vec![coerced_type.clone(), coerced_type]) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(lcm, vec![])(&args.args) + let left = &args.args[0].to_array(args.number_rows)?; + let right = &args.args[1]; + + let arr: ArrayRef = match (left.data_type(), right.data_type()) { + (DataType::Int64, _) => calculate_binary_math::< + Int64Type, + Int64Type, + Int64Type, + _, + >(&left, right, lcm_signed_int)?, + ( + lhs @ DataType::Decimal32(precision, scale), + rhs @ DataType::Decimal32(_, _), + ) if *lhs == rhs => { + calculate_binary_decimal_math_cast::< + Decimal32Type, + Decimal32Type, + Decimal32Type, + _, + >(&left, right, lcm_signed, *precision, *scale, lhs)? + } + ( + lhs @ DataType::Decimal64(precision, scale), + rhs @ DataType::Decimal64(_, _), + ) if *lhs == rhs => { + calculate_binary_decimal_math_cast::< + Decimal64Type, + Decimal64Type, + Decimal64Type, + _, + >(&left, right, lcm_signed, *precision, *scale, lhs)? + } + ( + lhs @ DataType::Decimal128(precision, scale), + rhs @ DataType::Decimal128(_, _), + ) if *lhs == rhs => { + calculate_binary_decimal_math_cast::< + Decimal128Type, + Decimal128Type, + Decimal128Type, + _, + >(&left, right, lcm_signed, *precision, *scale, lhs)? + } + ( + lhs @ DataType::Decimal256(precision, scale), + rhs @ DataType::Decimal256(_, _), + ) if *lhs == rhs => { + calculate_binary_decimal_math_cast::< + Decimal256Type, + Decimal256Type, + Decimal256Type, + _, + >(&left, right, lcm_signed, *precision, *scale, lhs)? + } + (lhs, rhs) => { + return exec_err!( + "Unsupported data types {lhs:?} and {rhs:?} for function {}", + self.name() + ); + } + }; + Ok(ColumnarValue::Array(arr)) } fn documentation(&self) -> Option<&Documentation> { self.doc() } } - -/// Lcm SQL function -fn lcm(args: &[ArrayRef]) -> Result { - let compute_lcm = |x: i64, y: i64| -> Result { - if x == 0 || y == 0 { - return Ok(0); - } - - // lcm(x, y) = |x| * |y| / gcd(|x|, |y|) - let a = x.unsigned_abs(); - let b = y.unsigned_abs(); - let gcd = unsigned_gcd(a, b); - // gcd is not zero since both a and b are not zero, so the division is safe. - (a / gcd) - .checked_mul(b) - .and_then(|v| i64::try_from(v).ok()) - .ok_or_else(|| { - ArrowError::ComputeError(format!( - "Signed integer overflow in LCM({x}, {y})" - )) - }) - }; - - match args[0].data_type() { - Int64 => { - let arg1 = args[0].as_primitive::(); - let arg2 = args[1].as_primitive::(); - - let result: PrimitiveArray = try_binary(arg1, arg2, compute_lcm)?; - Ok(Arc::new(result) as ArrayRef) - } - other => exec_err!("Unsupported data type {other:?} for function lcm"), - } -} diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 1754ccb43488a..a5d45380ecf0a 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -25,6 +25,7 @@ use std::sync::Arc; pub mod abs; pub mod bounds; pub mod ceil; +mod common; pub mod cot; mod decimal; pub mod factorial; diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index 78016c0f52f71..aacc8820a8cb6 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::{calculate_binary_decimal_math, calculate_binary_math}; +use crate::utils::{calculate_binary_decimal_math_cast, calculate_binary_math}; use arrow::array::ArrayRef; use arrow::datatypes::DataType::{ @@ -486,7 +486,7 @@ fn round_columnar( } (Decimal32(input_precision, scale), Decimal32(precision, new_scale)) => { // reduce scale to reclaim integer precision - let result = calculate_binary_decimal_math::< + let result = calculate_binary_decimal_math_cast::< Decimal32Type, Int32Type, Decimal32Type, @@ -518,11 +518,12 @@ fn round_columnar( }, *precision, *new_scale, + &DataType::Int32, )?; result as _ } (Decimal64(input_precision, scale), Decimal64(precision, new_scale)) => { - let result = calculate_binary_decimal_math::< + let result = calculate_binary_decimal_math_cast::< Decimal64Type, Int32Type, Decimal64Type, @@ -551,11 +552,12 @@ fn round_columnar( }, *precision, *new_scale, + &DataType::Int32, )?; result as _ } (Decimal128(input_precision, scale), Decimal128(precision, new_scale)) => { - let result = calculate_binary_decimal_math::< + let result = calculate_binary_decimal_math_cast::< Decimal128Type, Int32Type, Decimal128Type, @@ -584,11 +586,12 @@ fn round_columnar( }, *precision, *new_scale, + &DataType::Int32, )?; result as _ } (Decimal256(input_precision, scale), Decimal256(precision, new_scale)) => { - let result = calculate_binary_decimal_math::< + let result = calculate_binary_decimal_math_cast::< Decimal256Type, Int32Type, Decimal256Type, @@ -617,6 +620,7 @@ fn round_columnar( }, *precision, *new_scale, + &DataType::Int32, )?; result as _ } diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index b9bde1454994c..39683e9a6afa2 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -133,6 +133,72 @@ pub fn calculate_binary_math( right: &ColumnarValue, fun: F, ) -> Result>> +where + L: ArrowPrimitiveType, + R: ArrowPrimitiveType, + O: ArrowPrimitiveType, + F: Fn(L::Native, R::Native) -> Result, + R::Native: TryFrom, +{ + calculate_binary_math_cast::(left, right, fun, &R::DATA_TYPE) +} + +/// Computes a binary math function for input arrays using a specified function +/// and applies rescaling to given precision and scale. +/// Generic types: +/// - `L`: Left array decimal type +/// - `R`: Right array primitive type +/// - `O`: Output array decimal type +/// - `F`: Functor computing `fun(l: L, r: R) -> Result` +#[deprecated( + since = "55.0.0", + note = "Use `calculate_binary_decimal_math_cast` instead" +)] +pub fn calculate_binary_decimal_math( + left: &dyn Array, + right: &ColumnarValue, + fun: F, + precision: u8, + scale: i8, +) -> Result>> +where + L: DecimalType, + R: ArrowPrimitiveType, + O: DecimalType, + F: Fn(L::Native, R::Native) -> Result, + R::Native: TryFrom, +{ + calculate_binary_decimal_math_cast::( + left, + right, + fun, + precision, + scale, + &R::DATA_TYPE, + ) +} + +/// Computes a binary math function for input arrays using a specified function. +/// +/// It casts the right operand to `cast_target` instead of the default `R::DATA_TYPE` to preserve +/// the right operand scale. +/// +/// # Type Parameters +/// - `L`: Left array primitive type +/// - `R`: Right array primitive type +/// - `O`: Output array primitive type +/// - `F`: Functor computing `fun(l: L, r: R) -> Result` +/// # Arguments +/// - `left`: Left input array +/// - `right`: Right input array or scalar value +/// - `fun`: Function of type `F` +/// - `cast_target`: Data type to cast right operand to before applying function +fn calculate_binary_math_cast( + left: &dyn Array, + right: &ColumnarValue, + fun: F, + cast_target: &DataType, +) -> Result>> where L: ArrowPrimitiveType, R: ArrowPrimitiveType, @@ -141,7 +207,7 @@ where R::Native: TryFrom, { let left = left.as_primitive::(); - let right = right.cast_to(&R::DATA_TYPE, None)?; + let right = right.cast_to(cast_target, None)?; let result = match right { ColumnarValue::Scalar(scalar) => { if scalar.is_null() { @@ -152,8 +218,7 @@ where let right = R::Native::try_from(scalar.clone()).map_err(|_| { DataFusionError::NotImplemented(format!( "Cannot convert scalar value {} to {}", - &scalar, - R::DATA_TYPE + &scalar, cast_target )) })?; left.try_unary::<_, O, _>(|lvalue| fun(lvalue, right))? @@ -168,18 +233,30 @@ where } /// Computes a binary math function for input arrays using a specified function -/// and apply rescaling to given precision and scale. -/// Generic types: +/// and applies rescaling to given precision and scale. +/// +/// It casts the right operand to `cast_target` instead of the default `R::DATA_TYPE` to preserve +/// the right operand scale. +/// +/// # Type Parameters /// - `L`: Left array decimal type /// - `R`: Right array primitive type /// - `O`: Output array decimal type /// - `F`: Functor computing `fun(l: L, r: R) -> Result` -pub fn calculate_binary_decimal_math( +/// # Arguments +/// - `left`: Left input array +/// - `right`: Right input array or scalar value +/// - `fun`: Function of type `F` +/// - `precision`: Precision to apply to output decimal array +/// - `scale`: Scale to apply to output decimal array +/// - `cast_target`: Data type to cast right operand to before applying function +pub fn calculate_binary_decimal_math_cast( left: &dyn Array, right: &ColumnarValue, fun: F, precision: u8, scale: i8, + cast_target: &DataType, ) -> Result>> where L: DecimalType, @@ -188,7 +265,8 @@ where F: Fn(L::Native, R::Native) -> Result, R::Native: TryFrom, { - let result_array = calculate_binary_math::(left, right, fun)?; + let result_array = + calculate_binary_math_cast::(left, right, fun, cast_target)?; Ok(Arc::new( result_array .as_ref() diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index 1748c9b3e5d36..583d6f6777865 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -686,6 +686,38 @@ select gcd(-9223372036854775808, 0); query error DataFusion error: Arrow error: Compute error: Signed integer overflow in GCD\(0, \-9223372036854775808\) select gcd(0, -9223372036854775808); +# gcd decimal +query RT +select gcd(2::decimal(38, 0), 3::decimal(38, 0)), arrow_typeof(gcd(2::decimal(38, 0), 3::decimal(38, 0))); +---- +1 Decimal128(38, 0) + +query RT +select gcd(0::decimal(38, 0), 3::decimal(38, 0)), arrow_typeof(gcd(0::decimal(38, 0), 3::decimal(38, 0))); +---- +3 Decimal128(38, 0) + +query RT +select gcd(2, 3::decimal(38, 0)), arrow_typeof(gcd(2, 3::decimal(38, 0))); +---- +1 Decimal128(38, 0) + +query RR +select gcd(-15::decimal(38, 0), -3::decimal(38, 0)), gcd(-15::decimal(38, 0), 3::decimal(38, 0)); +---- +3 3 + +# non-whole number case +query RT +select gcd(15.3::decimal(38, 1), 2.9::decimal(38, 1)), arrow_typeof(gcd(15.3::decimal(38, 1), 2.9::decimal(38, 1))); +---- +0.1 Decimal128(38, 1) + +# both decimal arguments are coerced to widest - decimal(38, 5), return type is that as well +query RT +select gcd(15::decimal(30, 2), 3::decimal(38, 5)), arrow_typeof(gcd(15::decimal(30, 2), 3::decimal(38, 5))); +---- +3 Decimal128(38, 5) ## lcm @@ -727,6 +759,28 @@ select lcm(1, -9223372036854775808); query error DataFusion error: Arrow error: Compute error: Signed integer overflow in LCM\(2, 9223372036854775803\) select lcm(2, 9223372036854775803); +# lcm decimal +query R +select lcm(2::decimal(38, 0), 3::decimal(38, 0)); +---- +6 + +query RT +select lcm(0::decimal(38, 0), 3::decimal(38, 0)), arrow_typeof(lcm(0::decimal(38, 0), 3::decimal(38, 0))); +---- +0 Decimal128(38, 0) + +query RT +select lcm(2, 3::decimal(38, 0)), arrow_typeof(lcm(2, 3::decimal(38, 0))); +---- +6 Decimal128(38, 0) + +# both decimal arguments are coerced to widest - decimal(38, 5), return type is that as well +query RT +select lcm(2::decimal(30, 2), 3::decimal(38, 5)), arrow_typeof(lcm(2::decimal(30, 2), 3::decimal(38, 5))); +---- +6 Decimal128(38, 5) + ## pow/power @@ -899,6 +953,28 @@ SELECT lcm(6, column1) FROM (VALUES (4), (9), (0)); 18 0 +query I +SELECT lcm(column1, column2) FROM (VALUES (0, 5), (3, 5), (25, 5), (-16, 5)); +---- +0 +15 +25 +80 + +query R +SELECT lcm(6, arrow_cast(column1, 'Decimal128(38,0)')) FROM (VALUES (4), (9), (0)); +---- +12 +18 +0 + +query R +SELECT lcm(arrow_cast(column1, 'Decimal128(38,0)'), arrow_cast(column2, 'Decimal128(38,0)')) FROM (VALUES (6, 4), (6, 9), (6, 0)); +---- +12 +18 +0 + # lcm array and scalar with nulls in the array query I SELECT lcm(column1, 5) FROM (VALUES (0), (NULL), (25)); @@ -942,6 +1018,29 @@ SELECT gcd(15, column1) FROM (VALUES (10), (25), (0)); 5 15 +query I +SELECT gcd(column1, column2) FROM (VALUES (8, 12), (18, 12), (0, 12), (-36, 12)); +---- +4 +6 +12 +12 + +query R +SELECT gcd(15, arrow_cast(column1, 'Decimal128(38,0)')) FROM (VALUES (10), (25), (0)); +---- +5 +5 +15 + +query R +SELECT gcd(arrow_cast(column1, 'Decimal128(38,0)'), arrow_cast(column2, 'Decimal128(38,0)')) FROM (VALUES (15, 10), (15, 25), (15, 0)); +---- +5 +5 +15 + + # gcd array and scalar with nulls in the array query I SELECT gcd(column1, 12) FROM (VALUES (8), (NULL), (0), (-36));