Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 91 additions & 54 deletions datafusion/functions/src/datetime/date_bin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,12 @@ use arrow::datatypes::{
DataType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
Time64NanosecondType, TimeUnit,
};
use arrow::error::ArrowError;
use arrow::temporal_conversions::NANOSECONDS_IN_DAY;
use datafusion_common::cast::as_primitive_array;
use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err, plan_err};
use datafusion_common::{
DataFusionError, Result, ScalarValue, exec_err, not_impl_err, plan_err,
};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
Expand Down Expand Up @@ -322,7 +325,7 @@ impl Interval {
// return time in nanoseconds that the source timestamp falls into based on the stride and origin
fn date_bin_nanos_interval(stride_nanos: i64, source: i64, origin: i64) -> Result<i64> {
let time_diff = source.checked_sub(origin).ok_or_else(|| {
arrow::error::ArrowError::InvalidArgumentError(format!(
ArrowError::InvalidArgumentError(format!(
"date_bin source timestamp {source} - origin {origin} overflows i64"
))
})?;
Expand All @@ -331,7 +334,7 @@ fn date_bin_nanos_interval(stride_nanos: i64, source: i64, origin: i64) -> Resul
let time_delta = compute_distance(time_diff, stride_nanos)?;

origin.checked_add(time_delta).ok_or_else(|| {
arrow::error::ArrowError::InvalidArgumentError(format!(
ArrowError::InvalidArgumentError(format!(
"date_bin origin {origin} + delta {time_delta} overflows i64"
))
.into()
Expand All @@ -341,20 +344,20 @@ fn date_bin_nanos_interval(stride_nanos: i64, source: i64, origin: i64) -> Resul
// distance from origin to bin
fn compute_distance(time_diff: i64, stride: i64) -> Result<i64> {
let remainder = time_diff.checked_rem(stride).ok_or_else(|| {
arrow::error::ArrowError::InvalidArgumentError(format!(
ArrowError::InvalidArgumentError(format!(
"date_bin compute_distance time_diff {time_diff} % stride {stride} overflows i64"
))
})?;
let time_delta = time_diff.checked_sub(remainder).ok_or_else(|| {
arrow::error::ArrowError::InvalidArgumentError(format!(
ArrowError::InvalidArgumentError(format!(
"date_bin compute_distance time_diff {time_diff} - remainder {remainder} overflows i64"
))
})?;

if time_diff < 0 && stride > 1 && time_delta != time_diff {
// The origin is later than the source timestamp, round down to the previous bin
time_delta.checked_sub(stride).ok_or_else(|| {
arrow::error::ArrowError::InvalidArgumentError(format!(
ArrowError::InvalidArgumentError(format!(
"date_bin compute_distance time_delta {time_delta} - stride {stride} overflows i64"
))
.into()
Expand Down Expand Up @@ -594,53 +597,91 @@ fn date_bin_impl(
return exec_err!("DATE_BIN stride must be non-zero");
}

fn stride_map_fn<T: ArrowTimestampType>(
origin: i64,
stride: i64,
stride_fn: BinFunction,
) -> impl Fn(i64) -> Result<i64> {
let scale = match T::UNIT {
fn timestamp_scale<T: ArrowTimestampType>() -> i64 {
match T::UNIT {
Nanosecond => 1,
Microsecond => NANOS_PER_MICRO,
Millisecond => NANOS_PER_MILLI,
Second => NANOSECONDS,
};
move |x: i64| match stride_fn(stride, x * scale, origin) {
Ok(result) => Ok(result / scale),
Err(e) => Err(e),
}
}

fn timestamp_scale_overflow_error(x: i64) -> DataFusionError {
DataFusionError::Execution(format!(
"DATE_BIN source timestamp {x} cannot be represented in nanoseconds"
))
}

Ok(match array {
ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => {
let apply_stride_fn =
stride_map_fn::<TimestampNanosecondType>(origin, stride, stride_fn);
let scale = timestamp_scale::<TimestampNanosecondType>();
ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(
v.and_then(|val| apply_stride_fn(val).ok()),
match *v {
Some(val) => {
let scaled = val
.checked_mul(scale)
.ok_or_else(|| timestamp_scale_overflow_error(val))?;
match stride_fn(stride, scaled, origin) {
Ok(result) => Some(result / scale),
Err(_) => None,
}
Comment thread
xiedeyantu marked this conversation as resolved.
}
None => None,
},
tz_opt.clone(),
))
}
ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(v, tz_opt)) => {
let apply_stride_fn =
stride_map_fn::<TimestampMicrosecondType>(origin, stride, stride_fn);
let scale = timestamp_scale::<TimestampMicrosecondType>();
ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(
v.and_then(|val| apply_stride_fn(val).ok()),
match *v {
Some(val) => {
let scaled = val
.checked_mul(scale)
.ok_or_else(|| timestamp_scale_overflow_error(val))?;
match stride_fn(stride, scaled, origin) {
Ok(result) => Some(result / scale),
Err(_) => None,
}
}
None => None,
},
tz_opt.clone(),
))
}
ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, tz_opt)) => {
let apply_stride_fn =
stride_map_fn::<TimestampMillisecondType>(origin, stride, stride_fn);
let scale = timestamp_scale::<TimestampMillisecondType>();
ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
v.and_then(|val| apply_stride_fn(val).ok()),
match *v {
Some(val) => {
let scaled = val
.checked_mul(scale)
.ok_or_else(|| timestamp_scale_overflow_error(val))?;
match stride_fn(stride, scaled, origin) {
Ok(result) => Some(result / scale),
Err(_) => None,
}
}
None => None,
},
tz_opt.clone(),
))
}
ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => {
let apply_stride_fn =
stride_map_fn::<TimestampSecondType>(origin, stride, stride_fn);
let scale = timestamp_scale::<TimestampSecondType>();
ColumnarValue::Scalar(ScalarValue::TimestampSecond(
v.and_then(|val| apply_stride_fn(val).ok()),
match *v {
Some(val) => {
let scaled = val
.checked_mul(scale)
.ok_or_else(|| timestamp_scale_overflow_error(val))?;
match stride_fn(stride, scaled, origin) {
Ok(result) => Some(result / scale),
Err(_) => None,
}
}
None => None,
},
tz_opt.clone(),
))
}
Expand Down Expand Up @@ -710,20 +751,24 @@ fn date_bin_impl(
T: ArrowTimestampType,
{
let array = as_primitive_array::<T>(array)?;
let scale = match T::UNIT {
Nanosecond => 1,
Microsecond => NANOS_PER_MICRO,
Millisecond => NANOS_PER_MILLI,
Second => NANOSECONDS,
};

let result: PrimitiveArray<T> = array.try_unary(|val| {
stride_fn(stride, val * scale, origin)
.map(|binned| binned / scale)
.map_err(|e| {
arrow::error::ArrowError::ComputeError(e.to_string())
})
})?;
let scale = timestamp_scale::<T>();

let values = array
.iter()
.map(|val| match val {
Some(val) => {
let scaled = val
.checked_mul(scale)
.ok_or_else(|| timestamp_scale_overflow_error(val))?;
Ok(stride_fn(stride, scaled, origin)
.ok()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Jefffrey I will wait for this to get merged as I am doing some cleanup too.

.map(|binned| binned / scale))
}
None => Ok(None),
})
.collect::<Result<Vec<_>>>()?;

let result = PrimitiveArray::<T>::from_iter(values);

let array = result.with_timezone_opt(tz_opt.clone());
Ok(ColumnarValue::Array(Arc::new(array)))
Expand Down Expand Up @@ -764,9 +809,7 @@ fn date_bin_impl(
let nanos = binned_nanos % (NANOSECONDS_IN_DAY);
(nanos / NANOS_PER_MILLI) as i32
})
.map_err(|e| {
arrow::error::ArrowError::ComputeError(e.to_string())
})
.map_err(|e| ArrowError::ComputeError(e.to_string()))
})?;
ColumnarValue::Array(Arc::new(result))
}
Expand All @@ -784,9 +827,7 @@ fn date_bin_impl(
let nanos = binned_nanos % (NANOSECONDS_IN_DAY);
(nanos / NANOS_PER_SEC) as i32
})
.map_err(|e| {
arrow::error::ArrowError::ComputeError(e.to_string())
})
.map_err(|e| ArrowError::ComputeError(e.to_string()))
})?;
ColumnarValue::Array(Arc::new(result))
}
Expand All @@ -804,9 +845,7 @@ fn date_bin_impl(
let nanos = binned_nanos % (NANOSECONDS_IN_DAY);
nanos / NANOS_PER_MICRO
})
.map_err(|e| {
arrow::error::ArrowError::ComputeError(e.to_string())
})
.map_err(|e| ArrowError::ComputeError(e.to_string()))
})?;
ColumnarValue::Array(Arc::new(result))
}
Expand All @@ -821,9 +860,7 @@ fn date_bin_impl(
array.try_unary(|x| {
stride_fn(stride, x, origin)
.map(|binned_nanos| binned_nanos % (NANOSECONDS_IN_DAY))
.map_err(|e| {
arrow::error::ArrowError::ComputeError(e.to_string())
})
.map_err(|e| ArrowError::ComputeError(e.to_string()))
})?;
ColumnarValue::Array(Arc::new(result))
}
Expand Down
21 changes: 20 additions & 1 deletion datafusion/sqllogictest/test_files/date_bin_errors.slt
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,23 @@ select date_bin(
arrow_cast(-9223372036854775808, 'Timestamp(Nanosecond, None)')
);
----
NULL
NULL

# Source timestamp scaling to nanoseconds overflows: should return an error, not panic
query error DataFusion error: Execution error: DATE_BIN source timestamp 9223372036854775807 cannot be represented in nanoseconds
select date_bin(
interval '1 nanosecond',
arrow_cast(9223372036854775807, 'Timestamp(Second, None)'),
timestamp '1970-01-01 00:00:00'
);

# Source timestamp scaling to nanoseconds overflows in array path: should return an error, not panic
query error DataFusion error: Execution error: DATE_BIN source timestamp 9223372036854775807 cannot be represented in nanoseconds
select date_bin(
interval '1 nanosecond',
ts,
timestamp '1970-01-01 00:00:00'
)
from (
values (arrow_cast(9223372036854775807, 'Timestamp(Second, None)'))
) as t(ts);
Loading