diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 198ad88b945bd..6b8eaa0be0b82 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -27,15 +27,15 @@ use arrow::datatypes::{ Decimal256Type, Float64Type, Int64Type, }; use arrow::error::ArrowError; +use datafusion_common::types::{NativeType, logical_float64, logical_int64}; use datafusion_common::utils::take_function_args; -use datafusion_common::{Result, ScalarValue, exec_err, plan_datafusion_err}; +use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::type_coercion::is_decimal; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, + Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, + ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, lit, }; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; #[user_doc( @@ -67,8 +67,26 @@ impl Default for PowerFunc { impl PowerFunc { pub fn new() -> Self { + let integer = Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ); + let decimal = Coercion::new_exact(TypeSignatureClass::Decimal); + let float = Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ); Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![decimal.clone(), integer]), + TypeSignature::Coercible(vec![decimal.clone(), float.clone()]), + TypeSignature::Coercible(vec![float; 2]), + ], + Volatility::Immutable, + ), aliases: vec![String::from("pow")], } } @@ -153,6 +171,7 @@ impl ScalarUDFImpl for PowerFunc { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "power" } @@ -162,57 +181,23 @@ impl ScalarUDFImpl for PowerFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + if arg_types[0].is_null() { + Ok(DataType::Float64) + } else { + Ok(arg_types[0].clone()) + } } fn aliases(&self) -> &[String] { &self.aliases } - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let [arg1, arg2] = take_function_args(self.name(), arg_types)?; - - fn coerced_type_exp(name: &str, data_type: &DataType) -> Result { - match data_type { - DataType::Null => Ok(DataType::Int64), - d if d.is_floating() => Ok(DataType::Float64), - d if d.is_integer() => Ok(DataType::Int64), - d if is_decimal(d) => Ok(DataType::Float64), - other => { - exec_err!("Unsupported data type {other:?} for {} function", name) - } - } - } - - // Determine the exponent type first, as it affects base coercion - let exp_type = coerced_type_exp(self.name(), arg2)?; - - // For base coercion: always use Float64 for integer/null bases - // This matches PostgreSQL behavior and handles negative exponents correctly - fn coerced_type_base(name: &str, data_type: &DataType) -> Result { - match data_type { - d if d.is_floating() => Ok(DataType::Float64), - // Integer and Null bases always coerce to Float64 - // (integer power doesn't support negative exponents, and pow() - // should return float like PostgreSQL does) - DataType::Null => Ok(DataType::Float64), - d if d.is_integer() => Ok(DataType::Float64), - d if is_decimal(d) => Ok(d.clone()), - other => { - exec_err!("Unsupported data type {other:?} for {} function", name) - } - } - } - - Ok(vec![coerced_type_base(self.name(), arg1)?, exp_type]) - } - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let base = &args.args[0].to_array(args.number_rows)?; - let exponent = &args.args[1]; + let [base, exponent] = take_function_args(self.name(), &args.args)?; + let base = base.to_array(args.number_rows)?; let arr: ArrayRef = match (base.data_type(), exponent.data_type()) { - (DataType::Float64, _) => { + (DataType::Float64, DataType::Float64) => { calculate_binary_math::( &base, exponent, @@ -322,9 +307,8 @@ impl ScalarUDFImpl for PowerFunc { )? } (base_type, exp_type) => { - return exec_err!( - "Unsupported data types for base {base_type:?} and exponent {exp_type:?} for function {}", - self.name() + return internal_err!( + "Unsupported data types for base {base_type:?} and exponent {exp_type:?} for power" ); } }; @@ -332,30 +316,33 @@ impl ScalarUDFImpl for PowerFunc { } /// Simplify the `power` function by the relevant rules: - /// 1. Power(a, 0) ===> 0 + /// 1. Power(a, 0) ===> 1 /// 2. Power(a, 1) ===> a /// 3. Power(a, Log(a, b)) ===> b fn simplify( &self, - mut args: Vec, + args: Vec, info: &dyn SimplifyInfo, ) -> Result { - let exponent = args.pop().ok_or_else(|| { - plan_datafusion_err!("Expected power to have 2 arguments, got 0") - })?; - let base = args.pop().ok_or_else(|| { - plan_datafusion_err!("Expected power to have 2 arguments, got 1") - })?; - + let [base, exponent] = take_function_args("power", args)?; + let base_type = info.get_data_type(&base)?; let exponent_type = info.get_data_type(&exponent)?; + + // Null propagation + if base_type.is_null() || exponent_type.is_null() { + let return_type = self.return_type(&[base_type, exponent_type])?; + return Ok(ExprSimplifyResult::Simplified(lit( + ScalarValue::Null.cast_to(&return_type)? + ))); + } + match exponent { Expr::Literal(value, _) if value == ScalarValue::new_zero(&exponent_type)? => { - Ok(ExprSimplifyResult::Simplified(Expr::Literal( - ScalarValue::new_one(&info.get_data_type(&base)?)?, - None, - ))) + Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one( + &base_type, + )?))) } Expr::Literal(value, _) if value == ScalarValue::new_one(&exponent_type)? => { Ok(ExprSimplifyResult::Simplified(base)) @@ -383,202 +370,6 @@ fn is_log(func: &ScalarUDF) -> bool { #[cfg(test)] mod tests { use super::*; - use arrow::array::{Array, Decimal128Array, Float64Array, Int64Array}; - use arrow::datatypes::{DECIMAL128_MAX_SCALE, Field}; - use datafusion_common::cast::{as_decimal128_array, as_float64_array}; - use datafusion_common::config::ConfigOptions; - use std::sync::Arc; - - #[cfg(test)] - #[ctor::ctor] - fn init() { - // Enable RUST_LOG logging configuration for test - let _ = env_logger::try_init(); - } - - #[test] - fn test_power_f64() { - let arg_fields = vec![ - Field::new("a", DataType::Float64, true).into(), - Field::new("a", DataType::Float64, true).into(), - ]; - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(Arc::new(Float64Array::from(vec![ - 2.0, 2.0, 3.0, 5.0, - ]))), // base - ColumnarValue::Array(Arc::new(Float64Array::from(vec![ - 3.0, 2.0, 4.0, 4.0, - ]))), // exponent - ], - arg_fields, - number_rows: 4, - return_field: Field::new("f", DataType::Float64, true).into(), - config_options: Arc::new(ConfigOptions::default()), - }; - let result = PowerFunc::new() - .invoke_with_args(args) - .expect("failed to initialize function power"); - - match result { - ColumnarValue::Array(arr) => { - let floats = as_float64_array(&arr) - .expect("failed to convert result to a Float64Array"); - assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), 8.0); - assert_eq!(floats.value(1), 4.0); - assert_eq!(floats.value(2), 81.0); - assert_eq!(floats.value(3), 625.0); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } - } - - #[test] - fn test_power_i128() { - let arg_fields = vec![ - Field::new( - "a", - DataType::Decimal128(DECIMAL128_MAX_SCALE as u8, 0), - true, - ) - .into(), - Field::new("a", DataType::Int64, true).into(), - ]; - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(Arc::new( - Decimal128Array::from(vec![2, 2, 3, 5, 0, 5]) - .with_precision_and_scale(DECIMAL128_MAX_SCALE as u8, 0) - .unwrap(), - )), // base - ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4, 4, 0]))), // exponent - ], - arg_fields, - number_rows: 6, - return_field: Field::new( - "f", - DataType::Decimal128(DECIMAL128_MAX_SCALE as u8, 0), - true, - ) - .into(), - config_options: Arc::new(ConfigOptions::default()), - }; - let result = PowerFunc::new() - .invoke_with_args(args) - .expect("failed to initialize function power"); - - match result { - ColumnarValue::Array(arr) => { - let ints = as_decimal128_array(&arr) - .expect("failed to convert result to an array"); - - assert_eq!(ints.len(), 6); - assert_eq!(ints.value(0), i128::from(8)); - assert_eq!(ints.value(1), i128::from(4)); - assert_eq!(ints.value(2), i128::from(81)); - assert_eq!(ints.value(3), i128::from(625)); - assert_eq!(ints.value(4), i128::from(0)); - assert_eq!(ints.value(5), i128::from(1)); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } - } - - #[test] - fn test_power_array_null() { - let arg_fields = vec![ - Field::new("a", DataType::Float64, true).into(), - Field::new("a", DataType::Float64, true).into(), - ]; - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0, 2.0, 2.0]))), // base - ColumnarValue::Array(Arc::new(Float64Array::from(vec![ - Some(1.0), - None, - Some(3.0), - ]))), // exponent - ], - arg_fields, - number_rows: 3, - return_field: Field::new("f", DataType::Float64, true).into(), - config_options: Arc::new(ConfigOptions::default()), - }; - let result = PowerFunc::new() - .invoke_with_args(args) - .expect("failed to initialize function power"); - - match result { - ColumnarValue::Array(arr) => { - let floats = - as_float64_array(&arr).expect("failed to convert result to an array"); - - assert_eq!(floats.len(), 3); - assert!(!floats.is_null(0)); - assert_eq!(floats.value(0), 2.0); - assert!(floats.is_null(1)); - assert!(!floats.is_null(2)); - assert_eq!(floats.value(2), 8.0); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } - } - - #[test] - fn test_power_decimal_with_scale() { - // 2.5 ^ 4 = 39 - // 2.5 is 25 in Decimal128(2, 1) by parsing rules - // Signature is Decimal128(2, 1) -> Int64 -> Decimal128(2, 1), therefore - // result is 390 in Decimal128(2, 1) aka 39 in unscaled Decimal128(2, 0) - let arg_fields = vec![ - Field::new( - "a", - DataType::Decimal128(DECIMAL128_MAX_SCALE as u8, 0), - true, - ) - .into(), - Field::new("a", DataType::Int64, true).into(), - ]; - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Decimal128( - Some(i128::from(25)), - 2, - 1, - )), // base - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), // exponent - ], - arg_fields, - number_rows: 1, - return_field: Field::new("f", DataType::Decimal128(2, 1), true).into(), - config_options: Arc::new(ConfigOptions::default()), - }; - let result = PowerFunc::new() - .invoke_with_args(args) - .expect("failed to initialize function power"); - - match result { - ColumnarValue::Array(arr) => { - let ints = as_decimal128_array(&arr) - .expect("failed to convert result to an array"); - - assert_eq!(ints.len(), 1); - assert_eq!(ints.value(0), i128::from(390)); - // Signature stays the same as input - assert_eq!(*arr.data_type(), DataType::Decimal128(2, 1)); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } - } #[test] fn test_pow_decimal128_helper() { @@ -601,46 +392,4 @@ mod tests { "Not yet implemented: Negative scale is not yet supported value: -1" ); } - - #[test] - fn test_power_coerce_types() { - let power_func = PowerFunc::new(); - - // Int64 base with Int64 exponent -> base coerced to Float64 (like PostgreSQL) - // This allows negative exponents to work correctly - let result = power_func - .coerce_types(&[DataType::Int64, DataType::Int64]) - .unwrap(); - assert_eq!(result, vec![DataType::Float64, DataType::Int64]); - - // Float64 base with Float64 exponent -> both stay Float64 - let result = power_func - .coerce_types(&[DataType::Float64, DataType::Float64]) - .unwrap(); - assert_eq!(result, vec![DataType::Float64, DataType::Float64]); - - // Int64 base with Float64 exponent -> base coerced to Float64 - let result = power_func - .coerce_types(&[DataType::Int64, DataType::Float64]) - .unwrap(); - assert_eq!(result, vec![DataType::Float64, DataType::Float64]); - - // Int32 base with Float32 exponent -> both coerced to Float64 - let result = power_func - .coerce_types(&[DataType::Int32, DataType::Float32]) - .unwrap(); - assert_eq!(result, vec![DataType::Float64, DataType::Float64]); - - // Null base with Float64 exponent -> base coerced to Float64 - let result = power_func - .coerce_types(&[DataType::Null, DataType::Float64]) - .unwrap(); - assert_eq!(result, vec![DataType::Float64, DataType::Float64]); - - // Null base with Int64 exponent -> base coerced to Float64 (like PostgreSQL) - let result = power_func - .coerce_types(&[DataType::Null, DataType::Int64]) - .unwrap(); - assert_eq!(result, vec![DataType::Float64, DataType::Int64]); - } } diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index 322ba7a104a7d..53cf17fe7a545 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -740,6 +740,60 @@ select power(2107754225, 1221660777); ---- Infinity +query R rowsort +select power(base::double, exponent::double) +from values + (2.0, 2.0), + (5.0, 4.0), + (2.0, 3.0), + (3.0, 4.0) as t(base, exponent); +---- +4 +625 +8 +81 + +query R rowsort +select power(base::bigint, exponent::bigint) +from values + (2, 2), + (5, 4), + (2, 3), + (3, 4), + (2, NULL) as t(base, exponent); +---- +4 +625 +8 +81 +NULL + +query RT rowsort +select + power(base::decimal(38, 0), exponent::decimal(38, 0)), + arrow_typeof(power(base::decimal(38, 0), exponent::decimal(38, 0))) +from values + (0, 4), + (5, 0), + (2, 2), + (5, 4), + (2, 3), + (3, 4) as t(base, exponent); +---- +0 Decimal128(38, 0) +1 Decimal128(38, 0) +4 Decimal128(38, 0) +625 Decimal128(38, 0) +8 Decimal128(38, 0) +81 Decimal128(38, 0) + +query RT +select + pow(2.5::decimal(2, 1), 4::bigint), + arrow_typeof(pow(2.5::decimal(2, 1), 4::bigint)); +---- +39 Decimal128(2, 1) + # factorial overflow query error DataFusion error: Arrow error: Compute error: Overflow happened on FACTORIAL\(350943270\) select FACTORIAL(350943270); diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 7c6b38b78e500..9c7071cb65c23 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1883,7 +1883,7 @@ D false # test string_temporal_coercion query BBBBBBBBBB -select +select arrow_cast(to_timestamp('2020-01-01 01:01:11.1234567890Z'), 'Timestamp(Second, None)') == '2020-01-01T01:01:11', arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, None)') == arrow_cast('2020-01-02T01:01:11', 'LargeUtf8'), arrow_cast(to_timestamp('2020-01-03 01:01:11.1234567890Z'), 'Time32(Second)') == '01:01:11',