From ccf266144d6b3b96bf0012d7135b5b8de73674b8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 3 Jul 2024 08:01:26 -0700 Subject: [PATCH 1/5] fix: Incorrect LEFT JOIN evaluation result on OR conditions (#11203) * fix: Incorrect LEFT JOIN evaluation result on OR conditions * Add a few more test cases * Don't push join filter predicates into join_conditions * Add test case and fix typo * Add test case --------- Co-authored-by: Andrew Lamb --- datafusion/optimizer/src/push_down_filter.rs | 22 ++- datafusion/sqllogictest/test_files/join.slt | 193 +++++++++++++++++++ 2 files changed, 213 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index fa432ad76de53..664fc93a762a5 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -424,8 +424,10 @@ fn push_down_all_join( } } + let mut on_filter_join_conditions = vec![]; + let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type)?; + if !on_filter.is_empty() { - let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type)?; for on in on_filter { if on_left_preserved && can_pushdown_join_predicate(&on, left_schema)? { left_push.push(on) @@ -434,7 +436,7 @@ fn push_down_all_join( { right_push.push(on) } else { - join_conditions.push(on) + on_filter_join_conditions.push(on) } } } @@ -450,6 +452,21 @@ fn push_down_all_join( right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema)); } + // For predicates from join filter, we should check with if a join side is preserved + // in term of join filtering. + if on_left_preserved { + left_push.extend(extract_or_clauses_for_join( + &on_filter_join_conditions, + left_schema, + )); + } + if on_right_preserved { + right_push.extend(extract_or_clauses_for_join( + &on_filter_join_conditions, + right_schema, + )); + } + if let Some(predicate) = conjunction(left_push) { join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?)); } @@ -459,6 +476,7 @@ fn push_down_all_join( } // Add any new join conditions as the non join predicates + join_conditions.extend(on_filter_join_conditions); join.filter = conjunction(join_conditions); // wrap the join on the filter whose predicates must be kept, if any diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index 6732d3e9108b1..3c89109145d70 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -793,3 +793,196 @@ DROP TABLE companies statement ok DROP TABLE leads + +#### +## Test ON clause predicates are not pushed past join for OUTER JOINs +#### + + +# create tables +statement ok +CREATE TABLE employees(emp_id INT, name VARCHAR); + +statement ok +CREATE TABLE department(emp_id INT, dept_name VARCHAR); + +statement ok +INSERT INTO employees (emp_id, name) VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Carol'); + +statement ok +INSERT INTO department (emp_id, dept_name) VALUES (1, 'HR'), (3, 'Engineering'), (4, 'Sales'); + +# Can not push the ON filter below an OUTER JOIN +query TT +EXPLAIN SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +LEFT JOIN department AS d +ON (e.name = 'Alice' OR e.name = 'Bob'); +---- +logical_plan +01)Left Join: Filter: e.name = Utf8("Alice") OR e.name = Utf8("Bob") +02)--SubqueryAlias: e +03)----TableScan: employees projection=[emp_id, name] +04)--SubqueryAlias: d +05)----TableScan: department projection=[dept_name] +physical_plan +01)ProjectionExec: expr=[emp_id@1 as emp_id, name@2 as name, dept_name@0 as dept_name] +02)--NestedLoopJoinExec: join_type=Right, filter=name@0 = Alice OR name@0 = Bob +03)----MemoryExec: partitions=1, partition_sizes=[1] +04)----MemoryExec: partitions=1, partition_sizes=[1] + +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +LEFT JOIN department AS d +ON (e.name = 'Alice' OR e.name = 'Bob'); +---- +1 Alice HR +2 Bob HR +1 Alice Engineering +2 Bob Engineering +1 Alice Sales +2 Bob Sales +3 Carol NULL + +# neither RIGHT OUTER JOIN +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM department AS d +RIGHT JOIN employees AS e +ON (e.name = 'Alice' OR e.name = 'Bob'); +---- +1 Alice HR +2 Bob HR +1 Alice Engineering +2 Bob Engineering +1 Alice Sales +2 Bob Sales +3 Carol NULL + +# neither FULL OUTER JOIN +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM department AS d +FULL JOIN employees AS e +ON (e.name = 'Alice' OR e.name = 'Bob'); +---- +1 Alice HR +2 Bob HR +1 Alice Engineering +2 Bob Engineering +1 Alice Sales +2 Bob Sales +3 Carol NULL + +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM employees e +LEFT JOIN department d +ON (e.name = 'NotExist1' OR e.name = 'NotExist2'); +---- +1 Alice NULL +2 Bob NULL +3 Carol NULL + +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM employees e +LEFT JOIN department d +ON (e.name = 'Alice' OR e.name = 'NotExist'); +---- +1 Alice HR +1 Alice Engineering +1 Alice Sales +2 Bob NULL +3 Carol NULL + +# Can push the ON filter below the JOIN for INNER JOIN (expect to see a filter below the join) +query TT +EXPLAIN SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +JOIN department AS d +ON (e.name = 'Alice' OR e.name = 'Bob'); +---- +logical_plan +01)CrossJoin: +02)--SubqueryAlias: e +03)----Filter: employees.name = Utf8("Alice") OR employees.name = Utf8("Bob") +04)------TableScan: employees projection=[emp_id, name] +05)--SubqueryAlias: d +06)----TableScan: department projection=[dept_name] +physical_plan +01)CrossJoinExec +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: name@1 = Alice OR name@1 = Bob +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)--MemoryExec: partitions=1, partition_sizes=[1] + +# expect no row for Carol +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +JOIN department AS d +ON (e.name = 'Alice' OR e.name = 'Bob'); +---- +1 Alice HR +1 Alice Engineering +1 Alice Sales +2 Bob HR +2 Bob Engineering +2 Bob Sales + +# OR conditions on Filter (not join filter) +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +LEFT JOIN department AS d +ON e.emp_id = d.emp_id +WHERE (e.name = 'Alice' OR e.name = 'Carol'); +---- +1 Alice HR +3 Carol Engineering + +# Push down OR conditions on Filter through LEFT JOIN if possible +query TT +EXPLAIN SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +LEFT JOIN department AS d +ON e.emp_id = d.emp_id +WHERE ((dept_name != 'Engineering' AND e.name = 'Alice') OR (name != 'Alice' AND e.name = 'Carol')); +---- +logical_plan +01)Filter: d.dept_name != Utf8("Engineering") AND e.name = Utf8("Alice") OR e.name != Utf8("Alice") AND e.name = Utf8("Carol") +02)--Projection: e.emp_id, e.name, d.dept_name +03)----Left Join: e.emp_id = d.emp_id +04)------SubqueryAlias: e +05)--------Filter: employees.name = Utf8("Alice") OR employees.name != Utf8("Alice") AND employees.name = Utf8("Carol") +06)----------TableScan: employees projection=[emp_id, name] +07)------SubqueryAlias: d +08)--------TableScan: department projection=[emp_id, dept_name] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: dept_name@2 != Engineering AND name@1 = Alice OR name@1 != Alice AND name@1 = Carol +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------HashJoinExec: mode=CollectLeft, join_type=Left, on=[(emp_id@0, emp_id@0)], projection=[emp_id@0, name@1, dept_name@3] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------FilterExec: name@1 = Alice OR name@1 != Alice AND name@1 = Carol +08)--------------MemoryExec: partitions=1, partition_sizes=[1] +09)----------MemoryExec: partitions=1, partition_sizes=[1] + +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +LEFT JOIN department AS d +ON e.emp_id = d.emp_id +WHERE ((dept_name != 'Engineering' AND e.name = 'Alice') OR (name != 'Alice' AND e.name = 'Carol')); +---- +1 Alice HR +3 Carol Engineering + +statement ok +DROP TABLE employees + +statement ok +DROP TABLE department From 3fbe3d49e5caaf89543e29ddc67820afc1f38d88 Mon Sep 17 00:00:00 2001 From: Chunchun Ye <14298407+appletreeisyellow@users.noreply.github.com> Date: Thu, 11 Jul 2024 11:17:09 -0500 Subject: [PATCH 2/5] feat: add UDF to_local_time() (#11347) * feat: add UDF `to_local_time()` * chore: support column value in array * chore: lint * chore: fix conversion for us, ms, and s * chore: add more tests for daylight savings time * chore: add function description * refactor: update tests and add examples in description * chore: add description and example * chore: doc chore: doc chore: doc chore: doc chore: doc * chore: stop copying * chore: fix typo * chore: mention that the offset varies based on daylight savings time * refactor: parse timezone once and update examples in description * refactor: replace map..concat with flat_map * chore: add hard code timestamp value in test chore: doc chore: doc * chore: handle errors and remove panics * chore: move some test to slt * chore: clone time_value * chore: typo --------- Co-authored-by: Andrew Lamb --- datafusion/functions/src/datetime/mod.rs | 11 +- .../functions/src/datetime/to_local_time.rs | 564 ++++++++++++++++++ .../sqllogictest/test_files/timestamps.slt | 177 ++++++ 3 files changed, 751 insertions(+), 1 deletion(-) create mode 100644 datafusion/functions/src/datetime/to_local_time.rs diff --git a/datafusion/functions/src/datetime/mod.rs b/datafusion/functions/src/datetime/mod.rs index 9c2f80856bf86..a7e9827d6ca69 100644 --- a/datafusion/functions/src/datetime/mod.rs +++ b/datafusion/functions/src/datetime/mod.rs @@ -32,6 +32,7 @@ pub mod make_date; pub mod now; pub mod to_char; pub mod to_date; +pub mod to_local_time; pub mod to_timestamp; pub mod to_unixtime; @@ -50,6 +51,7 @@ make_udf_function!( make_udf_function!(now::NowFunc, NOW, now); make_udf_function!(to_char::ToCharFunc, TO_CHAR, to_char); make_udf_function!(to_date::ToDateFunc, TO_DATE, to_date); +make_udf_function!(to_local_time::ToLocalTimeFunc, TO_LOCAL_TIME, to_local_time); make_udf_function!(to_unixtime::ToUnixtimeFunc, TO_UNIXTIME, to_unixtime); make_udf_function!(to_timestamp::ToTimestampFunc, TO_TIMESTAMP, to_timestamp); make_udf_function!( @@ -108,7 +110,13 @@ pub mod expr_fn { ),( now, "returns the current timestamp in nanoseconds, using the same value for all instances of now() in same statement", - ),( + ), + ( + to_local_time, + "converts a timezone-aware timestamp to local time (with no offset or timezone information), i.e. strips off the timezone from the timestamp", + args, + ), + ( to_unixtime, "converts a string and optional formats to a Unixtime", args, @@ -277,6 +285,7 @@ pub fn functions() -> Vec> { now(), to_char(), to_date(), + to_local_time(), to_unixtime(), to_timestamp(), to_timestamp_seconds(), diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs new file mode 100644 index 0000000000000..c84d1015bd7ee --- /dev/null +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -0,0 +1,564 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::ops::Add; +use std::sync::Arc; + +use arrow::array::timezone::Tz; +use arrow::array::{Array, ArrayRef, PrimitiveBuilder}; +use arrow::datatypes::DataType::Timestamp; +use arrow::datatypes::{ + ArrowTimestampType, DataType, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, +}; +use arrow::datatypes::{ + TimeUnit, + TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}, +}; + +use chrono::{DateTime, MappedLocalTime, Offset, TimeDelta, TimeZone, Utc}; +use datafusion_common::cast::as_primitive_array; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ + ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, +}; + +/// A UDF function that converts a timezone-aware timestamp to local time (with no offset or +/// timezone information). In other words, this function strips off the timezone from the timestamp, +/// while keep the display value of the timestamp the same. +#[derive(Debug)] +pub struct ToLocalTimeFunc { + signature: Signature, +} + +impl Default for ToLocalTimeFunc { + fn default() -> Self { + Self::new() + } +} + +impl ToLocalTimeFunc { + pub fn new() -> Self { + let base_sig = |array_type: TimeUnit| { + [ + Exact(vec![Timestamp(array_type, None)]), + Exact(vec![Timestamp(array_type, Some(TIMEZONE_WILDCARD.into()))]), + ] + }; + + let full_sig = [Nanosecond, Microsecond, Millisecond, Second] + .into_iter() + .flat_map(base_sig) + .collect::>(); + + Self { + signature: Signature::one_of(full_sig, Volatility::Immutable), + } + } + + fn to_local_time(&self, args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!( + "to_local_time function requires 1 argument, got {}", + args.len() + ); + } + + let time_value = &args[0]; + let arg_type = time_value.data_type(); + match arg_type { + DataType::Timestamp(_, None) => { + // if no timezone specificed, just return the input + Ok(time_value.clone()) + } + // If has timezone, adjust the underlying time value. The current time value + // is stored as i64 in UTC, even though the timezone may not be in UTC. Therefore, + // we need to adjust the time value to the local time. See [`adjust_to_local_time`] + // for more details. + // + // Then remove the timezone in return type, i.e. return None + DataType::Timestamp(_, Some(timezone)) => { + let tz: Tz = timezone.parse()?; + + match time_value { + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(ts), + Some(_), + )) => { + let adjusted_ts = + adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(adjusted_ts), + None, + ))) + } + ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( + Some(ts), + Some(_), + )) => { + let adjusted_ts = + adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( + Some(adjusted_ts), + None, + ))) + } + ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( + Some(ts), + Some(_), + )) => { + let adjusted_ts = + adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( + Some(adjusted_ts), + None, + ))) + } + ColumnarValue::Scalar(ScalarValue::TimestampSecond( + Some(ts), + Some(_), + )) => { + let adjusted_ts = + adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampSecond( + Some(adjusted_ts), + None, + ))) + } + ColumnarValue::Array(array) => { + fn transform_array( + array: &ArrayRef, + tz: Tz, + ) -> Result { + let mut builder = PrimitiveBuilder::::new(); + + let primitive_array = as_primitive_array::(array)?; + for ts_opt in primitive_array.iter() { + match ts_opt { + None => builder.append_null(), + Some(ts) => { + let adjusted_ts: i64 = + adjust_to_local_time::(ts, tz)?; + builder.append_value(adjusted_ts) + } + } + } + + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } + + match array.data_type() { + Timestamp(_, None) => { + // if no timezone specificed, just return the input + Ok(time_value.clone()) + } + Timestamp(Nanosecond, Some(_)) => { + transform_array::(array, tz) + } + Timestamp(Microsecond, Some(_)) => { + transform_array::(array, tz) + } + Timestamp(Millisecond, Some(_)) => { + transform_array::(array, tz) + } + Timestamp(Second, Some(_)) => { + transform_array::(array, tz) + } + _ => { + exec_err!("to_local_time function requires timestamp argument in array, got {:?}", array.data_type()) + } + } + } + _ => { + exec_err!( + "to_local_time function requires timestamp argument, got {:?}", + time_value.data_type() + ) + } + } + } + _ => { + exec_err!( + "to_local_time function requires timestamp argument, got {:?}", + arg_type + ) + } + } + } +} + +/// This function converts a timestamp with a timezone to a timestamp without a timezone. +/// The display value of the adjusted timestamp remain the same, but the underlying timestamp +/// representation is adjusted according to the relative timezone offset to UTC. +/// +/// This function uses chrono to handle daylight saving time changes. +/// +/// For example, +/// +/// ```text +/// '2019-03-31T01:00:00Z'::timestamp at time zone 'Europe/Brussels' +/// ``` +/// +/// is displayed as follows in datafusion-cli: +/// +/// ```text +/// 2019-03-31T01:00:00+01:00 +/// ``` +/// +/// and is represented in DataFusion as: +/// +/// ```text +/// TimestampNanosecond(Some(1_553_990_400_000_000_000), Some("Europe/Brussels")) +/// ``` +/// +/// To strip off the timezone while keeping the display value the same, we need to +/// adjust the underlying timestamp with the timezone offset value using `adjust_to_local_time()` +/// +/// ```text +/// adjust_to_local_time(1_553_990_400_000_000_000, "Europe/Brussels") --> 1_553_994_000_000_000_000 +/// ``` +/// +/// The difference between `1_553_990_400_000_000_000` and `1_553_994_000_000_000_000` is +/// `3600_000_000_000` ns, which corresponds to 1 hour. This matches with the timezone +/// offset for "Europe/Brussels" for this date. +/// +/// Note that the offset varies with daylight savings time (DST), which makes this tricky! For +/// example, timezone "Europe/Brussels" has a 2-hour offset during DST and a 1-hour offset +/// when DST ends. +/// +/// Consequently, DataFusion can represent the timestamp in local time (with no offset or +/// timezone information) as +/// +/// ```text +/// TimestampNanosecond(Some(1_553_994_000_000_000_000), None) +/// ``` +/// +/// which is displayed as follows in datafusion-cli: +/// +/// ```text +/// 2019-03-31T01:00:00 +/// ``` +/// +/// See `test_adjust_to_local_time()` for example +fn adjust_to_local_time(ts: i64, tz: Tz) -> Result { + fn convert_timestamp(ts: i64, converter: F) -> Result> + where + F: Fn(i64) -> MappedLocalTime>, + { + match converter(ts) { + MappedLocalTime::Ambiguous(earliest, latest) => exec_err!( + "Ambiguous timestamp. Do you mean {:?} or {:?}", + earliest, + latest + ), + MappedLocalTime::None => exec_err!( + "The local time does not exist because there is a gap in the local time." + ), + MappedLocalTime::Single(date_time) => Ok(date_time), + } + } + + let date_time = match T::UNIT { + Nanosecond => Utc.timestamp_nanos(ts), + Microsecond => convert_timestamp(ts, |ts| Utc.timestamp_micros(ts))?, + Millisecond => convert_timestamp(ts, |ts| Utc.timestamp_millis_opt(ts))?, + Second => convert_timestamp(ts, |ts| Utc.timestamp_opt(ts, 0))?, + }; + + let offset_seconds: i64 = tz + .offset_from_utc_datetime(&date_time.naive_utc()) + .fix() + .local_minus_utc() as i64; + + let adjusted_date_time = date_time.add( + // This should not fail under normal circumstances as the + // maximum possible offset is 26 hours (93,600 seconds) + TimeDelta::try_seconds(offset_seconds) + .ok_or(DataFusionError::Internal("Offset seconds should be less than i64::MAX / 1_000 or greater than -i64::MAX / 1_000".to_string()))?, + ); + + // convert the naive datetime back to i64 + match T::UNIT { + Nanosecond => adjusted_date_time.timestamp_nanos_opt().ok_or( + DataFusionError::Internal( + "Failed to convert DateTime to timestamp in nanosecond. This error may occur if the date is out of range. The supported date ranges are between 1677-09-21T00:12:43.145224192 and 2262-04-11T23:47:16.854775807".to_string(), + ), + ), + Microsecond => Ok(adjusted_date_time.timestamp_micros()), + Millisecond => Ok(adjusted_date_time.timestamp_millis()), + Second => Ok(adjusted_date_time.timestamp()), + } +} + +impl ScalarUDFImpl for ToLocalTimeFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "to_local_time" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 1 { + return exec_err!( + "to_local_time function requires 1 argument, got {:?}", + arg_types.len() + ); + } + + match &arg_types[0] { + Timestamp(Nanosecond, _) => Ok(Timestamp(Nanosecond, None)), + Timestamp(Microsecond, _) => Ok(Timestamp(Microsecond, None)), + Timestamp(Millisecond, _) => Ok(Timestamp(Millisecond, None)), + Timestamp(Second, _) => Ok(Timestamp(Second, None)), + _ => exec_err!( + "The to_local_time function can only accept timestamp as the arg, got {:?}", arg_types[0] + ), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!( + "to_local_time function requires 1 argument, got {:?}", + args.len() + ); + } + + self.to_local_time(args) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::{types::TimestampNanosecondType, TimestampNanosecondArray}; + use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; + use arrow::datatypes::{DataType, TimeUnit}; + use chrono::NaiveDateTime; + use datafusion_common::ScalarValue; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use super::{adjust_to_local_time, ToLocalTimeFunc}; + + #[test] + fn test_adjust_to_local_time() { + let timestamp_str = "2020-03-31T13:40:00"; + let tz: arrow::array::timezone::Tz = + "America/New_York".parse().expect("Invalid timezone"); + + let timestamp = timestamp_str + .parse::() + .unwrap() + .and_local_timezone(tz) // this is in a local timezone + .unwrap() + .timestamp_nanos_opt() + .unwrap(); + + let expected_timestamp = timestamp_str + .parse::() + .unwrap() + .and_utc() // this is in UTC + .timestamp_nanos_opt() + .unwrap(); + + let res = adjust_to_local_time::(timestamp, tz).unwrap(); + assert_eq!(res, expected_timestamp); + } + + #[test] + fn test_to_local_time_scalar() { + let timezone = Some("Europe/Brussels".into()); + let timestamps_with_timezone = vec![ + ( + ScalarValue::TimestampNanosecond( + Some(1_123_123_000_000_000_000), + timezone.clone(), + ), + ScalarValue::TimestampNanosecond(Some(1_123_130_200_000_000_000), None), + ), + ( + ScalarValue::TimestampMicrosecond( + Some(1_123_123_000_000_000), + timezone.clone(), + ), + ScalarValue::TimestampMicrosecond(Some(1_123_130_200_000_000), None), + ), + ( + ScalarValue::TimestampMillisecond( + Some(1_123_123_000_000), + timezone.clone(), + ), + ScalarValue::TimestampMillisecond(Some(1_123_130_200_000), None), + ), + ( + ScalarValue::TimestampSecond(Some(1_123_123_000), timezone), + ScalarValue::TimestampSecond(Some(1_123_130_200), None), + ), + ]; + + for (input, expected) in timestamps_with_timezone { + test_to_local_time_helper(input, expected); + } + } + + #[test] + fn test_timezone_with_daylight_savings() { + let timezone_str = "America/New_York"; + let tz: arrow::array::timezone::Tz = + timezone_str.parse().expect("Invalid timezone"); + + // Test data: + // ( + // the string display of the input timestamp, + // the i64 representation of the timestamp before adjustment in nanosecond, + // the i64 representation of the timestamp after adjustment in nanosecond, + // ) + let test_cases = vec![ + ( + // DST time + "2020-03-31T13:40:00", + 1_585_676_400_000_000_000, + 1_585_662_000_000_000_000, + ), + ( + // End of DST + "2020-11-04T14:06:40", + 1_604_516_800_000_000_000, + 1_604_498_800_000_000_000, + ), + ]; + + for ( + input_timestamp_str, + expected_input_timestamp, + expected_adjusted_timestamp, + ) in test_cases + { + let input_timestamp = input_timestamp_str + .parse::() + .unwrap() + .and_local_timezone(tz) // this is in a local timezone + .unwrap() + .timestamp_nanos_opt() + .unwrap(); + assert_eq!(input_timestamp, expected_input_timestamp); + + let expected_timestamp = input_timestamp_str + .parse::() + .unwrap() + .and_utc() // this is in UTC + .timestamp_nanos_opt() + .unwrap(); + assert_eq!(expected_timestamp, expected_adjusted_timestamp); + + let input = ScalarValue::TimestampNanosecond( + Some(input_timestamp), + Some(timezone_str.into()), + ); + let expected = + ScalarValue::TimestampNanosecond(Some(expected_timestamp), None); + test_to_local_time_helper(input, expected) + } + } + + fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) { + let res = ToLocalTimeFunc::new() + .invoke(&[ColumnarValue::Scalar(input)]) + .unwrap(); + match res { + ColumnarValue::Scalar(res) => { + assert_eq!(res, expected); + } + _ => panic!("unexpected return type"), + } + } + + #[test] + fn test_to_local_time_timezones_array() { + let cases = [ + ( + vec![ + "2020-09-08T00:00:00", + "2020-09-08T01:00:00", + "2020-09-08T02:00:00", + "2020-09-08T03:00:00", + "2020-09-08T04:00:00", + ], + None::>, + vec![ + "2020-09-08T00:00:00", + "2020-09-08T01:00:00", + "2020-09-08T02:00:00", + "2020-09-08T03:00:00", + "2020-09-08T04:00:00", + ], + ), + ( + vec![ + "2020-09-08T00:00:00", + "2020-09-08T01:00:00", + "2020-09-08T02:00:00", + "2020-09-08T03:00:00", + "2020-09-08T04:00:00", + ], + Some("+01:00".into()), + vec![ + "2020-09-08T00:00:00", + "2020-09-08T01:00:00", + "2020-09-08T02:00:00", + "2020-09-08T03:00:00", + "2020-09-08T04:00:00", + ], + ), + ]; + + cases.iter().for_each(|(source, _tz_opt, expected)| { + let input = source + .iter() + .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) + .collect::(); + let right = expected + .iter() + .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) + .collect::(); + let result = ToLocalTimeFunc::new() + .invoke(&[ColumnarValue::Array(Arc::new(input))]) + .unwrap(); + if let ColumnarValue::Array(result) = result { + assert_eq!( + result.data_type(), + &DataType::Timestamp(TimeUnit::Nanosecond, None) + ); + let left = arrow::array::cast::as_primitive_array::< + TimestampNanosecondType, + >(&result); + assert_eq!(left, &right); + } else { + panic!("unexpected column type"); + } + }); + } +} diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index 2216dbfa5fd58..f4e492649b9f8 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -2844,3 +2844,180 @@ select arrow_cast('2024-06-17T13:00:00', 'Timestamp(Nanosecond, Some("UTC"))') - query error select arrow_cast('2024-06-17T13:00:00', 'Timestamp(Nanosecond, Some("+00:00"))') - arrow_cast('2024-06-17T12:00:00', 'Timestamp(Microsecond, Some("+01:00"))'); + +########## +## Test to_local_time function +########## + +# invalid number of arguments -- no argument +statement error +select to_local_time(); + +# invalid number of arguments -- more than 1 argument +statement error +select to_local_time('2024-04-01T00:00:20Z'::timestamp, 'some string'); + +# invalid argument data type +statement error DataFusion error: Execution error: The to_local_time function can only accept timestamp as the arg, got Utf8 +select to_local_time('2024-04-01T00:00:20Z'); + +# invalid timezone +statement error DataFusion error: Arrow error: Parser error: Invalid timezone "Europe/timezone": failed to parse timezone +select to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/timezone'); + +# valid query +query P +select to_local_time('2024-04-01T00:00:20Z'::timestamp); +---- +2024-04-01T00:00:20 + +query P +select to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE '+05:00'); +---- +2024-04-01T00:00:20 + +query P +select to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels'); +---- +2024-04-01T00:00:20 + +query PTPT +select + time, + arrow_typeof(time) as type, + to_local_time(time) as to_local_time, + arrow_typeof(to_local_time(time)) as to_local_time_type +from ( + select '2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels' as time +); +---- +2024-04-01T00:00:20+02:00 Timestamp(Nanosecond, Some("Europe/Brussels")) 2024-04-01T00:00:20 Timestamp(Nanosecond, None) + +# use to_local_time() in date_bin() +query P +select date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')); +---- +2024-04-01T00:00:00 + +query P +select date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AT TIME ZONE 'Europe/Brussels'; +---- +2024-04-01T00:00:00+02:00 + +# test using to_local_time() on array values +statement ok +create table t AS +VALUES + ('2024-01-01T00:00:01Z'), + ('2024-02-01T00:00:01Z'), + ('2024-03-01T00:00:01Z'), + ('2024-04-01T00:00:01Z'), + ('2024-05-01T00:00:01Z'), + ('2024-06-01T00:00:01Z'), + ('2024-07-01T00:00:01Z'), + ('2024-08-01T00:00:01Z'), + ('2024-09-01T00:00:01Z'), + ('2024-10-01T00:00:01Z'), + ('2024-11-01T00:00:01Z'), + ('2024-12-01T00:00:01Z') +; + +statement ok +create view t_utc as +select column1::timestamp AT TIME ZONE 'UTC' as "column1" +from t; + +statement ok +create view t_timezone as +select column1::timestamp AT TIME ZONE 'Europe/Brussels' as "column1" +from t; + +query PPT +select column1, to_local_time(column1::timestamp), arrow_typeof(to_local_time(column1::timestamp)) from t_utc; +---- +2024-01-01T00:00:01Z 2024-01-01T00:00:01 Timestamp(Nanosecond, None) +2024-02-01T00:00:01Z 2024-02-01T00:00:01 Timestamp(Nanosecond, None) +2024-03-01T00:00:01Z 2024-03-01T00:00:01 Timestamp(Nanosecond, None) +2024-04-01T00:00:01Z 2024-04-01T00:00:01 Timestamp(Nanosecond, None) +2024-05-01T00:00:01Z 2024-05-01T00:00:01 Timestamp(Nanosecond, None) +2024-06-01T00:00:01Z 2024-06-01T00:00:01 Timestamp(Nanosecond, None) +2024-07-01T00:00:01Z 2024-07-01T00:00:01 Timestamp(Nanosecond, None) +2024-08-01T00:00:01Z 2024-08-01T00:00:01 Timestamp(Nanosecond, None) +2024-09-01T00:00:01Z 2024-09-01T00:00:01 Timestamp(Nanosecond, None) +2024-10-01T00:00:01Z 2024-10-01T00:00:01 Timestamp(Nanosecond, None) +2024-11-01T00:00:01Z 2024-11-01T00:00:01 Timestamp(Nanosecond, None) +2024-12-01T00:00:01Z 2024-12-01T00:00:01 Timestamp(Nanosecond, None) + +query PPT +select column1, to_local_time(column1), arrow_typeof(to_local_time(column1)) from t_utc; +---- +2024-01-01T00:00:01Z 2024-01-01T00:00:01 Timestamp(Nanosecond, None) +2024-02-01T00:00:01Z 2024-02-01T00:00:01 Timestamp(Nanosecond, None) +2024-03-01T00:00:01Z 2024-03-01T00:00:01 Timestamp(Nanosecond, None) +2024-04-01T00:00:01Z 2024-04-01T00:00:01 Timestamp(Nanosecond, None) +2024-05-01T00:00:01Z 2024-05-01T00:00:01 Timestamp(Nanosecond, None) +2024-06-01T00:00:01Z 2024-06-01T00:00:01 Timestamp(Nanosecond, None) +2024-07-01T00:00:01Z 2024-07-01T00:00:01 Timestamp(Nanosecond, None) +2024-08-01T00:00:01Z 2024-08-01T00:00:01 Timestamp(Nanosecond, None) +2024-09-01T00:00:01Z 2024-09-01T00:00:01 Timestamp(Nanosecond, None) +2024-10-01T00:00:01Z 2024-10-01T00:00:01 Timestamp(Nanosecond, None) +2024-11-01T00:00:01Z 2024-11-01T00:00:01 Timestamp(Nanosecond, None) +2024-12-01T00:00:01Z 2024-12-01T00:00:01 Timestamp(Nanosecond, None) + +query PPT +select column1, to_local_time(column1), arrow_typeof(to_local_time(column1)) from t_timezone; +---- +2024-01-01T00:00:01+01:00 2024-01-01T00:00:01 Timestamp(Nanosecond, None) +2024-02-01T00:00:01+01:00 2024-02-01T00:00:01 Timestamp(Nanosecond, None) +2024-03-01T00:00:01+01:00 2024-03-01T00:00:01 Timestamp(Nanosecond, None) +2024-04-01T00:00:01+02:00 2024-04-01T00:00:01 Timestamp(Nanosecond, None) +2024-05-01T00:00:01+02:00 2024-05-01T00:00:01 Timestamp(Nanosecond, None) +2024-06-01T00:00:01+02:00 2024-06-01T00:00:01 Timestamp(Nanosecond, None) +2024-07-01T00:00:01+02:00 2024-07-01T00:00:01 Timestamp(Nanosecond, None) +2024-08-01T00:00:01+02:00 2024-08-01T00:00:01 Timestamp(Nanosecond, None) +2024-09-01T00:00:01+02:00 2024-09-01T00:00:01 Timestamp(Nanosecond, None) +2024-10-01T00:00:01+02:00 2024-10-01T00:00:01 Timestamp(Nanosecond, None) +2024-11-01T00:00:01+01:00 2024-11-01T00:00:01 Timestamp(Nanosecond, None) +2024-12-01T00:00:01+01:00 2024-12-01T00:00:01 Timestamp(Nanosecond, None) + +# combine to_local_time() with date_bin() +query P +select date_bin(interval '1 day', to_local_time(column1)) AT TIME ZONE 'Europe/Brussels' as date_bin from t_utc; +---- +2024-01-01T00:00:00+01:00 +2024-02-01T00:00:00+01:00 +2024-03-01T00:00:00+01:00 +2024-04-01T00:00:00+02:00 +2024-05-01T00:00:00+02:00 +2024-06-01T00:00:00+02:00 +2024-07-01T00:00:00+02:00 +2024-08-01T00:00:00+02:00 +2024-09-01T00:00:00+02:00 +2024-10-01T00:00:00+02:00 +2024-11-01T00:00:00+01:00 +2024-12-01T00:00:00+01:00 + +query P +select date_bin(interval '1 day', to_local_time(column1)) AT TIME ZONE 'Europe/Brussels' as date_bin from t_timezone; +---- +2024-01-01T00:00:00+01:00 +2024-02-01T00:00:00+01:00 +2024-03-01T00:00:00+01:00 +2024-04-01T00:00:00+02:00 +2024-05-01T00:00:00+02:00 +2024-06-01T00:00:00+02:00 +2024-07-01T00:00:00+02:00 +2024-08-01T00:00:00+02:00 +2024-09-01T00:00:00+02:00 +2024-10-01T00:00:00+02:00 +2024-11-01T00:00:00+01:00 +2024-12-01T00:00:00+01:00 + +statement ok +drop table t; + +statement ok +drop view t_utc; + +statement ok +drop view t_timezone; From 4f9849440e59058bb0a14929b9348a71375a55c4 Mon Sep 17 00:00:00 2001 From: wiedld Date: Wed, 10 Jul 2024 11:21:01 -0700 Subject: [PATCH 3/5] Track parquet writer encoding memory usage on MemoryPool (#11345) * feat(11344): track memory used for non-parallel writes * feat(11344): track memory usage during parallel writes * test(11344): create bounded stream for testing * test(11344): test ParquetSink memory reservation * feat(11344): track bytes in file writer * refactor(11344): tweak the ordering to add col bytes to rg_reservation, before selecting shrinking for data bytes flushed * refactor: move each col_reservation and rg_reservation to match the parallelized call stack for col vs rg * test(11344): add memory_limit enforcement test for parquet sink * chore: cleanup to remove unnecessary reservation management steps * fix: fix CI test failure due to file extension rename --- .../src/datasource/file_format/parquet.rs | 165 ++++++++++++++++-- datafusion/core/src/test_util/mod.rs | 36 ++++ datafusion/core/tests/memory_limit/mod.rs | 25 +++ 3 files changed, 216 insertions(+), 10 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 27d783cd89b5f..694c949285374 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -48,6 +48,7 @@ use datafusion_common::{ DEFAULT_PARQUET_EXTENSION, }; use datafusion_common_runtime::SpawnedTask; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_physical_expr::expressions::{MaxAccumulator, MinAccumulator}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; @@ -749,9 +750,13 @@ impl DataSink for ParquetSink { parquet_props.writer_options().clone(), ) .await?; + let mut reservation = + MemoryConsumer::new(format!("ParquetSink[{}]", path)) + .register(context.memory_pool()); file_write_tasks.spawn(async move { while let Some(batch) = rx.recv().await { writer.write(&batch).await?; + reservation.try_resize(writer.memory_size())?; } let file_metadata = writer .close() @@ -771,6 +776,7 @@ impl DataSink for ParquetSink { let schema = self.get_writer_schema(); let props = parquet_props.clone(); let parallel_options_clone = parallel_options.clone(); + let pool = Arc::clone(context.memory_pool()); file_write_tasks.spawn(async move { let file_metadata = output_single_parquet_file_parallelized( writer, @@ -778,6 +784,7 @@ impl DataSink for ParquetSink { schema, props.writer_options(), parallel_options_clone, + pool, ) .await?; Ok((path, file_metadata)) @@ -818,14 +825,16 @@ impl DataSink for ParquetSink { async fn column_serializer_task( mut rx: Receiver, mut writer: ArrowColumnWriter, -) -> Result { + mut reservation: MemoryReservation, +) -> Result<(ArrowColumnWriter, MemoryReservation)> { while let Some(col) = rx.recv().await { writer.write(&col)?; + reservation.try_resize(writer.memory_size())?; } - Ok(writer) + Ok((writer, reservation)) } -type ColumnWriterTask = SpawnedTask>; +type ColumnWriterTask = SpawnedTask>; type ColSender = Sender; /// Spawns a parallel serialization task for each column @@ -835,6 +844,7 @@ fn spawn_column_parallel_row_group_writer( schema: Arc, parquet_props: Arc, max_buffer_size: usize, + pool: &Arc, ) -> Result<(Vec, Vec)> { let schema_desc = arrow_to_parquet_schema(&schema)?; let col_writers = get_column_writers(&schema_desc, &parquet_props, &schema)?; @@ -848,7 +858,13 @@ fn spawn_column_parallel_row_group_writer( mpsc::channel::(max_buffer_size); col_array_channels.push(send_array); - let task = SpawnedTask::spawn(column_serializer_task(recieve_array, writer)); + let reservation = + MemoryConsumer::new("ParquetSink(ArrowColumnWriter)").register(pool); + let task = SpawnedTask::spawn(column_serializer_task( + recieve_array, + writer, + reservation, + )); col_writer_tasks.push(task); } @@ -864,7 +880,7 @@ struct ParallelParquetWriterOptions { /// This is the return type of calling [ArrowColumnWriter].close() on each column /// i.e. the Vec of encoded columns which can be appended to a row group -type RBStreamSerializeResult = Result<(Vec, usize)>; +type RBStreamSerializeResult = Result<(Vec, MemoryReservation, usize)>; /// Sends the ArrowArrays in passed [RecordBatch] through the channels to their respective /// parallel column serializers. @@ -895,16 +911,22 @@ async fn send_arrays_to_col_writers( fn spawn_rg_join_and_finalize_task( column_writer_tasks: Vec, rg_rows: usize, + pool: &Arc, ) -> SpawnedTask { + let mut rg_reservation = + MemoryConsumer::new("ParquetSink(SerializedRowGroupWriter)").register(pool); + SpawnedTask::spawn(async move { let num_cols = column_writer_tasks.len(); let mut finalized_rg = Vec::with_capacity(num_cols); for task in column_writer_tasks.into_iter() { - let writer = task.join_unwind().await?; + let (writer, _col_reservation) = task.join_unwind().await?; + let encoded_size = writer.get_estimated_total_bytes(); + rg_reservation.grow(encoded_size); finalized_rg.push(writer.close()?); } - Ok((finalized_rg, rg_rows)) + Ok((finalized_rg, rg_reservation, rg_rows)) }) } @@ -922,6 +944,7 @@ fn spawn_parquet_parallel_serialization_task( schema: Arc, writer_props: Arc, parallel_options: ParallelParquetWriterOptions, + pool: Arc, ) -> SpawnedTask> { SpawnedTask::spawn(async move { let max_buffer_rb = parallel_options.max_buffered_record_batches_per_stream; @@ -931,6 +954,7 @@ fn spawn_parquet_parallel_serialization_task( schema.clone(), writer_props.clone(), max_buffer_rb, + &pool, )?; let mut current_rg_rows = 0; @@ -957,6 +981,7 @@ fn spawn_parquet_parallel_serialization_task( let finalize_rg_task = spawn_rg_join_and_finalize_task( column_writer_handles, max_row_group_rows, + &pool, ); serialize_tx.send(finalize_rg_task).await.map_err(|_| { @@ -973,6 +998,7 @@ fn spawn_parquet_parallel_serialization_task( schema.clone(), writer_props.clone(), max_buffer_rb, + &pool, )?; } } @@ -981,8 +1007,11 @@ fn spawn_parquet_parallel_serialization_task( drop(col_array_channels); // Handle leftover rows as final rowgroup, which may be smaller than max_row_group_rows if current_rg_rows > 0 { - let finalize_rg_task = - spawn_rg_join_and_finalize_task(column_writer_handles, current_rg_rows); + let finalize_rg_task = spawn_rg_join_and_finalize_task( + column_writer_handles, + current_rg_rows, + &pool, + ); serialize_tx.send(finalize_rg_task).await.map_err(|_| { DataFusionError::Internal( @@ -1002,9 +1031,13 @@ async fn concatenate_parallel_row_groups( schema: Arc, writer_props: Arc, mut object_store_writer: Box, + pool: Arc, ) -> Result { let merged_buff = SharedBuffer::new(INITIAL_BUFFER_BYTES); + let mut file_reservation = + MemoryConsumer::new("ParquetSink(SerializedFileWriter)").register(&pool); + let schema_desc = arrow_to_parquet_schema(schema.as_ref())?; let mut parquet_writer = SerializedFileWriter::new( merged_buff.clone(), @@ -1015,15 +1048,20 @@ async fn concatenate_parallel_row_groups( while let Some(task) = serialize_rx.recv().await { let result = task.join_unwind().await; let mut rg_out = parquet_writer.next_row_group()?; - let (serialized_columns, _cnt) = result?; + let (serialized_columns, mut rg_reservation, _cnt) = result?; for chunk in serialized_columns { chunk.append_to_row_group(&mut rg_out)?; + rg_reservation.free(); + let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap(); + file_reservation.try_resize(buff_to_flush.len())?; + if buff_to_flush.len() > BUFFER_FLUSH_BYTES { object_store_writer .write_all(buff_to_flush.as_slice()) .await?; buff_to_flush.clear(); + file_reservation.try_resize(buff_to_flush.len())?; // will set to zero } } rg_out.close()?; @@ -1034,6 +1072,7 @@ async fn concatenate_parallel_row_groups( object_store_writer.write_all(final_buff.as_slice()).await?; object_store_writer.shutdown().await?; + file_reservation.free(); Ok(file_metadata) } @@ -1048,6 +1087,7 @@ async fn output_single_parquet_file_parallelized( output_schema: Arc, parquet_props: &WriterProperties, parallel_options: ParallelParquetWriterOptions, + pool: Arc, ) -> Result { let max_rowgroups = parallel_options.max_parallel_row_groups; // Buffer size of this channel limits maximum number of RowGroups being worked on in parallel @@ -1061,12 +1101,14 @@ async fn output_single_parquet_file_parallelized( output_schema.clone(), arc_props.clone(), parallel_options, + Arc::clone(&pool), ); let file_metadata = concatenate_parallel_row_groups( serialize_rx, output_schema.clone(), arc_props.clone(), object_store_writer, + pool, ) .await?; @@ -1158,8 +1200,10 @@ mod tests { use super::super::test_util::scan_format; use crate::datasource::listing::{ListingTableUrl, PartitionedFile}; use crate::physical_plan::collect; + use crate::test_util::bounded_stream; use std::fmt::{Display, Formatter}; use std::sync::atomic::{AtomicUsize, Ordering}; + use std::time::Duration; use super::*; @@ -2177,4 +2221,105 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn parquet_sink_write_memory_reservation() -> Result<()> { + async fn test_memory_reservation(global: ParquetOptions) -> Result<()> { + let field_a = Field::new("a", DataType::Utf8, false); + let field_b = Field::new("b", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let object_store_url = ObjectStoreUrl::local_filesystem(); + + let file_sink_config = FileSinkConfig { + object_store_url: object_store_url.clone(), + file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![], + overwrite: true, + keep_partition_by_columns: false, + }; + let parquet_sink = Arc::new(ParquetSink::new( + file_sink_config, + TableParquetOptions { + key_value_metadata: std::collections::HashMap::from([ + ("my-data".to_string(), Some("stuff".to_string())), + ("my-data-bool-key".to_string(), None), + ]), + global, + ..Default::default() + }, + )); + + // create data + let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"])); + let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"])); + let batch = + RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap(); + + // create task context + let task_context = build_ctx(object_store_url.as_ref()); + assert_eq!( + task_context.memory_pool().reserved(), + 0, + "no bytes are reserved yet" + ); + + let mut write_task = parquet_sink.write_all( + Box::pin(RecordBatchStreamAdapter::new( + schema, + bounded_stream(batch, 1000), + )), + &task_context, + ); + + // incrementally poll and check for memory reservation + let mut reserved_bytes = 0; + while futures::poll!(&mut write_task).is_pending() { + reserved_bytes += task_context.memory_pool().reserved(); + tokio::time::sleep(Duration::from_micros(1)).await; + } + assert!( + reserved_bytes > 0, + "should have bytes reserved during write" + ); + assert_eq!( + task_context.memory_pool().reserved(), + 0, + "no leaking byte reservation" + ); + + Ok(()) + } + + let write_opts = ParquetOptions { + allow_single_file_parallelism: false, + ..Default::default() + }; + test_memory_reservation(write_opts) + .await + .expect("should track for non-parallel writes"); + + let row_parallel_write_opts = ParquetOptions { + allow_single_file_parallelism: true, + maximum_parallel_row_group_writers: 10, + maximum_buffered_record_batches_per_stream: 1, + ..Default::default() + }; + test_memory_reservation(row_parallel_write_opts) + .await + .expect("should track for row-parallel writes"); + + let col_parallel_write_opts = ParquetOptions { + allow_single_file_parallelism: true, + maximum_parallel_row_group_writers: 1, + maximum_buffered_record_batches_per_stream: 2, + ..Default::default() + }; + test_memory_reservation(col_parallel_write_opts) + .await + .expect("should track for column-parallel writes"); + + Ok(()) + } } diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index 059fa8fc6da77..ba0509f3f51ac 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -366,3 +366,39 @@ pub fn register_unbounded_file_with_ordering( ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; Ok(()) } + +struct BoundedStream { + limit: usize, + count: usize, + batch: RecordBatch, +} + +impl Stream for BoundedStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + if self.count >= self.limit { + return Poll::Ready(None); + } + self.count += 1; + Poll::Ready(Some(Ok(self.batch.clone()))) + } +} + +impl RecordBatchStream for BoundedStream { + fn schema(&self) -> SchemaRef { + self.batch.schema() + } +} + +/// Creates an bounded stream for testing purposes. +pub fn bounded_stream(batch: RecordBatch, limit: usize) -> SendableRecordBatchStream { + Box::pin(BoundedStream { + count: 0, + limit, + batch, + }) +} diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index f61ee5d9ab984..f7402357d1c76 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -31,6 +31,7 @@ use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; use futures::StreamExt; use std::any::Any; use std::sync::{Arc, OnceLock}; +use tokio::fs::File; use datafusion::datasource::streaming::StreamingTable; use datafusion::datasource::{MemTable, TableProvider}; @@ -323,6 +324,30 @@ async fn oom_recursive_cte() { .await } +#[tokio::test] +async fn oom_parquet_sink() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.into_path().join("test.parquet"); + let _ = File::create(path.clone()).await.unwrap(); + + TestCase::new() + .with_query(format!( + " + COPY (select * from t) + TO '{}' + STORED AS PARQUET OPTIONS (compression 'uncompressed'); + ", + path.to_string_lossy() + )) + .with_expected_errors(vec![ + // TODO: update error handling in ParquetSink + "Unable to send array to writer!", + ]) + .with_memory_limit(200_000) + .run() + .await +} + /// Run the query with the specified memory limit, /// and verifies the expected errors are returned #[derive(Clone, Debug)] From fac9e696adb7085ed5658bdc827e17eeb53b4ea3 Mon Sep 17 00:00:00 2001 From: wiedld Date: Fri, 12 Jul 2024 04:04:42 -0700 Subject: [PATCH 4/5] fix(11397): surface proper errors in ParquetSink (#11399) * fix(11397): do not surface errors for closed channels, and instead let the task join errors be surfaced * fix(11397): terminate early on channel send failure Add Optimizer Sanity Checker, improve sortedness equivalence properties (#11196) * Initial optimizer sanity checker. Only includes sort reqs, docs will be added. * Add distro and pipeline friendly checks * Also check the plans we create are correct. * Add distribution test cases using global limit exec. * Add test for multiple children using SortMergeJoinExec. * Move PipelineChecker to SanityCheckPlan * Fix some tests and add docs * Add some test docs and fix clippy diagnostics. * Fix some failing tests * Replace PipelineChecker with SanityChecker in .slt files. * Initial commit * Slt tests pass * Resolve linter errors * Minor changes * Minor changes * Minor changes * Minor changes * Sort PreservingMerge clear per partition * Minor changes * Update output_requirements.rs * Address reviews * Update datafusion/core/src/physical_optimizer/optimizer.rs Co-authored-by: Mehmet Ozan Kabak * Update datafusion/core/src/physical_optimizer/sanity_checker.rs Co-authored-by: Mehmet Ozan Kabak * Address reviews * Minor changes * Apply suggestions from code review Co-authored-by: Andrew Lamb * Update comment * Add map implementation --------- Co-authored-by: Erman Yafay Co-authored-by: berkaysynnada Co-authored-by: Mehmet Ozan Kabak Co-authored-by: Andrew Lamb --- .../src/datasource/file_format/parquet.rs | 32 +- datafusion/core/src/physical_optimizer/mod.rs | 2 +- .../core/src/physical_optimizer/optimizer.rs | 16 +- .../physical_optimizer/output_requirements.rs | 4 +- .../physical_optimizer/pipeline_checker.rs | 334 --------- .../src/physical_optimizer/sanity_checker.rs | 666 ++++++++++++++++++ .../src/physical_optimizer/sort_pushdown.rs | 31 +- .../sort_preserving_repartition_fuzz.rs | 6 +- datafusion/core/tests/memory_limit/mod.rs | 4 +- datafusion/core/tests/sql/joins.rs | 2 +- .../physical-expr/src/equivalence/class.rs | 67 ++ .../physical-expr/src/equivalence/mod.rs | 6 +- .../physical-expr/src/equivalence/ordering.rs | 6 +- .../src/equivalence/properties.rs | 115 ++- datafusion/physical-expr/src/lib.rs | 2 +- .../physical-plan/src/coalesce_partitions.rs | 2 +- datafusion/physical-plan/src/filter.rs | 20 +- .../physical-plan/src/repartition/mod.rs | 5 + .../src/sorts/sort_preserving_merge.rs | 16 +- datafusion/physical-plan/src/union.rs | 96 ++- datafusion/physical-plan/src/windows/mod.rs | 8 +- .../sqllogictest/test_files/aggregate.slt | 2 +- .../sqllogictest/test_files/explain.slt | 6 +- datafusion/sqllogictest/test_files/joins.slt | 50 +- datafusion/sqllogictest/test_files/window.slt | 6 +- 25 files changed, 1027 insertions(+), 477 deletions(-) delete mode 100644 datafusion/core/src/physical_optimizer/pipeline_checker.rs create mode 100644 datafusion/core/src/physical_optimizer/sanity_checker.rs diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 694c949285374..6271d8af37862 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -893,12 +893,12 @@ async fn send_arrays_to_col_writers( let mut next_channel = 0; for (array, field) in rb.columns().iter().zip(schema.fields()) { for c in compute_leaves(field, array)? { - col_array_channels[next_channel] - .send(c) - .await - .map_err(|_| { - DataFusionError::Internal("Unable to send array to writer!".into()) - })?; + // Do not surface error from closed channel (means something + // else hit an error, and the plan is shutting down). + if col_array_channels[next_channel].send(c).await.is_err() { + return Ok(()); + } + next_channel += 1; } } @@ -984,11 +984,11 @@ fn spawn_parquet_parallel_serialization_task( &pool, ); - serialize_tx.send(finalize_rg_task).await.map_err(|_| { - DataFusionError::Internal( - "Unable to send closed RG to concat task!".into(), - ) - })?; + // Do not surface error from closed channel (means something + // else hit an error, and the plan is shutting down). + if serialize_tx.send(finalize_rg_task).await.is_err() { + return Ok(()); + } current_rg_rows = 0; rb = rb.slice(rows_left, rb.num_rows() - rows_left); @@ -1013,11 +1013,11 @@ fn spawn_parquet_parallel_serialization_task( &pool, ); - serialize_tx.send(finalize_rg_task).await.map_err(|_| { - DataFusionError::Internal( - "Unable to send closed RG to concat task!".into(), - ) - })?; + // Do not surface error from closed channel (means something + // else hit an error, and the plan is shutting down). + if serialize_tx.send(finalize_rg_task).await.is_err() { + return Ok(()); + } } Ok(()) diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index e1bde36bd6fed..9ad05bf496e59 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -30,10 +30,10 @@ pub mod join_selection; pub mod limited_distinct_aggregation; pub mod optimizer; pub mod output_requirements; -pub mod pipeline_checker; pub mod projection_pushdown; pub mod pruning; pub mod replace_with_order_preserving_variants; +pub mod sanity_checker; mod sort_pushdown; pub mod topk_aggregation; pub mod update_aggr_exprs; diff --git a/datafusion/core/src/physical_optimizer/optimizer.rs b/datafusion/core/src/physical_optimizer/optimizer.rs index 6880a5433943e..2d9744ad23dd3 100644 --- a/datafusion/core/src/physical_optimizer/optimizer.rs +++ b/datafusion/core/src/physical_optimizer/optimizer.rs @@ -30,7 +30,7 @@ use crate::physical_optimizer::enforce_sorting::EnforceSorting; use crate::physical_optimizer::join_selection::JoinSelection; use crate::physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggregation; use crate::physical_optimizer::output_requirements::OutputRequirements; -use crate::physical_optimizer::pipeline_checker::PipelineChecker; +use crate::physical_optimizer::sanity_checker::SanityCheckPlan; use crate::physical_optimizer::topk_aggregation::TopKAggregation; use crate::{error::Result, physical_plan::ExecutionPlan}; @@ -124,11 +124,15 @@ impl PhysicalOptimizer { // are not present, the load of executors such as join or union will be // reduced by narrowing their input tables. Arc::new(ProjectionPushdown::new()), - // The PipelineChecker rule will reject non-runnable query plans that use - // pipeline-breaking operators on infinite input(s). The rule generates a - // diagnostic error message when this happens. It makes no changes to the - // given query plan; i.e. it only acts as a final gatekeeping rule. - Arc::new(PipelineChecker::new()), + // The SanityCheckPlan rule checks whether the order and + // distribution requirements of each node in the plan + // is satisfied. It will also reject non-runnable query + // plans that use pipeline-breaking operators on infinite + // input(s). The rule generates a diagnostic error + // message for invalid plans. It makes no changes to the + // given query plan; i.e. it only acts as a final + // gatekeeping rule. + Arc::new(SanityCheckPlan::new()), ]; Self::with_rules(rules) diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/core/src/physical_optimizer/output_requirements.rs index 67b38dba90ca0..671bb437d5fa2 100644 --- a/datafusion/core/src/physical_optimizer/output_requirements.rs +++ b/datafusion/core/src/physical_optimizer/output_requirements.rs @@ -248,7 +248,9 @@ fn require_top_ordering_helper( if children.len() != 1 { Ok((plan, false)) } else if let Some(sort_exec) = plan.as_any().downcast_ref::() { - let req_ordering = sort_exec.properties().output_ordering().unwrap_or(&[]); + // In case of constant columns, output ordering of SortExec would give an empty set. + // Therefore; we check the sort expression field of the SortExec to assign the requirements. + let req_ordering = sort_exec.expr(); let req_dist = sort_exec.required_input_distribution()[0].clone(); let reqs = PhysicalSortRequirement::from_sort_exprs(req_ordering); Ok(( diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs deleted file mode 100644 index 5c6a0ab8ea7fa..0000000000000 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ /dev/null @@ -1,334 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -//http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! The [PipelineChecker] rule ensures that a given plan can accommodate its -//! infinite sources, if there are any. It will reject non-runnable query plans -//! that use pipeline-breaking operators on infinite input(s). - -use std::sync::Arc; - -use crate::config::ConfigOptions; -use crate::error::Result; -use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; - -use datafusion_common::config::OptimizerOptions; -use datafusion_common::plan_err; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; -use datafusion_physical_plan::joins::SymmetricHashJoinExec; - -/// The PipelineChecker rule rejects non-runnable query plans that use -/// pipeline-breaking operators on infinite input(s). -#[derive(Default)] -pub struct PipelineChecker {} - -impl PipelineChecker { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -impl PhysicalOptimizerRule for PipelineChecker { - fn optimize( - &self, - plan: Arc, - config: &ConfigOptions, - ) -> Result> { - plan.transform_up(|p| check_finiteness_requirements(p, &config.optimizer)) - .data() - } - - fn name(&self) -> &str { - "PipelineChecker" - } - - fn schema_check(&self) -> bool { - true - } -} - -/// This function propagates finiteness information and rejects any plan with -/// pipeline-breaking operators acting on infinite inputs. -pub fn check_finiteness_requirements( - input: Arc, - optimizer_options: &OptimizerOptions, -) -> Result>> { - if let Some(exec) = input.as_any().downcast_ref::() { - if !(optimizer_options.allow_symmetric_joins_without_pruning - || (exec.check_if_order_information_available()? && is_prunable(exec))) - { - return plan_err!("Join operation cannot operate on a non-prunable stream without enabling \ - the 'allow_symmetric_joins_without_pruning' configuration flag"); - } - } - if !input.execution_mode().pipeline_friendly() { - plan_err!( - "Cannot execute pipeline breaking queries, operator: {:?}", - input - ) - } else { - Ok(Transformed::no(input)) - } -} - -/// This function returns whether a given symmetric hash join is amenable to -/// data pruning. For this to be possible, it needs to have a filter where -/// all involved [`PhysicalExpr`]s, [`Operator`]s and data types support -/// interval calculations. -/// -/// [`PhysicalExpr`]: crate::physical_plan::PhysicalExpr -/// [`Operator`]: datafusion_expr::Operator -fn is_prunable(join: &SymmetricHashJoinExec) -> bool { - join.filter().map_or(false, |filter| { - check_support(filter.expression(), &join.schema()) - && filter - .schema() - .fields() - .iter() - .all(|f| is_datatype_supported(f.data_type())) - }) -} - -#[cfg(test)] -mod sql_tests { - use super::*; - use crate::physical_optimizer::test_utils::{ - BinaryTestCase, QueryCase, SourceType, UnaryTestCase, - }; - - #[tokio::test] - async fn test_hash_left_join_swap() -> Result<()> { - let test1 = BinaryTestCase { - source_types: (SourceType::Unbounded, SourceType::Bounded), - expect_fail: false, - }; - - let test2 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Unbounded), - expect_fail: true, - }; - let test3 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Bounded), - expect_fail: false, - }; - let case = QueryCase { - sql: "SELECT t2.c1 FROM left as t1 LEFT JOIN right as t2 ON t1.c1 = t2.c1" - .to_string(), - cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], - error_operator: "operator: HashJoinExec".to_string(), - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_hash_right_join_swap() -> Result<()> { - let test1 = BinaryTestCase { - source_types: (SourceType::Unbounded, SourceType::Bounded), - expect_fail: true, - }; - let test2 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Unbounded), - expect_fail: false, - }; - let test3 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Bounded), - expect_fail: false, - }; - let case = QueryCase { - sql: "SELECT t2.c1 FROM left as t1 RIGHT JOIN right as t2 ON t1.c1 = t2.c1" - .to_string(), - cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], - error_operator: "operator: HashJoinExec".to_string(), - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_hash_inner_join_swap() -> Result<()> { - let test1 = BinaryTestCase { - source_types: (SourceType::Unbounded, SourceType::Bounded), - expect_fail: false, - }; - let test2 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Unbounded), - expect_fail: false, - }; - let test3 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Bounded), - expect_fail: false, - }; - let case = QueryCase { - sql: "SELECT t2.c1 FROM left as t1 JOIN right as t2 ON t1.c1 = t2.c1" - .to_string(), - cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], - error_operator: "Join Error".to_string(), - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_hash_full_outer_join_swap() -> Result<()> { - let test1 = BinaryTestCase { - source_types: (SourceType::Unbounded, SourceType::Bounded), - expect_fail: true, - }; - let test2 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Unbounded), - expect_fail: true, - }; - let test3 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Bounded), - expect_fail: false, - }; - let case = QueryCase { - sql: "SELECT t2.c1 FROM left as t1 FULL JOIN right as t2 ON t1.c1 = t2.c1" - .to_string(), - cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], - error_operator: "operator: HashJoinExec".to_string(), - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_aggregate() -> Result<()> { - let test1 = UnaryTestCase { - source_type: SourceType::Bounded, - expect_fail: false, - }; - let test2 = UnaryTestCase { - source_type: SourceType::Unbounded, - expect_fail: true, - }; - let case = QueryCase { - sql: "SELECT c1, MIN(c4) FROM test GROUP BY c1".to_string(), - cases: vec![Arc::new(test1), Arc::new(test2)], - error_operator: "operator: AggregateExec".to_string(), - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_window_agg_hash_partition() -> Result<()> { - let test1 = UnaryTestCase { - source_type: SourceType::Bounded, - expect_fail: false, - }; - let test2 = UnaryTestCase { - source_type: SourceType::Unbounded, - expect_fail: true, - }; - let case = QueryCase { - sql: "SELECT - c9, - SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1 - FROM test - LIMIT 5".to_string(), - cases: vec![Arc::new(test1), Arc::new(test2)], - error_operator: "operator: SortExec".to_string() - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_window_agg_single_partition() -> Result<()> { - let test1 = UnaryTestCase { - source_type: SourceType::Bounded, - expect_fail: false, - }; - let test2 = UnaryTestCase { - source_type: SourceType::Unbounded, - expect_fail: true, - }; - let case = QueryCase { - sql: "SELECT - c9, - SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1 - FROM test".to_string(), - cases: vec![Arc::new(test1), Arc::new(test2)], - error_operator: "operator: SortExec".to_string() - }; - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_hash_cross_join() -> Result<()> { - let test1 = BinaryTestCase { - source_types: (SourceType::Unbounded, SourceType::Bounded), - expect_fail: true, - }; - let test2 = BinaryTestCase { - source_types: (SourceType::Unbounded, SourceType::Unbounded), - expect_fail: true, - }; - let test3 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Unbounded), - expect_fail: true, - }; - let test4 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Bounded), - expect_fail: false, - }; - let case = QueryCase { - sql: "SELECT t2.c1 FROM left as t1 CROSS JOIN right as t2".to_string(), - cases: vec![ - Arc::new(test1), - Arc::new(test2), - Arc::new(test3), - Arc::new(test4), - ], - error_operator: "operator: CrossJoinExec".to_string(), - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_analyzer() -> Result<()> { - let test1 = UnaryTestCase { - source_type: SourceType::Bounded, - expect_fail: false, - }; - let test2 = UnaryTestCase { - source_type: SourceType::Unbounded, - expect_fail: false, - }; - let case = QueryCase { - sql: "EXPLAIN ANALYZE SELECT * FROM test".to_string(), - cases: vec![Arc::new(test1), Arc::new(test2)], - error_operator: "Analyze Error".to_string(), - }; - - case.run().await?; - Ok(()) - } -} diff --git a/datafusion/core/src/physical_optimizer/sanity_checker.rs b/datafusion/core/src/physical_optimizer/sanity_checker.rs new file mode 100644 index 0000000000000..083b42f7400bc --- /dev/null +++ b/datafusion/core/src/physical_optimizer/sanity_checker.rs @@ -0,0 +1,666 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +//http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! The [SanityCheckPlan] rule ensures that a given plan can +//! accommodate its infinite sources, if there are any. It will reject +//! non-runnable query plans that use pipeline-breaking operators on +//! infinite input(s). In addition, it will check if all order and +//! distribution requirements of a plan are satisfied by its children. + +use std::sync::Arc; + +use crate::error::Result; +use crate::physical_optimizer::PhysicalOptimizerRule; +use crate::physical_plan::ExecutionPlan; + +use datafusion_common::config::{ConfigOptions, OptimizerOptions}; +use datafusion_common::plan_err; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; +use datafusion_physical_plan::joins::SymmetricHashJoinExec; +use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; + +use itertools::izip; + +/// The SanityCheckPlan rule rejects the following query plans: +/// 1. Invalid plans containing nodes whose order and/or distribution requirements +/// are not satisfied by their children. +/// 2. Plans that use pipeline-breaking operators on infinite input(s), +/// it is impossible to execute such queries (they will never generate output nor finish) +#[derive(Default)] +pub struct SanityCheckPlan {} + +impl SanityCheckPlan { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for SanityCheckPlan { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + plan.transform_up(|p| check_plan_sanity(p, &config.optimizer)) + .data() + } + + fn name(&self) -> &str { + "SanityCheckPlan" + } + + fn schema_check(&self) -> bool { + true + } +} + +/// This function propagates finiteness information and rejects any plan with +/// pipeline-breaking operators acting on infinite inputs. +pub fn check_finiteness_requirements( + input: Arc, + optimizer_options: &OptimizerOptions, +) -> Result>> { + if let Some(exec) = input.as_any().downcast_ref::() { + if !(optimizer_options.allow_symmetric_joins_without_pruning + || (exec.check_if_order_information_available()? && is_prunable(exec))) + { + return plan_err!("Join operation cannot operate on a non-prunable stream without enabling \ + the 'allow_symmetric_joins_without_pruning' configuration flag"); + } + } + if !input.execution_mode().pipeline_friendly() { + plan_err!( + "Cannot execute pipeline breaking queries, operator: {:?}", + input + ) + } else { + Ok(Transformed::no(input)) + } +} + +/// This function returns whether a given symmetric hash join is amenable to +/// data pruning. For this to be possible, it needs to have a filter where +/// all involved [`PhysicalExpr`]s, [`Operator`]s and data types support +/// interval calculations. +/// +/// [`PhysicalExpr`]: crate::physical_plan::PhysicalExpr +/// [`Operator`]: datafusion_expr::Operator +fn is_prunable(join: &SymmetricHashJoinExec) -> bool { + join.filter().map_or(false, |filter| { + check_support(filter.expression(), &join.schema()) + && filter + .schema() + .fields() + .iter() + .all(|f| is_datatype_supported(f.data_type())) + }) +} + +/// Ensures that the plan is pipeline friendly and the order and +/// distribution requirements from its children are satisfied. +pub fn check_plan_sanity( + plan: Arc, + optimizer_options: &OptimizerOptions, +) -> Result>> { + check_finiteness_requirements(plan.clone(), optimizer_options)?; + + for (child, child_sort_req, child_dist_req) in izip!( + plan.children().iter(), + plan.required_input_ordering().iter(), + plan.required_input_distribution().iter() + ) { + let child_eq_props = child.equivalence_properties(); + if let Some(child_sort_req) = child_sort_req { + if !child_eq_props.ordering_satisfy_requirement(child_sort_req) { + let child_plan_str = get_plan_string(child); + return plan_err!( + "Child: {:?} does not satisfy parent order requirements: {:?}", + child_plan_str, + child_sort_req + ); + } + } + + if !child + .output_partitioning() + .satisfy(child_dist_req, child_eq_props) + { + let child_plan_str = get_plan_string(child); + return plan_err!( + "Child: {:?} does not satisfy parent distribution requirements: {:?}", + child_plan_str, + child_dist_req + ); + } + } + + Ok(Transformed::no(plan)) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::physical_optimizer::test_utils::{ + bounded_window_exec, global_limit_exec, local_limit_exec, memory_exec, + repartition_exec, sort_exec, sort_expr_options, sort_merge_join_exec, + BinaryTestCase, QueryCase, SourceType, UnaryTestCase, + }; + + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::Result; + use datafusion_expr::JoinType; + use datafusion_physical_expr::expressions::col; + use datafusion_physical_expr::Partitioning; + use datafusion_physical_plan::displayable; + use datafusion_physical_plan::repartition::RepartitionExec; + + fn create_test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("c9", DataType::Int32, true)])) + } + + fn create_test_schema2() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])) + } + + /// Check if sanity checker should accept or reject plans. + fn assert_sanity_check(plan: &Arc, is_sane: bool) { + let sanity_checker = SanityCheckPlan::new(); + let opts = ConfigOptions::default(); + assert_eq!( + sanity_checker.optimize(plan.clone(), &opts).is_ok(), + is_sane + ); + } + + /// Check if the plan we created is as expected by comparing the plan + /// formatted as a string. + fn assert_plan(plan: &dyn ExecutionPlan, expected_lines: Vec<&str>) { + let plan_str = displayable(plan).indent(true).to_string(); + let actual_lines: Vec<&str> = plan_str.trim().lines().collect(); + assert_eq!(actual_lines, expected_lines); + } + + #[tokio::test] + async fn test_hash_left_join_swap() -> Result<()> { + let test1 = BinaryTestCase { + source_types: (SourceType::Unbounded, SourceType::Bounded), + expect_fail: false, + }; + + let test2 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Unbounded), + expect_fail: true, + }; + let test3 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Bounded), + expect_fail: false, + }; + let case = QueryCase { + sql: "SELECT t2.c1 FROM left as t1 LEFT JOIN right as t2 ON t1.c1 = t2.c1" + .to_string(), + cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], + error_operator: "operator: HashJoinExec".to_string(), + }; + + case.run().await?; + Ok(()) + } + + #[tokio::test] + async fn test_hash_right_join_swap() -> Result<()> { + let test1 = BinaryTestCase { + source_types: (SourceType::Unbounded, SourceType::Bounded), + expect_fail: true, + }; + let test2 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Unbounded), + expect_fail: false, + }; + let test3 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Bounded), + expect_fail: false, + }; + let case = QueryCase { + sql: "SELECT t2.c1 FROM left as t1 RIGHT JOIN right as t2 ON t1.c1 = t2.c1" + .to_string(), + cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], + error_operator: "operator: HashJoinExec".to_string(), + }; + + case.run().await?; + Ok(()) + } + + #[tokio::test] + async fn test_hash_inner_join_swap() -> Result<()> { + let test1 = BinaryTestCase { + source_types: (SourceType::Unbounded, SourceType::Bounded), + expect_fail: false, + }; + let test2 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Unbounded), + expect_fail: false, + }; + let test3 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Bounded), + expect_fail: false, + }; + let case = QueryCase { + sql: "SELECT t2.c1 FROM left as t1 JOIN right as t2 ON t1.c1 = t2.c1" + .to_string(), + cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], + error_operator: "Join Error".to_string(), + }; + + case.run().await?; + Ok(()) + } + + #[tokio::test] + async fn test_hash_full_outer_join_swap() -> Result<()> { + let test1 = BinaryTestCase { + source_types: (SourceType::Unbounded, SourceType::Bounded), + expect_fail: true, + }; + let test2 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Unbounded), + expect_fail: true, + }; + let test3 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Bounded), + expect_fail: false, + }; + let case = QueryCase { + sql: "SELECT t2.c1 FROM left as t1 FULL JOIN right as t2 ON t1.c1 = t2.c1" + .to_string(), + cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], + error_operator: "operator: HashJoinExec".to_string(), + }; + + case.run().await?; + Ok(()) + } + + #[tokio::test] + async fn test_aggregate() -> Result<()> { + let test1 = UnaryTestCase { + source_type: SourceType::Bounded, + expect_fail: false, + }; + let test2 = UnaryTestCase { + source_type: SourceType::Unbounded, + expect_fail: true, + }; + let case = QueryCase { + sql: "SELECT c1, MIN(c4) FROM test GROUP BY c1".to_string(), + cases: vec![Arc::new(test1), Arc::new(test2)], + error_operator: "operator: AggregateExec".to_string(), + }; + + case.run().await?; + Ok(()) + } + + #[tokio::test] + async fn test_window_agg_hash_partition() -> Result<()> { + let test1 = UnaryTestCase { + source_type: SourceType::Bounded, + expect_fail: false, + }; + let test2 = UnaryTestCase { + source_type: SourceType::Unbounded, + expect_fail: true, + }; + let case = QueryCase { + sql: "SELECT + c9, + SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1 + FROM test + LIMIT 5".to_string(), + cases: vec![Arc::new(test1), Arc::new(test2)], + error_operator: "operator: SortExec".to_string() + }; + + case.run().await?; + Ok(()) + } + + #[tokio::test] + async fn test_window_agg_single_partition() -> Result<()> { + let test1 = UnaryTestCase { + source_type: SourceType::Bounded, + expect_fail: false, + }; + let test2 = UnaryTestCase { + source_type: SourceType::Unbounded, + expect_fail: true, + }; + let case = QueryCase { + sql: "SELECT + c9, + SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1 + FROM test".to_string(), + cases: vec![Arc::new(test1), Arc::new(test2)], + error_operator: "operator: SortExec".to_string() + }; + case.run().await?; + Ok(()) + } + + #[tokio::test] + async fn test_hash_cross_join() -> Result<()> { + let test1 = BinaryTestCase { + source_types: (SourceType::Unbounded, SourceType::Bounded), + expect_fail: true, + }; + let test2 = BinaryTestCase { + source_types: (SourceType::Unbounded, SourceType::Unbounded), + expect_fail: true, + }; + let test3 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Unbounded), + expect_fail: true, + }; + let test4 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Bounded), + expect_fail: false, + }; + let case = QueryCase { + sql: "SELECT t2.c1 FROM left as t1 CROSS JOIN right as t2".to_string(), + cases: vec![ + Arc::new(test1), + Arc::new(test2), + Arc::new(test3), + Arc::new(test4), + ], + error_operator: "operator: CrossJoinExec".to_string(), + }; + + case.run().await?; + Ok(()) + } + + #[tokio::test] + async fn test_analyzer() -> Result<()> { + let test1 = UnaryTestCase { + source_type: SourceType::Bounded, + expect_fail: false, + }; + let test2 = UnaryTestCase { + source_type: SourceType::Unbounded, + expect_fail: false, + }; + let case = QueryCase { + sql: "EXPLAIN ANALYZE SELECT * FROM test".to_string(), + cases: vec![Arc::new(test1), Arc::new(test2)], + error_operator: "Analyze Error".to_string(), + }; + + case.run().await?; + Ok(()) + } + + #[tokio::test] + /// Tests that plan is valid when the sort requirements are satisfied. + async fn test_bounded_window_agg_sort_requirement() -> Result<()> { + let schema = create_test_schema(); + let source = memory_exec(&schema); + let sort_exprs = vec![sort_expr_options( + "c9", + &source.schema(), + SortOptions { + descending: false, + nulls_first: false, + }, + )]; + let sort = sort_exec(sort_exprs.clone(), source); + let bw = bounded_window_exec("c9", sort_exprs, sort); + assert_plan(bw.as_ref(), vec![ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[c9@0 ASC NULLS LAST], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]" + ]); + assert_sanity_check(&bw, true); + Ok(()) + } + + #[tokio::test] + /// Tests that plan is invalid when the sort requirements are not satisfied. + async fn test_bounded_window_agg_no_sort_requirement() -> Result<()> { + let schema = create_test_schema(); + let source = memory_exec(&schema); + let sort_exprs = vec![sort_expr_options( + "c9", + &source.schema(), + SortOptions { + descending: false, + nulls_first: false, + }, + )]; + let bw = bounded_window_exec("c9", sort_exprs, source); + assert_plan(bw.as_ref(), vec![ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " MemoryExec: partitions=1, partition_sizes=[0]" + ]); + // Order requirement of the `BoundedWindowAggExec` is not satisfied. We expect to receive error during sanity check. + assert_sanity_check(&bw, false); + Ok(()) + } + + #[tokio::test] + /// A valid when a single partition requirement + /// is satisfied. + async fn test_global_limit_single_partition() -> Result<()> { + let schema = create_test_schema(); + let source = memory_exec(&schema); + let limit = global_limit_exec(source); + + assert_plan( + limit.as_ref(), + vec![ + "GlobalLimitExec: skip=0, fetch=100", + " MemoryExec: partitions=1, partition_sizes=[0]", + ], + ); + assert_sanity_check(&limit, true); + Ok(()) + } + + #[tokio::test] + /// An invalid plan when a single partition requirement + /// is not satisfied. + async fn test_global_limit_multi_partition() -> Result<()> { + let schema = create_test_schema(); + let source = memory_exec(&schema); + let limit = global_limit_exec(repartition_exec(source)); + + assert_plan( + limit.as_ref(), + vec![ + "GlobalLimitExec: skip=0, fetch=100", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " MemoryExec: partitions=1, partition_sizes=[0]", + ], + ); + // Distribution requirement of the `GlobalLimitExec` is not satisfied. We expect to receive error during sanity check. + assert_sanity_check(&limit, false); + Ok(()) + } + + #[tokio::test] + /// A plan with no requirements should satisfy. + async fn test_local_limit() -> Result<()> { + let schema = create_test_schema(); + let source = memory_exec(&schema); + let limit = local_limit_exec(source); + + assert_plan( + limit.as_ref(), + vec![ + "LocalLimitExec: fetch=100", + " MemoryExec: partitions=1, partition_sizes=[0]", + ], + ); + assert_sanity_check(&limit, true); + Ok(()) + } + + #[tokio::test] + /// Valid plan with multiple children satisfy both order and distribution. + async fn test_sort_merge_join_satisfied() -> Result<()> { + let schema1 = create_test_schema(); + let schema2 = create_test_schema2(); + let source1 = memory_exec(&schema1); + let source2 = memory_exec(&schema2); + let sort_opts = SortOptions::default(); + let sort_exprs1 = vec![sort_expr_options("c9", &source1.schema(), sort_opts)]; + let sort_exprs2 = vec![sort_expr_options("a", &source2.schema(), sort_opts)]; + let left = sort_exec(sort_exprs1, source1); + let right = sort_exec(sort_exprs2, source2); + let left_jcol = col("c9", &left.schema()).unwrap(); + let right_jcol = col("a", &right.schema()).unwrap(); + let left = Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(vec![left_jcol.clone()], 10), + )?); + + let right = Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(vec![right_jcol.clone()], 10), + )?); + + let join_on = vec![(left_jcol as _, right_jcol as _)]; + let join_ty = JoinType::Inner; + let smj = sort_merge_join_exec(left, right, &join_on, &join_ty); + + assert_plan( + smj.as_ref(), + vec![ + "SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)]", + " RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1", + " SortExec: expr=[c9@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ], + ); + assert_sanity_check(&smj, true); + Ok(()) + } + + #[tokio::test] + /// Invalid case when the order is not satisfied by the 2nd + /// child. + async fn test_sort_merge_join_order_missing() -> Result<()> { + let schema1 = create_test_schema(); + let schema2 = create_test_schema2(); + let source1 = memory_exec(&schema1); + let right = memory_exec(&schema2); + let sort_exprs1 = vec![sort_expr_options( + "c9", + &source1.schema(), + SortOptions::default(), + )]; + let left = sort_exec(sort_exprs1, source1); + // Missing sort of the right child here.. + let left_jcol = col("c9", &left.schema()).unwrap(); + let right_jcol = col("a", &right.schema()).unwrap(); + let left = Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(vec![left_jcol.clone()], 10), + )?); + + let right = Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(vec![right_jcol.clone()], 10), + )?); + + let join_on = vec![(left_jcol as _, right_jcol as _)]; + let join_ty = JoinType::Inner; + let smj = sort_merge_join_exec(left, right, &join_on, &join_ty); + + assert_plan( + smj.as_ref(), + vec![ + "SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)]", + " RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1", + " SortExec: expr=[c9@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1", + " MemoryExec: partitions=1, partition_sizes=[0]", + ], + ); + // Order requirement for the `SortMergeJoin` is not satisfied for right child. We expect to receive error during sanity check. + assert_sanity_check(&smj, false); + Ok(()) + } + + #[tokio::test] + /// Invalid case when the distribution is not satisfied by the 2nd + /// child. + async fn test_sort_merge_join_dist_missing() -> Result<()> { + let schema1 = create_test_schema(); + let schema2 = create_test_schema2(); + let source1 = memory_exec(&schema1); + let source2 = memory_exec(&schema2); + let sort_opts = SortOptions::default(); + let sort_exprs1 = vec![sort_expr_options("c9", &source1.schema(), sort_opts)]; + let sort_exprs2 = vec![sort_expr_options("a", &source2.schema(), sort_opts)]; + let left = sort_exec(sort_exprs1, source1); + let right = sort_exec(sort_exprs2, source2); + let right = Arc::new(RepartitionExec::try_new( + right, + Partitioning::RoundRobinBatch(10), + )?); + let left_jcol = col("c9", &left.schema()).unwrap(); + let right_jcol = col("a", &right.schema()).unwrap(); + let left = Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(vec![left_jcol.clone()], 10), + )?); + + // Missing hash partitioning on right child. + + let join_on = vec![(left_jcol as _, right_jcol as _)]; + let join_ty = JoinType::Inner; + let smj = sort_merge_join_exec(left, right, &join_on, &join_ty); + + assert_plan( + smj.as_ref(), + vec![ + "SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)]", + " RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1", + " SortExec: expr=[c9@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ], + ); + // Distribution requirement for the `SortMergeJoin` is not satisfied for right child (has round-robin partitioning). We expect to receive error during sanity check. + assert_sanity_check(&smj, false); + Ok(()) + } +} diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index 83531da3ca8ff..36ac4b22d5942 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -245,11 +245,38 @@ fn try_pushdown_requirements_to_join( sort_expr: Vec, push_side: JoinSide, ) -> Result>>>> { + let left_eq_properties = smj.left().equivalence_properties(); + let right_eq_properties = smj.right().equivalence_properties(); + let mut smj_required_orderings = smj.required_input_ordering(); + let right_requirement = smj_required_orderings.swap_remove(1); + let left_requirement = smj_required_orderings.swap_remove(0); let left_ordering = smj.left().output_ordering().unwrap_or(&[]); let right_ordering = smj.right().output_ordering().unwrap_or(&[]); let (new_left_ordering, new_right_ordering) = match push_side { - JoinSide::Left => (sort_expr.as_slice(), right_ordering), - JoinSide::Right => (left_ordering, sort_expr.as_slice()), + JoinSide::Left => { + let left_eq_properties = + left_eq_properties.clone().with_reorder(sort_expr.clone()); + if left_eq_properties + .ordering_satisfy_requirement(&left_requirement.unwrap_or_default()) + { + // After re-ordering requirement is still satisfied + (sort_expr.as_slice(), right_ordering) + } else { + return Ok(None); + } + } + JoinSide::Right => { + let right_eq_properties = + right_eq_properties.clone().with_reorder(sort_expr.clone()); + if right_eq_properties + .ordering_satisfy_requirement(&right_requirement.unwrap_or_default()) + { + // After re-ordering requirement is still satisfied + (left_ordering, sort_expr.as_slice()) + } else { + return Ok(None); + } + } }; let join_type = smj.join_type(); let probe_side = SortMergeJoinExec::probe_side(&join_type); diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index 21ef8a7c2110f..f00d17a06ffc9 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -41,7 +41,7 @@ mod sp_repartition_fuzz_tests { use datafusion_physical_expr::{ equivalence::{EquivalenceClass, EquivalenceProperties}, expressions::{col, Column}, - PhysicalExpr, PhysicalSortExpr, + ConstExpr, PhysicalExpr, PhysicalSortExpr, }; use test_utils::add_empty_batches; @@ -80,7 +80,7 @@ mod sp_repartition_fuzz_tests { // Define a and f are aliases eq_properties.add_equal_conditions(col_a, col_f)?; // Column e has constant value. - eq_properties = eq_properties.add_constants([col_e.clone()]); + eq_properties = eq_properties.add_constants([ConstExpr::new(col_e.clone())]); // Randomly order columns for sorting let mut rng = StdRng::seed_from_u64(seed); @@ -149,7 +149,7 @@ mod sp_repartition_fuzz_tests { // Fill constant columns for constant in eq_properties.constants() { - let col = constant.as_any().downcast_ref::().unwrap(); + let col = constant.expr().as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); let arr = Arc::new(UInt64Array::from_iter_values(vec![0; n_elem])) as ArrayRef; diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index f7402357d1c76..7ef24609e238d 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -340,8 +340,8 @@ async fn oom_parquet_sink() { path.to_string_lossy() )) .with_expected_errors(vec![ - // TODO: update error handling in ParquetSink - "Unable to send array to writer!", + "Failed to allocate additional", + "for ParquetSink(ArrowColumnWriter)", ]) .with_memory_limit(200_000) .run() diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index fad9b94b01120..addabc8a36127 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -230,7 +230,7 @@ async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> { match df.create_physical_plan().await { Ok(_) => panic!("Expecting error."), Err(e) => { - assert_eq!(e.strip_backtrace(), "PipelineChecker\ncaused by\nError during planning: Join operation cannot operate on a non-prunable stream without enabling the 'allow_symmetric_joins_without_pruning' configuration flag") + assert_eq!(e.strip_backtrace(), "SanityCheckPlan\ncaused by\nError during planning: Join operation cannot operate on a non-prunable stream without enabling the 'allow_symmetric_joins_without_pruning' configuration flag") } } Ok(()) diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index b4d12e963611e..6c12acb934bee 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -28,6 +28,73 @@ use crate::{ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::JoinType; +#[derive(Debug, Clone)] +/// A structure representing a expression known to be constant in a physical execution plan. +/// +/// The `ConstExpr` struct encapsulates an expression that is constant during the execution +/// of a query. For example if a predicate like `A = 5` applied earlier in the plan `A` would +/// be known constant +/// +/// # Fields +/// +/// - `expr`: Constant expression for a node in the physical plan. +/// +/// - `across_partitions`: A boolean flag indicating whether the constant expression is +/// valid across partitions. If set to `true`, the constant expression has same value for all partitions. +/// If set to `false`, the constant expression may have different values for different partitions. +pub struct ConstExpr { + expr: Arc, + across_partitions: bool, +} + +impl ConstExpr { + pub fn new(expr: Arc) -> Self { + Self { + expr, + // By default, assume constant expressions are not same accross partitions. + across_partitions: false, + } + } + + pub fn with_across_partitions(mut self, across_partitions: bool) -> Self { + self.across_partitions = across_partitions; + self + } + + pub fn across_partitions(&self) -> bool { + self.across_partitions + } + + pub fn expr(&self) -> &Arc { + &self.expr + } + + pub fn owned_expr(self) -> Arc { + self.expr + } + + pub fn map(&self, f: F) -> Option + where + F: Fn(&Arc) -> Option>, + { + let maybe_expr = f(&self.expr); + maybe_expr.map(|expr| Self { + expr, + across_partitions: self.across_partitions, + }) + } +} + +/// Checks whether `expr` is among in the `const_exprs`. +pub fn const_exprs_contains( + const_exprs: &[ConstExpr], + expr: &Arc, +) -> bool { + const_exprs + .iter() + .any(|const_expr| const_expr.expr.eq(expr)) +} + /// An `EquivalenceClass` is a set of [`Arc`]s that are known /// to have the same value for all tuples in a relation. These are generated by /// equality predicates (e.g. `a = b`), typically equi-join conditions and diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index 7faf2caae01c9..5eb8a19e3d672 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -27,7 +27,7 @@ mod ordering; mod projection; mod properties; -pub use class::{EquivalenceClass, EquivalenceGroup}; +pub use class::{ConstExpr, EquivalenceClass, EquivalenceGroup}; pub use ordering::OrderingEquivalenceClass; pub use projection::ProjectionMapping; pub use properties::{join_equivalence_properties, EquivalenceProperties}; @@ -205,7 +205,7 @@ mod tests { // Define a and f are aliases eq_properties.add_equal_conditions(col_a, col_f)?; // Column e has constant value. - eq_properties = eq_properties.add_constants([col_e.clone()]); + eq_properties = eq_properties.add_constants([ConstExpr::new(col_e.clone())]); // Randomly order columns for sorting let mut rng = StdRng::seed_from_u64(seed); @@ -482,7 +482,7 @@ mod tests { // Fill constant columns for constant in &eq_properties.constants { - let col = constant.as_any().downcast_ref::().unwrap(); + let col = constant.expr().as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); let arr = Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) as ArrayRef; diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index 7857d9df726e9..ac9d64e486ac6 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -234,7 +234,7 @@ mod tests { }; use crate::expressions::{col, BinaryExpr, Column}; use crate::utils::tests::TestScalarUDF; - use crate::{PhysicalExpr, PhysicalSortExpr}; + use crate::{ConstExpr, PhysicalExpr, PhysicalSortExpr}; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::SortOptions; @@ -554,7 +554,9 @@ mod tests { let eq_group = EquivalenceGroup::new(eq_group); eq_properties.add_equivalence_group(eq_group); - let constants = constants.into_iter().cloned(); + let constants = constants + .into_iter() + .map(|expr| ConstExpr::new(expr.clone()).with_across_partitions(true)); eq_properties = eq_properties.add_constants(constants); let reqs = convert_to_sort_exprs(&reqs); diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 7bf389ecfdf32..e3a2d1c753ca4 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -19,12 +19,13 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use super::ordering::collapse_lex_ordering; +use crate::equivalence::class::const_exprs_contains; use crate::equivalence::{ collapse_lex_req, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, }; use crate::expressions::Literal; use crate::{ - physical_exprs_contains, LexOrdering, LexOrderingRef, LexRequirement, + physical_exprs_contains, ConstExpr, LexOrdering, LexOrderingRef, LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement, }; @@ -92,7 +93,7 @@ pub struct EquivalenceProperties { /// Expressions whose values are constant throughout the table. /// TODO: We do not need to track constants separately, they can be tracked /// inside `eq_groups` as `Literal` expressions. - pub constants: Vec>, + pub constants: Vec, /// Schema associated with this object. schema: SchemaRef, } @@ -134,7 +135,7 @@ impl EquivalenceProperties { } /// Returns a reference to the constant expressions - pub fn constants(&self) -> &[Arc] { + pub fn constants(&self) -> &[ConstExpr] { &self.constants } @@ -144,7 +145,7 @@ impl EquivalenceProperties { let mut output_ordering = self.oeq_class().output_ordering().unwrap_or_default(); // Prune out constant expressions output_ordering - .retain(|sort_expr| !physical_exprs_contains(constants, &sort_expr.expr)); + .retain(|sort_expr| !const_exprs_contains(constants, &sort_expr.expr)); (!output_ordering.is_empty()).then_some(output_ordering) } @@ -173,6 +174,12 @@ impl EquivalenceProperties { self.oeq_class.clear(); } + /// Removes constant expressions that may change across partitions. + /// This method should be used when data from different partitions are merged. + pub fn clear_per_partition_constants(&mut self) { + self.constants.retain(|item| item.across_partitions()); + } + /// Extends this `EquivalenceProperties` by adding the orderings inside the /// ordering equivalence class `other`. pub fn add_ordering_equivalence_class(&mut self, other: OrderingEquivalenceClass) { @@ -204,13 +211,15 @@ impl EquivalenceProperties { // Discover new constants in light of new the equality: if self.is_expr_constant(left) { // Left expression is constant, add right as constant - if !physical_exprs_contains(&self.constants, right) { - self.constants.push(right.clone()); + if !const_exprs_contains(&self.constants, right) { + self.constants + .push(ConstExpr::new(right.clone()).with_across_partitions(true)); } } else if self.is_expr_constant(right) { // Right expression is constant, add left as constant - if !physical_exprs_contains(&self.constants, left) { - self.constants.push(left.clone()); + if !const_exprs_contains(&self.constants, left) { + self.constants + .push(ConstExpr::new(left.clone()).with_across_partitions(true)); } } @@ -270,11 +279,29 @@ impl EquivalenceProperties { /// Track/register physical expressions with constant values. pub fn add_constants( mut self, - constants: impl IntoIterator>, + constants: impl IntoIterator, ) -> Self { - for expr in self.eq_group.normalize_exprs(constants) { - if !physical_exprs_contains(&self.constants, &expr) { - self.constants.push(expr); + let (const_exprs, across_partition_flags): ( + Vec>, + Vec, + ) = constants + .into_iter() + .map(|const_expr| { + let across_partitions = const_expr.across_partitions(); + let expr = const_expr.owned_expr(); + (expr, across_partitions) + }) + .unzip(); + for (expr, across_partitions) in self + .eq_group + .normalize_exprs(const_exprs) + .into_iter() + .zip(across_partition_flags) + { + if !const_exprs_contains(&self.constants, &expr) { + let const_expr = + ConstExpr::new(expr).with_across_partitions(across_partitions); + self.constants.push(const_expr); } } self @@ -326,7 +353,13 @@ impl EquivalenceProperties { sort_reqs: LexRequirementRef, ) -> LexRequirement { let normalized_sort_reqs = self.eq_group.normalize_sort_requirements(sort_reqs); - let constants_normalized = self.eq_group.normalize_exprs(self.constants.clone()); + let mut constant_exprs = vec![]; + constant_exprs.extend( + self.constants + .iter() + .map(|const_expr| const_expr.expr().clone()), + ); + let constants_normalized = self.eq_group.normalize_exprs(constant_exprs); // Prune redundant sections in the requirement: collapse_lex_req( normalized_sort_reqs @@ -370,8 +403,8 @@ impl EquivalenceProperties { // From the analysis above, we know that `[a ASC]` is satisfied. Then, // we add column `a` as constant to the algorithm state. This enables us // to deduce that `(b + c) ASC` is satisfied, given `a` is constant. - eq_properties = - eq_properties.add_constants(std::iter::once(normalized_req.expr)); + eq_properties = eq_properties + .add_constants(std::iter::once(ConstExpr::new(normalized_req.expr))); } true } @@ -781,24 +814,25 @@ impl EquivalenceProperties { /// # Returns /// /// Returns a `Vec>` containing the projected constants. - fn projected_constants( - &self, - mapping: &ProjectionMapping, - ) -> Vec> { + fn projected_constants(&self, mapping: &ProjectionMapping) -> Vec { // First, project existing constants. For example, assume that `a + b` // is known to be constant. If the projection were `a as a_new`, `b as b_new`, // then we would project constant `a + b` as `a_new + b_new`. let mut projected_constants = self .constants .iter() - .flat_map(|expr| self.eq_group.project_expr(mapping, expr)) + .flat_map(|const_expr| { + const_expr.map(|expr| self.eq_group.project_expr(mapping, expr)) + }) .collect::>(); // Add projection expressions that are known to be constant: for (source, target) in mapping.iter() { if self.is_expr_constant(source) - && !physical_exprs_contains(&projected_constants, target) + && !const_exprs_contains(&projected_constants, target) { - projected_constants.push(target.clone()); + // Expression evaluates to single value + projected_constants + .push(ConstExpr::new(target.clone()).with_across_partitions(true)); } } projected_constants @@ -891,8 +925,8 @@ impl EquivalenceProperties { // Note that these expressions are not properly "constants". This is just // an implementation strategy confined to this function. for (PhysicalSortExpr { expr, .. }, idx) in &ordered_exprs { - eq_properties = - eq_properties.add_constants(std::iter::once(expr.clone())); + eq_properties = eq_properties + .add_constants(std::iter::once(ConstExpr::new(expr.clone()))); search_indices.shift_remove(idx); } // Add new ordered section to the state. @@ -917,7 +951,11 @@ impl EquivalenceProperties { // As an example, assume that we know columns `a` and `b` are constant. // Then, `a`, `b` and `a + b` will all return `true` whereas `c` will // return `false`. - let normalized_constants = self.eq_group.normalize_exprs(self.constants.to_vec()); + let const_exprs = self + .constants + .iter() + .map(|const_expr| const_expr.expr().clone()); + let normalized_constants = self.eq_group.normalize_exprs(const_exprs); let normalized_expr = self.eq_group.normalize_expr(expr.clone()); is_constant_recurse(&normalized_constants, &normalized_expr) } @@ -1307,8 +1345,16 @@ pub fn join_equivalence_properties( on, )); - let left_oeq_class = left.oeq_class; - let mut right_oeq_class = right.oeq_class; + let EquivalenceProperties { + constants: left_constants, + oeq_class: left_oeq_class, + .. + } = left; + let EquivalenceProperties { + constants: right_constants, + oeq_class: mut right_oeq_class, + .. + } = right; match maintains_input_order { [true, false] => { // In this special case, right side ordering can be prefixed with @@ -1361,6 +1407,15 @@ pub fn join_equivalence_properties( [true, true] => unreachable!("Cannot maintain ordering of both sides"), _ => unreachable!("Join operators can not have more than two children"), } + match join_type { + JoinType::LeftAnti | JoinType::LeftSemi => { + result = result.add_constants(left_constants); + } + JoinType::RightAnti | JoinType::RightSemi => { + result = result.add_constants(right_constants); + } + _ => {} + } result } @@ -2088,7 +2143,7 @@ mod tests { let col_h = &col("h", &test_schema)?; // Add column h as constant - eq_properties = eq_properties.add_constants(vec![col_h.clone()]); + eq_properties = eq_properties.add_constants(vec![ConstExpr::new(col_h.clone())]); let test_cases = vec![ // TEST CASE 1 @@ -2386,7 +2441,9 @@ mod tests { ]; for case in cases { - let mut properties = base_properties.clone().add_constants(case.constants); + let mut properties = base_properties + .clone() + .add_constants(case.constants.into_iter().map(ConstExpr::new)); for [left, right] in &case.equal_conditions { properties.add_equal_conditions(left, right)? } diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index b764e81a95d13..06c73636773eb 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -46,7 +46,7 @@ pub use analysis::{analyze, AnalysisContext, ExprBoundaries}; pub use datafusion_physical_expr_common::aggregate::{ AggregateExpr, AggregatePhysicalExpressions, }; -pub use equivalence::EquivalenceProperties; +pub use equivalence::{ConstExpr, EquivalenceProperties}; pub use partitioning::{Distribution, Partitioning}; pub use physical_expr::{ physical_exprs_bag_equal, physical_exprs_contains, physical_exprs_equal, diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index ce67cba2cd0e0..93f449f2d39b8 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -65,7 +65,7 @@ impl CoalescePartitionsExec { // Coalescing partitions loses existing orderings: let mut eq_properties = input.equivalence_properties().clone(); eq_properties.clear_orderings(); - + eq_properties.clear_per_partition_constants(); PlanProperties::new( eq_properties, // Equivalence Properties Partitioning::UnknownPartitioning(1), // Output Partitioning diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 6153dbacfbff0..c141958c11718 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -44,7 +44,7 @@ use datafusion_physical_expr::expressions::BinaryExpr; use datafusion_physical_expr::intervals::utils::check_support; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{ - analyze, split_conjunction, AnalysisContext, ExprBoundaries, PhysicalExpr, + analyze, split_conjunction, AnalysisContext, ConstExpr, ExprBoundaries, PhysicalExpr, }; use futures::stream::{Stream, StreamExt}; @@ -162,7 +162,7 @@ impl FilterExec { fn extend_constants( input: &Arc, predicate: &Arc, - ) -> Vec> { + ) -> Vec { let mut res_constants = Vec::new(); let input_eqs = input.equivalence_properties(); @@ -170,10 +170,17 @@ impl FilterExec { for conjunction in conjunctions { if let Some(binary) = conjunction.as_any().downcast_ref::() { if binary.op() == &Operator::Eq { + // Filter evaluates to single value for all partitions if input_eqs.is_expr_constant(binary.left()) { - res_constants.push(binary.right().clone()) + res_constants.push( + ConstExpr::new(binary.right().clone()) + .with_across_partitions(true), + ) } else if input_eqs.is_expr_constant(binary.right()) { - res_constants.push(binary.left().clone()) + res_constants.push( + ConstExpr::new(binary.left().clone()) + .with_across_partitions(true), + ) } } } @@ -199,7 +206,10 @@ impl FilterExec { let constants = collect_columns(predicate) .into_iter() .filter(|column| stats.column_statistics[column.index()].is_singleton()) - .map(|column| Arc::new(column) as _); + .map(|column| { + let expr = Arc::new(column) as _; + ConstExpr::new(expr).with_across_partitions(true) + }); // this is for statistics eq_properties = eq_properties.add_constants(constants); // this is for logical constant (for example: a = '1', then a could be marked as a constant) diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 65f7d5070a5d5..d9e16c98eee89 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -701,6 +701,11 @@ impl RepartitionExec { if !Self::maintains_input_order_helper(input, preserve_order)[0] { eq_properties.clear_orderings(); } + // When there are more than one input partitions, they will be fused at the output. + // Therefore, remove per partition constants. + if input.output_partitioning().partition_count() > 1 { + eq_properties.clear_per_partition_constants(); + } eq_properties } diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 8a349bd22abf8..e364aca3791c6 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -80,7 +80,7 @@ pub struct SortPreservingMergeExec { impl SortPreservingMergeExec { /// Create a new sort execution plan pub fn new(expr: Vec, input: Arc) -> Self { - let cache = Self::compute_properties(&input); + let cache = Self::compute_properties(&input, expr.clone()); Self { input, expr, @@ -111,11 +111,17 @@ impl SortPreservingMergeExec { } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties(input: &Arc) -> PlanProperties { + fn compute_properties( + input: &Arc, + ordering: Vec, + ) -> PlanProperties { + let mut eq_properties = input.equivalence_properties().clone(); + eq_properties.clear_per_partition_constants(); + eq_properties.add_new_orderings(vec![ordering]); PlanProperties::new( - input.equivalence_properties().clone(), // Equivalence Properties - Partitioning::UnknownPartitioning(1), // Output Partitioning - input.execution_mode(), // Execution Mode + eq_properties, // Equivalence Properties + Partitioning::UnknownPartitioning(1), // Output Partitioning + input.execution_mode(), // Execution Mode ) } } diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index dc7d270bae257..3f88eb4c3732b 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -41,7 +41,7 @@ use arrow::record_batch::RecordBatch; use datafusion_common::stats::Precision; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_expr::{ConstExpr, EquivalenceProperties}; use futures::Stream; use itertools::Itertools; @@ -118,41 +118,11 @@ impl UnionExec { schema: SchemaRef, ) -> PlanProperties { // Calculate equivalence properties: - // TODO: In some cases, we should be able to preserve some equivalence - // classes and constants. Add support for such cases. let children_eqs = inputs .iter() .map(|child| child.equivalence_properties()) .collect::>(); - let mut eq_properties = EquivalenceProperties::new(schema); - // Use the ordering equivalence class of the first child as the seed: - let mut meets = children_eqs[0] - .oeq_class() - .iter() - .map(|item| item.to_vec()) - .collect::>(); - // Iterate over all the children: - for child_eqs in &children_eqs[1..] { - // Compute meet orderings of the current meets and the new ordering - // equivalence class. - let mut idx = 0; - while idx < meets.len() { - // Find all the meets of `current_meet` with this child's orderings: - let valid_meets = child_eqs.oeq_class().iter().filter_map(|ordering| { - child_eqs.get_meet_ordering(ordering, &meets[idx]) - }); - // Use the longest of these meets as others are redundant: - if let Some(next_meet) = valid_meets.max_by_key(|m| m.len()) { - meets[idx] = next_meet; - idx += 1; - } else { - meets.swap_remove(idx); - } - } - } - // We know have all the valid orderings after union, remove redundant - // entries (implicitly) and return: - eq_properties.add_new_orderings(meets); + let eq_properties = calculate_union_eq_properties(&children_eqs, schema); // Calculate output partitioning; i.e. sum output partitions of the inputs. let num_partitions = inputs @@ -167,6 +137,68 @@ impl UnionExec { PlanProperties::new(eq_properties, output_partitioning, mode) } } +/// Calculate `EquivalenceProperties` for `UnionExec` from the `EquivalenceProperties` +/// of its children. +fn calculate_union_eq_properties( + children_eqs: &[&EquivalenceProperties], + schema: SchemaRef, +) -> EquivalenceProperties { + // Calculate equivalence properties: + // TODO: In some cases, we should be able to preserve some equivalence + // classes and constants. Add support for such cases. + let mut eq_properties = EquivalenceProperties::new(schema); + // Use the ordering equivalence class of the first child as the seed: + let mut meets = children_eqs[0] + .oeq_class() + .iter() + .map(|item| item.to_vec()) + .collect::>(); + // Iterate over all the children: + for child_eqs in &children_eqs[1..] { + // Compute meet orderings of the current meets and the new ordering + // equivalence class. + let mut idx = 0; + while idx < meets.len() { + // Find all the meets of `current_meet` with this child's orderings: + let valid_meets = child_eqs.oeq_class().iter().filter_map(|ordering| { + child_eqs.get_meet_ordering(ordering, &meets[idx]) + }); + // Use the longest of these meets as others are redundant: + if let Some(next_meet) = valid_meets.max_by_key(|m| m.len()) { + meets[idx] = next_meet; + idx += 1; + } else { + meets.swap_remove(idx); + } + } + } + // We know have all the valid orderings after union, remove redundant + // entries (implicitly) and return: + eq_properties.add_new_orderings(meets); + + let mut meet_constants = children_eqs[0].constants().to_vec(); + // Iterate over all the children: + for child_eqs in &children_eqs[1..] { + let constants = child_eqs.constants(); + meet_constants = meet_constants + .into_iter() + .filter_map(|meet_constant| { + for const_expr in constants { + if const_expr.expr().eq(meet_constant.expr()) { + // TODO: Check whether constant expressions evaluates the same value or not for each partition + let across_partitions = false; + return Some( + ConstExpr::new(meet_constant.owned_expr()) + .with_across_partitions(across_partitions), + ); + } + } + None + }) + .collect::>(); + } + eq_properties.add_constants(meet_constants) +} impl DisplayAs for UnionExec { fn fmt_as( diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 181c308004346..252c8d12b5194 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -41,7 +41,8 @@ use datafusion_physical_expr::equivalence::collapse_lex_req; use datafusion_physical_expr::{ reverse_order_bys, window::{BuiltInWindowFunctionExpr, SlidingAggregateWindowExpr}, - AggregateExpr, EquivalenceProperties, LexOrdering, PhysicalSortRequirement, + AggregateExpr, ConstExpr, EquivalenceProperties, LexOrdering, + PhysicalSortRequirement, }; use itertools::Itertools; @@ -576,7 +577,10 @@ pub fn get_window_mode( options: None, })); // Treat partition by exprs as constant. During analysis of requirements are satisfied. - let partition_by_eqs = input_eqs.add_constants(partitionby_exprs.iter().cloned()); + let const_exprs = partitionby_exprs + .iter() + .map(|expr| ConstExpr::new(expr.clone())); + let partition_by_eqs = input_eqs.add_constants(const_exprs); let order_by_reqs = PhysicalSortRequirement::from_sort_exprs(orderby_keys); let reverse_order_by_reqs = PhysicalSortRequirement::from_sort_exprs(&reverse_order_bys(orderby_keys)); diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 552ad6a2a7563..e891093c81560 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -251,7 +251,7 @@ physical_plan 02)--AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[ARRAY_AGG(DISTINCT a.foo), sum(DISTINCT Int64(1))] 03)----CoalesceBatchesExec: target_batch_size=8192 04)------RepartitionExec: partitioning=Hash([id@0], 4), input_partitions=5 -05)--------AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[ARRAY_AGG(DISTINCT a.foo), sum(DISTINCT Int64(1))] +05)--------AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[ARRAY_AGG(DISTINCT a.foo), sum(DISTINCT Int64(1))], ordering_mode=Sorted 06)----------UnionExec 07)------------ProjectionExec: expr=[1 as id, 2 as foo] 08)--------------PlaceholderRowExec diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index b850760b8734a..3a4e8072bbc76 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -257,7 +257,7 @@ physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after OutputRequirements CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE -physical_plan after PipelineChecker SAME TEXT AS ABOVE +physical_plan after SanityCheckPlan SAME TEXT AS ABOVE physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true physical_plan_with_stats CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] physical_plan_with_schema CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, schema=[a:Int32;N, b:Int32;N, c:Int32;N] @@ -336,7 +336,7 @@ physical_plan after OutputRequirements 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE -physical_plan after PipelineChecker SAME TEXT AS ABOVE +physical_plan after SanityCheckPlan SAME TEXT AS ABOVE physical_plan 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] @@ -379,7 +379,7 @@ physical_plan after OutputRequirements 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE -physical_plan after PipelineChecker SAME TEXT AS ABOVE +physical_plan after SanityCheckPlan SAME TEXT AS ABOVE physical_plan 01)GlobalLimitExec: skip=0, fetch=10 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 501ae497745b0..3cbeea0f92221 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -3165,6 +3165,9 @@ WITH ORDER (a ASC NULLS FIRST, b ASC, c ASC) LOCATION '../core/tests/data/window_2.csv' OPTIONS ('format.has_header' 'true'); +statement ok +set datafusion.optimizer.prefer_existing_sort = true; + # sort merge join should propagate ordering equivalence of the left side # for inner join. Hence final requirement rn1 ASC is already satisfied at # the end of SortMergeJoinExec. @@ -3188,18 +3191,16 @@ logical_plan physical_plan 01)SortPreservingMergeExec: [rn1@5 ASC NULLS LAST] 02)--SortMergeJoin: join_type=Inner, on=[(a@1, a@1)] -03)----SortExec: expr=[rn1@5 ASC NULLS LAST], preserve_partitioning=[true] -04)------CoalesceBatchesExec: target_batch_size=2 -05)--------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 -06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -07)------------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] -08)--------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] -09)----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true -10)----SortExec: expr=[a@1 ASC], preserve_partitioning=[true] -11)------CoalesceBatchesExec: target_batch_size=2 -12)--------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 -13)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -14)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +03)----CoalesceBatchesExec: target_batch_size=2 +04)------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST,rn1@5 ASC NULLS LAST +05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +06)----------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] +07)------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +08)--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +09)----CoalesceBatchesExec: target_batch_size=2 +10)------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST +11)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +12)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true # sort merge join should propagate ordering equivalence of the right side # for right join. Hence final requirement rn1 ASC is already satisfied at @@ -3224,18 +3225,19 @@ logical_plan physical_plan 01)SortPreservingMergeExec: [rn1@10 ASC NULLS LAST] 02)--SortMergeJoin: join_type=Right, on=[(a@1, a@1)] -03)----SortExec: expr=[a@1 ASC], preserve_partitioning=[true] -04)------CoalesceBatchesExec: target_batch_size=2 -05)--------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 -06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true -08)----SortExec: expr=[rn1@5 ASC NULLS LAST], preserve_partitioning=[true] -09)------CoalesceBatchesExec: target_batch_size=2 -10)--------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 -11)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -12)------------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] -13)--------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] -14)----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +03)----CoalesceBatchesExec: target_batch_size=2 +04)------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST +05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +07)----CoalesceBatchesExec: target_batch_size=2 +08)------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST,rn1@5 ASC NULLS LAST +09)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +10)----------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] +11)------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +12)--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + +statement ok +set datafusion.optimizer.prefer_existing_sort = false; # SortMergeJoin should add ordering equivalences of # right table as lexicographical append to the global ordering diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index e6f3e70c1ebda..ba07b4ed0a87c 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -276,7 +276,7 @@ physical_plan 04)------AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[MAX(d.a)] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=4 -07)------------AggregateExec: mode=Partial, gby=[b@1 as b], aggr=[MAX(d.a)] +07)------------AggregateExec: mode=Partial, gby=[b@1 as b], aggr=[MAX(d.a)], ordering_mode=Sorted 08)--------------UnionExec 09)----------------ProjectionExec: expr=[1 as a, aa as b] 10)------------------PlaceholderRowExec @@ -3190,8 +3190,8 @@ SELECT a, d, rn1, rank1 FROM (SELECT a, d, # this is a negative test for asserting that ROW_NUMBER is not # added to the ordering equivalence when it contains partition by. # physical plan should contain SortExec. Since source is unbounded -# pipeline checker should raise error, when plan contains SortExec. -statement error DataFusion error: PipelineChecker +# sanity checker should raise error, when plan contains SortExec. +statement error DataFusion error: SanityCheckPlan SELECT a, d, rn1 FROM (SELECT a, d, ROW_NUMBER() OVER(PARTITION BY d ORDER BY a ASC) as rn1 FROM annotated_data_infinite2 From 5feaa4189e7a14ff32ad222ddafc1d7f4f33510c Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 16 Jul 2024 12:14:19 -0400 Subject: [PATCH 5/5] Test + workaround for SanityCheck plan --- .../src/physical_optimizer/sanity_checker.rs | 10 +++++ datafusion/sqllogictest/test_files/union.slt | 38 +++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/datafusion/core/src/physical_optimizer/sanity_checker.rs b/datafusion/core/src/physical_optimizer/sanity_checker.rs index 083b42f7400bc..46e13c74d667f 100644 --- a/datafusion/core/src/physical_optimizer/sanity_checker.rs +++ b/datafusion/core/src/physical_optimizer/sanity_checker.rs @@ -34,6 +34,8 @@ use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supp use datafusion_physical_plan::joins::SymmetricHashJoinExec; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::union::UnionExec; use itertools::izip; /// The SanityCheckPlan rule rejects the following query plans: @@ -125,6 +127,14 @@ pub fn check_plan_sanity( plan.required_input_ordering().iter(), plan.required_input_distribution().iter() ) { + // TEMP HACK WORKAROUND https://github.com/apache/datafusion/issues/11492 + if child.as_any().downcast_ref::().is_some() { + continue; + } + if child.as_any().downcast_ref::().is_some() { + continue; + } + let child_eq_props = child.equivalence_properties(); if let Some(child_sort_req) = child_sort_req { if !child_eq_props.ordering_satisfy_requirement(child_sort_req) { diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index 7b91e97e4a3e2..c343cf4fb7a08 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -449,6 +449,9 @@ physical_plan # Clean up after the test ######## +statement ok +drop table t + statement ok drop table t1; @@ -587,3 +590,38 @@ physical_plan 09)--ProjectionExec: expr=[1 as count, MAX(Int64(10))@0 as n] 10)----AggregateExec: mode=Single, gby=[], aggr=[MAX(Int64(10))] 11)------PlaceholderRowExec + + +### +# Test for https://github.com/apache/datafusion/issues/11492 +### + +# Input data is +# a,b,c +# 1,2,3 + +statement ok +CREATE EXTERNAL TABLE t ( + a INT, + b INT, + c INT +) +STORED AS CSV +LOCATION '../core/tests/data/example.csv' +WITH ORDER (a ASC) +OPTIONS ('format.has_header' 'true'); + +query T +SELECT (SELECT a from t ORDER BY a) UNION ALL (SELECT 'bar' as a from t) ORDER BY a; +---- +1 +bar + +query I +SELECT (SELECT a from t ORDER BY a) UNION ALL (SELECT NULL as a from t) ORDER BY a; +---- +1 +NULL + +statement ok +drop table t