diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 46c4d35227ef8..a940cf272f6a4 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -24,6 +24,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; +use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::{BinaryExpr, Cast, TryCast}; use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion}; @@ -31,6 +32,7 @@ use datafusion_expr::utils::from_plan; use datafusion_expr::{ binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, }; +use std::cmp::Ordering; use std::sync::Arc; /// [`UnwrapCastInComparison`] attempts to remove casts from @@ -400,16 +402,36 @@ fn try_cast_literal_to_type( DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)), DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)), DataType::Timestamp(TimeUnit::Second, tz) => { - ScalarValue::TimestampSecond(Some(value as i64), tz.clone()) + let value = cast_between_timestamp( + lit_data_type, + DataType::Timestamp(TimeUnit::Second, tz.clone()), + value, + ); + ScalarValue::TimestampSecond(value, tz.clone()) } DataType::Timestamp(TimeUnit::Millisecond, tz) => { - ScalarValue::TimestampMillisecond(Some(value as i64), tz.clone()) + let value = cast_between_timestamp( + lit_data_type, + DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), + value, + ); + ScalarValue::TimestampMillisecond(value, tz.clone()) } DataType::Timestamp(TimeUnit::Microsecond, tz) => { - ScalarValue::TimestampMicrosecond(Some(value as i64), tz.clone()) + let value = cast_between_timestamp( + lit_data_type, + DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), + value, + ); + ScalarValue::TimestampMicrosecond(value, tz.clone()) } DataType::Timestamp(TimeUnit::Nanosecond, tz) => { - ScalarValue::TimestampNanosecond(Some(value as i64), tz.clone()) + let value = cast_between_timestamp( + lit_data_type, + DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), + value, + ); + ScalarValue::TimestampNanosecond(value, tz.clone()) } DataType::Decimal128(p, s) => { ScalarValue::Decimal128(Some(value), *p, *s) @@ -428,6 +450,32 @@ fn try_cast_literal_to_type( } } +/// Cast a timestamp value from one unit to another +fn cast_between_timestamp(from: DataType, to: DataType, value: i128) -> Option { + let value = value as i64; + let from_scale = match from { + DataType::Timestamp(TimeUnit::Second, _) => 1, + DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, + DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, + DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, + _ => return Some(value), + }; + + let to_scale = match to { + DataType::Timestamp(TimeUnit::Second, _) => 1, + DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, + DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, + DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, + _ => return Some(value), + }; + + match from_scale.cmp(&to_scale) { + Ordering::Less => value.checked_mul(to_scale / from_scale), + Ordering::Greater => Some(value / (from_scale / to_scale)), + Ordering::Equal => Some(value), + } +} + #[cfg(test)] mod tests { use super::*; @@ -1070,4 +1118,162 @@ mod tests { } } } + + #[test] + fn test_try_cast_literal_to_timestamp() { + // same timestamp + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap() + .unwrap(); + + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123456), None) + ); + + // TimestampNanosecond to TimestampMicrosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Microsecond, None), + ) + .unwrap() + .unwrap(); + + assert_eq!( + new_scalar, + ScalarValue::TimestampMicrosecond(Some(123), None) + ); + + // TimestampNanosecond to TimestampMillisecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap() + .unwrap(); + + assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None)); + + // TimestampNanosecond to TimestampSecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Second, None), + ) + .unwrap() + .unwrap(); + + assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None)); + + // TimestampMicrosecond to TimestampNanosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMicrosecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap() + .unwrap(); + + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123000), None) + ); + + // TimestampMicrosecond to TimestampMillisecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMicrosecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap() + .unwrap(); + + assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None)); + + // TimestampMicrosecond to TimestampSecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMicrosecond(Some(123456789), None), + &DataType::Timestamp(TimeUnit::Second, None), + ) + .unwrap() + .unwrap(); + assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None)); + + // TimestampMillisecond to TimestampNanosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMillisecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap() + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123000000), None) + ); + + // TimestampMillisecond to TimestampMicrosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMillisecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Microsecond, None), + ) + .unwrap() + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampMicrosecond(Some(123000), None) + ); + // TimestampMillisecond to TimestampSecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMillisecond(Some(123456789), None), + &DataType::Timestamp(TimeUnit::Second, None), + ) + .unwrap() + .unwrap(); + assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None)); + + // TimestampSecond to TimestampNanosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap() + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123000000000), None) + ); + + // TimestampSecond to TimestampMicrosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Microsecond, None), + ) + .unwrap() + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampMicrosecond(Some(123000000), None) + ); + + // TimestampSecond to TimestampMillisecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap() + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampMillisecond(Some(123000), None) + ); + + // overflow + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(i64::MAX), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap() + .unwrap(); + assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None)); + } }