From 3d4db817d20388cd1286927d6c9b9e8e5c640449 Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Wed, 8 Mar 2023 22:29:44 +0100 Subject: [PATCH 1/5] fix: cast literal to timestamp --- .../src/unwrap_cast_in_comparison.rs | 72 +++++++++++++++++-- 1 file changed, 68 insertions(+), 4 deletions(-) diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 46c4d35227ef8..833876d53cc1a 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -400,16 +400,40 @@ fn try_cast_literal_to_type( DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)), DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)), DataType::Timestamp(TimeUnit::Second, tz) => { - ScalarValue::TimestampSecond(Some(value as i64), tz.clone()) + let value = cast_between_timestamp( + lit_data_type, + DataType::Timestamp(TimeUnit::Second, tz.clone()), + value, + ) + .unwrap(); + ScalarValue::TimestampSecond(Some(value), tz.clone()) } DataType::Timestamp(TimeUnit::Millisecond, tz) => { - ScalarValue::TimestampMillisecond(Some(value as i64), tz.clone()) + let value = cast_between_timestamp( + lit_data_type, + DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), + value, + ) + .unwrap(); + ScalarValue::TimestampMillisecond(Some(value), tz.clone()) } DataType::Timestamp(TimeUnit::Microsecond, tz) => { - ScalarValue::TimestampMicrosecond(Some(value as i64), tz.clone()) + let value = cast_between_timestamp( + lit_data_type, + DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), + value, + ) + .unwrap(); + ScalarValue::TimestampMicrosecond(Some(value), tz.clone()) } DataType::Timestamp(TimeUnit::Nanosecond, tz) => { - ScalarValue::TimestampNanosecond(Some(value as i64), tz.clone()) + let value = cast_between_timestamp( + lit_data_type, + DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), + value, + ) + .unwrap(); + ScalarValue::TimestampNanosecond(Some(value), tz.clone()) } DataType::Decimal128(p, s) => { ScalarValue::Decimal128(Some(value), *p, *s) @@ -428,6 +452,31 @@ fn try_cast_literal_to_type( } } +/// Cast a timestamp value from one unit to another +fn cast_between_timestamp(from: DataType, to: DataType, value: i128) -> Option { + let seconds = match from { + DataType::Timestamp(TimeUnit::Second, _) => Some(value * 1000 * 1000 * 1000), + DataType::Timestamp(TimeUnit::Millisecond, _) => Some(value * 1000 * 1000), + DataType::Timestamp(TimeUnit::Microsecond, _) => Some(value * 1000), + DataType::Timestamp(TimeUnit::Nanosecond, _) => Some(value), + _ => return Some(value as i64), + }; + + match to { + DataType::Timestamp(TimeUnit::Second, _) => { + seconds.map(|s| (s / 1000 / 1000 / 1000) as i64) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + seconds.map(|s| (s / 1000 / 1000) as i64) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + seconds.map(|s| (s / 1000) as i64) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => seconds.map(|s| s as i64), + _ => None, + } +} + #[cfg(test)] mod tests { use super::*; @@ -1070,4 +1119,19 @@ mod tests { } } } + + #[test] + fn test_try_cast_literal_to_timestamp() { + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Microsecond, None), + ) + .unwrap() + .unwrap(); + + assert_eq!( + new_scalar, + ScalarValue::TimestampMicrosecond(Some(123), None) + ); + } } From ff9d79512d8ec2397c504aab2d81f1a80b803ef9 Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Thu, 9 Mar 2023 10:40:40 +0100 Subject: [PATCH 2/5] update tests for all transformation --- .../src/unwrap_cast_in_comparison.rs | 121 ++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 833876d53cc1a..583dfee804ea2 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -1122,6 +1122,7 @@ mod tests { #[test] fn test_try_cast_literal_to_timestamp() { + // TimestampNanosecond to TimestampMicrosecond let new_scalar = try_cast_literal_to_type( &ScalarValue::TimestampNanosecond(Some(123456), None), &DataType::Timestamp(TimeUnit::Microsecond, None), @@ -1133,5 +1134,125 @@ mod tests { new_scalar, ScalarValue::TimestampMicrosecond(Some(123), None) ); + + // TimestampNanosecond to TimestampMillisecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap() + .unwrap(); + + assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None)); + + // TimestampNanosecond to TimestampSecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Second, None), + ) + .unwrap() + .unwrap(); + + assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None)); + + // TimestampMicrosecond to TimestampNanosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMicrosecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap() + .unwrap(); + + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123000), None) + ); + + // TimestampMicrosecond to TimestampMillisecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMicrosecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap() + .unwrap(); + + assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None)); + + // TimestampMicrosecond to TimestampSecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMicrosecond(Some(123456789), None), + &DataType::Timestamp(TimeUnit::Second, None), + ) + .unwrap() + .unwrap(); + assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None)); + + // TimestampMillisecond to TimestampNanosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMillisecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap() + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123000000), None) + ); + + // TimestampMillisecond to TimestampMicrosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMillisecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Microsecond, None), + ) + .unwrap() + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampMicrosecond(Some(123000), None) + ); + // TimestampMillisecond to TimestampSecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMillisecond(Some(123456789), None), + &DataType::Timestamp(TimeUnit::Second, None), + ) + .unwrap() + .unwrap(); + assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None)); + + // TimestampSecond to TimestampNanosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap() + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123000000000), None) + ); + + // TimestampSecond to TimestampMicrosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Microsecond, None), + ) + .unwrap() + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampMicrosecond(Some(123000000), None) + ); + + // TimestampSecond to TimestampMillisecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap() + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampMillisecond(Some(123000), None) + ); } } From 4f95a55f24469843ceb95ddb1778ca167aadae56 Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Thu, 9 Mar 2023 10:46:31 +0100 Subject: [PATCH 3/5] handle cast between same type --- .../optimizer/src/unwrap_cast_in_comparison.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 583dfee804ea2..f44fb0fdbca18 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -454,6 +454,11 @@ fn try_cast_literal_to_type( /// Cast a timestamp value from one unit to another fn cast_between_timestamp(from: DataType, to: DataType, value: i128) -> Option { + // avoid useless computation, like cast from second to second + if from == to { + return Some(value as i64); + } + let seconds = match from { DataType::Timestamp(TimeUnit::Second, _) => Some(value * 1000 * 1000 * 1000), DataType::Timestamp(TimeUnit::Millisecond, _) => Some(value * 1000 * 1000), @@ -1122,6 +1127,19 @@ mod tests { #[test] fn test_try_cast_literal_to_timestamp() { + // same timestamp + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap() + .unwrap(); + + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123456), None) + ); + // TimestampNanosecond to TimestampMicrosecond let new_scalar = try_cast_literal_to_type( &ScalarValue::TimestampNanosecond(Some(123456), None), From 2ca12a35de55e9226517c9556f5982b9c39b57d5 Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Thu, 9 Mar 2023 20:01:25 +0100 Subject: [PATCH 4/5] refactor cast_between_timestamp to avoid overflow --- .../src/unwrap_cast_in_comparison.rs | 39 +++++++++---------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index f44fb0fdbca18..e1bd1c00ada62 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -24,6 +24,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; +use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::{BinaryExpr, Cast, TryCast}; use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion}; @@ -31,6 +32,7 @@ use datafusion_expr::utils::from_plan; use datafusion_expr::{ binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, }; +use std::cmp::Ordering; use std::sync::Arc; /// [`UnwrapCastInComparison`] attempts to remove casts from @@ -454,31 +456,26 @@ fn try_cast_literal_to_type( /// Cast a timestamp value from one unit to another fn cast_between_timestamp(from: DataType, to: DataType, value: i128) -> Option { - // avoid useless computation, like cast from second to second - if from == to { - return Some(value as i64); - } + let from_scale = match from { + DataType::Timestamp(TimeUnit::Second, _) => 1, + DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, + DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, + DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, + _ => return Some(value as i64), + }; - let seconds = match from { - DataType::Timestamp(TimeUnit::Second, _) => Some(value * 1000 * 1000 * 1000), - DataType::Timestamp(TimeUnit::Millisecond, _) => Some(value * 1000 * 1000), - DataType::Timestamp(TimeUnit::Microsecond, _) => Some(value * 1000), - DataType::Timestamp(TimeUnit::Nanosecond, _) => Some(value), + let to_scale = match to { + DataType::Timestamp(TimeUnit::Second, _) => 1, + DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, + DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, + DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, _ => return Some(value as i64), }; - match to { - DataType::Timestamp(TimeUnit::Second, _) => { - seconds.map(|s| (s / 1000 / 1000 / 1000) as i64) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - seconds.map(|s| (s / 1000 / 1000) as i64) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - seconds.map(|s| (s / 1000) as i64) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => seconds.map(|s| s as i64), - _ => None, + match from_scale.cmp(&to_scale) { + Ordering::Less => Some(value as i64 * (to_scale / from_scale)), + Ordering::Greater => Some(value as i64 / (from_scale / to_scale)), + Ordering::Equal => Some(value as i64), } } From 89ea8162aca3a6763e700c55c06e85abd8da0f47 Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Fri, 10 Mar 2023 14:16:22 +0100 Subject: [PATCH 5/5] handle overflow to None --- .../src/unwrap_cast_in_comparison.rs | 40 +++++++++++-------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index e1bd1c00ada62..a940cf272f6a4 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -406,36 +406,32 @@ fn try_cast_literal_to_type( lit_data_type, DataType::Timestamp(TimeUnit::Second, tz.clone()), value, - ) - .unwrap(); - ScalarValue::TimestampSecond(Some(value), tz.clone()) + ); + ScalarValue::TimestampSecond(value, tz.clone()) } DataType::Timestamp(TimeUnit::Millisecond, tz) => { let value = cast_between_timestamp( lit_data_type, DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), value, - ) - .unwrap(); - ScalarValue::TimestampMillisecond(Some(value), tz.clone()) + ); + ScalarValue::TimestampMillisecond(value, tz.clone()) } DataType::Timestamp(TimeUnit::Microsecond, tz) => { let value = cast_between_timestamp( lit_data_type, DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), value, - ) - .unwrap(); - ScalarValue::TimestampMicrosecond(Some(value), tz.clone()) + ); + ScalarValue::TimestampMicrosecond(value, tz.clone()) } DataType::Timestamp(TimeUnit::Nanosecond, tz) => { let value = cast_between_timestamp( lit_data_type, DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), value, - ) - .unwrap(); - ScalarValue::TimestampNanosecond(Some(value), tz.clone()) + ); + ScalarValue::TimestampNanosecond(value, tz.clone()) } DataType::Decimal128(p, s) => { ScalarValue::Decimal128(Some(value), *p, *s) @@ -456,12 +452,13 @@ fn try_cast_literal_to_type( /// Cast a timestamp value from one unit to another fn cast_between_timestamp(from: DataType, to: DataType, value: i128) -> Option { + let value = value as i64; let from_scale = match from { DataType::Timestamp(TimeUnit::Second, _) => 1, DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, - _ => return Some(value as i64), + _ => return Some(value), }; let to_scale = match to { @@ -469,13 +466,13 @@ fn cast_between_timestamp(from: DataType, to: DataType, value: i128) -> Option MILLISECONDS, DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, - _ => return Some(value as i64), + _ => return Some(value), }; match from_scale.cmp(&to_scale) { - Ordering::Less => Some(value as i64 * (to_scale / from_scale)), - Ordering::Greater => Some(value as i64 / (from_scale / to_scale)), - Ordering::Equal => Some(value as i64), + Ordering::Less => value.checked_mul(to_scale / from_scale), + Ordering::Greater => Some(value / (from_scale / to_scale)), + Ordering::Equal => Some(value), } } @@ -1269,5 +1266,14 @@ mod tests { new_scalar, ScalarValue::TimestampMillisecond(Some(123000), None) ); + + // overflow + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(i64::MAX), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap() + .unwrap(); + assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None)); } }