diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 364c412fb02de..38f68ff46f42e 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -499,9 +499,9 @@ dependencies = [ [[package]] name = "cxx" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19f39818dcfc97d45b03953c1292efc4e80954e1583c4aa770bac1383e2310a4" +checksum = "3f83d0ebf42c6eafb8d7c52f7e5f2d3003b89c7aa4fd2b79229209459a849af8" dependencies = [ "cc", "cxxbridge-flags", @@ -511,9 +511,9 @@ dependencies = [ [[package]] name = "cxx-build" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e580d70777c116df50c390d1211993f62d40302881e54d4b79727acb83d0199" +checksum = "07d050484b55975889284352b0ffc2ecbda25c0c55978017c132b29ba0818a86" dependencies = [ "cc", "codespan-reporting", @@ -526,15 +526,15 @@ dependencies = [ [[package]] name = "cxxbridge-flags" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56a46460b88d1cec95112c8c363f0e2c39afdb237f60583b0b36343bf627ea9c" +checksum = "99d2199b00553eda8012dfec8d3b1c75fce747cf27c169a270b3b99e3448ab78" [[package]] name = "cxxbridge-macro" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "747b608fecf06b0d72d440f27acc99288207324b793be2c17991839f3d4995ea" +checksum = "dcb67a6de1f602736dd7eaead0080cf3435df806c61b24b13328db128c58868f" dependencies = [ "proc-macro2", "quote", @@ -649,6 +649,7 @@ dependencies = [ "datafusion-expr", "datafusion-row", "hashbrown", + "itertools", "lazy_static", "md-5", "ordered-float 3.2.0", @@ -1120,9 +1121,9 @@ dependencies = [ [[package]] name = "iana-time-zone-haiku" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fde6edd6cef363e9359ed3c98ba64590ba9eecba2293eb5a723ab32aee8926aa" +checksum = "0703ae284fc167426161c2e3f1da3ea71d94b21bedbcc9494e92b28e334e3dca" dependencies = [ "cxx", "cxx-build", @@ -1298,9 +1299,9 @@ checksum = "292a948cd991e376cf75541fe5b97a1081d713c618b4f1b9500f8844e49eb565" [[package]] name = "libmimalloc-sys" -version = "0.1.25" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11ca136052550448f55df7898c6dbe651c6b574fe38a0d9ea687a9f8088a2e2c" +checksum = "8fc093ab289b0bfda3aa1bdfab9c9542be29c7ef385cfcbe77f8c9813588eb48" dependencies = [ "cc", ] @@ -1376,9 +1377,9 @@ checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" [[package]] name = "mimalloc" -version = "0.1.29" +version = "0.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f64ad83c969af2e732e907564deb0d0ed393cec4af80776f77dd77a1a427698" +checksum = "76ce6a4b40d3bff9eb3ce9881ca0737a85072f9f975886082640cd46a75cdb35" dependencies = [ "libmimalloc-sys", ] @@ -1712,9 +1713,9 @@ checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" [[package]] name = "proc-macro2" -version = "1.0.46" +version = "1.0.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94e2ef8dbfc347b10c094890f778ee2e36ca9bb4262e86dc99cd217e35f3470b" +checksum = "5ea3d908b0e36316caf9e9e2c4625cdde190a7e6f440d794667ed17a1855e725" dependencies = [ "unicode-ident", ] diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index cbb74bbf9b4ce..0944564f88928 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -147,7 +147,7 @@ fn create_runtime_env() -> Result { ObjectStoreRegistry::new_with_provider(Some(Arc::new(object_store_provider))); let rn_config = RuntimeConfig::new().with_object_store_registry(Arc::new(object_store_registry)); - return RuntimeEnv::new(rn_config); + RuntimeEnv::new(rn_config) } fn is_valid_file(dir: &str) -> std::result::Result<(), String> { diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index 19993e751c26d..df5235328a2cc 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -68,7 +68,7 @@ fn build_gcs_object_store(url: &Url) -> Result Result Result<&str> { - url.host_str().ok_or(DataFusionError::Execution(format!( - "Not able to parse hostname from url, {}", - url.as_str() - ))) + url.host_str().ok_or_else(|| { + DataFusionError::Execution(format!( + "Not able to parse hostname from url, {}", + url.as_str() + )) + }) } #[cfg(test)] diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_datasource.rs index 7256c94ff0cfe..18499ddcabb6b 100644 --- a/datafusion-examples/examples/custom_datasource.rs +++ b/datafusion-examples/examples/custom_datasource.rs @@ -223,6 +223,12 @@ impl ExecutionPlan for CustomExec { None } + fn equivalence_properties( + &self, + ) -> Vec> { + vec![] + } + fn children(&self) -> Vec> { vec![] } diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index a699a234cd6d1..188541049f6f9 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -1519,4 +1519,84 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn partition_aware_union() -> Result<()> { + let left = test_table().await?.select_columns(&["c1", "c2"])?; + let right = test_table_with_name("c2") + .await? + .select_columns(&["c1", "c3"])? + .with_column_renamed("c2.c1", "c2_c1")?; + + let left_rows = left.collect().await?; + let right_rows = right.collect().await?; + let join1 = + left.join(right.clone(), JoinType::Inner, &["c1"], &["c2_c1"], None)?; + let join2 = left.join(right, JoinType::Inner, &["c1"], &["c2_c1"], None)?; + + let union = join1.union(join2)?; + + let union_rows = union.collect().await?; + + assert_eq!(100, left_rows.iter().map(|x| x.num_rows()).sum::()); + assert_eq!(100, right_rows.iter().map(|x| x.num_rows()).sum::()); + assert_eq!(4016, union_rows.iter().map(|x| x.num_rows()).sum::()); + + let physical_plan = union.create_physical_plan().await?; + let default_partition_count = + SessionContext::new().copied_config().target_partitions; + assert_eq!( + physical_plan.output_partitioning().partition_count(), + default_partition_count + ); + Ok(()) + } + + #[tokio::test] + async fn non_partition_aware_union() -> Result<()> { + let left = test_table().await?.select_columns(&["c1", "c2"])?; + let right = test_table_with_name("c2") + .await? + .select_columns(&["c1", "c2"])? + .with_column_renamed("c2.c1", "c2_c1")? + .with_column_renamed("c2.c2", "c2_c2")?; + + let left_rows = left.collect().await?; + let right_rows = right.collect().await?; + let join1 = left.join( + right.clone(), + JoinType::Inner, + &["c1", "c2"], + &["c2_c1", "c2_c2"], + None, + )?; + + // join key ordering is different + let join2 = left.join( + right, + JoinType::Inner, + &["c2", "c1"], + &["c2_c2", "c2_c1"], + None, + )?; + + let union = join1.union(join2)?; + + let union_rows = union.collect().await?; + + assert_eq!(100, left_rows.iter().map(|x| x.num_rows()).sum::()); + assert_eq!(100, right_rows.iter().map(|x| x.num_rows()).sum::()); + assert_eq!(916, union_rows.iter().map(|x| x.num_rows()).sum::()); + + let physical_plan = union.create_physical_plan().await?; + let default_partition_count = + SessionContext::new().copied_config().target_partitions; + + // the union's output partitioning count should be the combination of all output partitions count + assert_eq!( + physical_plan.output_partitioning().partition_count(), + default_partition_count * 2 + ); + Ok(()) + } } diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index c66199e073b71..b8cf90a6c00d0 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -73,7 +73,6 @@ use crate::optimizer::optimizer::{OptimizerConfig, OptimizerRule}; use datafusion_sql::{ResolvedTableReference, TableReference}; use crate::physical_optimizer::coalesce_batches::CoalesceBatches; -use crate::physical_optimizer::merge_exec::AddCoalescePartitionsExec; use crate::physical_optimizer::repartition::Repartition; use crate::config::{ @@ -82,6 +81,7 @@ use crate::config::{ }; use crate::datasource::file_format::file_type::{FileCompressionType, FileType}; use crate::execution::{runtime_env::RuntimeEnv, FunctionRegistry}; +use crate::physical_optimizer::enforcement::BasicEnforcement; use crate::physical_plan::file_format::{plan_to_csv, plan_to_json, plan_to_parquet}; use crate::physical_plan::planner::DefaultPhysicalPlanner; use crate::physical_plan::udaf::AggregateUDF; @@ -1145,6 +1145,8 @@ pub struct SessionConfig { pub parquet_pruning: bool, /// Should DataFusion collect statistics after listing files pub collect_statistics: bool, + /// Should DataFusion optimizer run a top down process to reorder the join keys + pub top_down_join_key_reordering: bool, /// Configuration options pub config_options: Arc>, /// Opaque extensions. @@ -1164,6 +1166,7 @@ impl Default for SessionConfig { repartition_windows: true, parquet_pruning: true, collect_statistics: false, + top_down_join_key_reordering: true, config_options: Arc::new(RwLock::new(ConfigOptions::new())), // Assume no extensions by default. extensions: HashMap::with_capacity_and_hasher( @@ -1479,6 +1482,7 @@ impl SessionState { Arc::new(AggregateStatistics::new()), Arc::new(HashBuildProbeOrder::new()), ]; + physical_optimizers.push(Arc::new(BasicEnforcement::new())); if config .config_options .read() @@ -1496,7 +1500,8 @@ impl SessionState { ))); } physical_optimizers.push(Arc::new(Repartition::new())); - physical_optimizers.push(Arc::new(AddCoalescePartitionsExec::new())); + physical_optimizers.push(Arc::new(BasicEnforcement::new())); + // physical_optimizers.push(Arc::new(AddCoalescePartitionsExec::new())); SessionState { session_id, diff --git a/datafusion/core/src/physical_optimizer/coalesce_batches.rs b/datafusion/core/src/physical_optimizer/coalesce_batches.rs index 9f7e22dbb0857..5557f5070db00 100644 --- a/datafusion/core/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/core/src/physical_optimizer/coalesce_batches.rs @@ -23,9 +23,10 @@ use crate::{ physical_optimizer::PhysicalOptimizerRule, physical_plan::{ coalesce_batches::CoalesceBatchesExec, filter::FilterExec, joins::HashJoinExec, - repartition::RepartitionExec, with_new_children_if_necessary, + repartition::RepartitionExec, TreeNodeRewritable, }, }; + use std::sync::Arc; /// Optimizer rule that introduces CoalesceBatchesExec to avoid overhead with small batches that @@ -42,40 +43,32 @@ impl CoalesceBatches { Self { target_batch_size } } } + impl PhysicalOptimizerRule for CoalesceBatches { fn optimize( &self, plan: Arc, - config: &crate::execution::context::SessionConfig, + _config: &crate::execution::context::SessionConfig, ) -> Result> { - if plan.children().is_empty() { - // leaf node, children cannot be replaced - Ok(plan.clone()) - } else { - // recurse down first - let children = plan - .children() - .iter() - .map(|child| self.optimize(child.clone(), config)) - .collect::>>()?; - let plan = with_new_children_if_necessary(plan, children)?; + let target_batch_size = self.target_batch_size; + plan.transform_up(&|plan| { + let plan_any = plan.as_any(); // The goal here is to detect operators that could produce small batches and only // wrap those ones with a CoalesceBatchesExec operator. An alternate approach here // would be to build the coalescing logic directly into the operators // See https://github.com/apache/arrow-datafusion/issues/139 - let plan_any = plan.as_any(); let wrap_in_coalesce = plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some(); - Ok(if wrap_in_coalesce { - Arc::new(CoalesceBatchesExec::new( + if wrap_in_coalesce { + Some(Arc::new(CoalesceBatchesExec::new( plan.clone(), - self.target_batch_size, - )) + target_batch_size, + ))) } else { - plan.clone() - }) - } + None + } + }) } fn name(&self) -> &str { diff --git a/datafusion/core/src/physical_optimizer/enforcement.rs b/datafusion/core/src/physical_optimizer/enforcement.rs new file mode 100644 index 0000000000000..a156e9459912e --- /dev/null +++ b/datafusion/core/src/physical_optimizer/enforcement.rs @@ -0,0 +1,1636 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Enforcement optimizer rules are used to make sure the plan's Distribution and Ordering +//! requirements are met by inserting necessary [[RepartitionExec]] and [[SortExec]]. +//! +use crate::error::Result; +use crate::physical_optimizer::PhysicalOptimizerRule; +use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; +use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use crate::physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; +use crate::physical_plan::projection::ProjectionExec; +use crate::physical_plan::repartition::RepartitionExec; +use crate::physical_plan::sorts::sort::SortExec; +use crate::physical_plan::windows::WindowAggExec; +use crate::physical_plan::{with_new_children_if_necessary, Distribution, ExecutionPlan}; +use crate::physical_plan::{Partitioning, TreeNodeRewritable}; +use crate::prelude::SessionConfig; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::expressions::NoOp; +use datafusion_physical_expr::{ + expr_list_eq_strict_order, normalize_expr_with_equivalence_properties, + normalize_sort_expr_with_equivalence_properties, PhysicalExpr, PhysicalSortExpr, +}; +use std::collections::HashMap; +use std::sync::Arc; + +/// BasicEnforcement rule, it ensures the Distribution and Ordering requirements are met +/// in the strictest way. It might add additional [[RepartitionExec]] to the plan tree +/// and give a non-optimal plan, but it can avoid the possible data skew in joins +/// +/// For example for a HashJoin with keys(a, b, c), the required Distribution(a, b, c) can be satisfied by +/// several alternative partitioning ways: [(a, b, c), (a, b), (a, c), (b, c), (a), (b), (c), ( )]. +/// +/// This rule only chooses the exactly match and satisfies the Distribution(a, b, c) by a HashPartition(a, b, c). +#[derive(Default)] +pub struct BasicEnforcement {} + +impl BasicEnforcement { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for BasicEnforcement { + fn optimize( + &self, + plan: Arc, + config: &SessionConfig, + ) -> Result> { + let target_partitions = config.target_partitions; + let top_down_join_key_reordering = config.top_down_join_key_reordering; + let new_plan = if top_down_join_key_reordering { + // Run a top-down process to adjust input key ordering recursively + adjust_input_keys_down_recursively(plan, vec![])? + } else { + plan + }; + // Distribution and Ordering enforcement need to be applied bottom-up. + new_plan.transform_up(&{ + |plan| { + let adjusted = if !top_down_join_key_reordering { + reorder_join_keys_to_inputs(plan) + } else { + plan + }; + Some(ensure_distribution_and_ordering( + adjusted, + target_partitions, + )) + } + }) + } + + fn name(&self) -> &str { + "BasicEnforcement" + } +} + +/// When the physical planner creates the Joins, the ordering of join keys is from the original query. +/// That might not match with the output partitioning of the join node's children +/// This method run a top-down process and try to adjust the output partitionging of the children +/// if children themselves are joins or aggregations. +fn adjust_input_keys_down_recursively( + plan: Arc, + parent_required: Vec>, +) -> Result> { + let plan_any = plan.as_any(); + if let Some(HashJoinExec { + left, + right, + on, + filter, + join_type, + mode, + null_equals_null, + .. + }) = plan_any.downcast_ref::() + { + match mode { + PartitionMode::Partitioned => { + let join_key_pairs = extract_join_keys(on); + if let Some(( + JoinKeyPairs { + left_keys, + right_keys, + }, + new_positions, + )) = try_reorder( + join_key_pairs.clone(), + parent_required, + &plan.equivalence_properties(), + ) { + let new_join_on = if !new_positions.is_empty() { + new_join_conditions(&left_keys, &right_keys) + } else { + on.clone() + }; + let new_left = + adjust_input_keys_down_recursively(left.clone(), left_keys)?; + let new_right = + adjust_input_keys_down_recursively(right.clone(), right_keys)?; + Ok(Arc::new(HashJoinExec::try_new( + new_left, + new_right, + new_join_on, + filter.clone(), + join_type, + PartitionMode::Partitioned, + null_equals_null, + )?)) + } else { + let new_left = adjust_input_keys_down_recursively( + left.clone(), + join_key_pairs.left_keys, + )?; + let new_right = adjust_input_keys_down_recursively( + right.clone(), + join_key_pairs.right_keys, + )?; + Ok(Arc::new(HashJoinExec::try_new( + new_left, + new_right, + on.clone(), + filter.clone(), + join_type, + PartitionMode::Partitioned, + null_equals_null, + )?)) + } + } + PartitionMode::CollectLeft => { + let new_right = + adjust_input_keys_down_recursively(right.clone(), parent_required)?; + Ok(Arc::new(HashJoinExec::try_new( + left.clone(), + new_right, + on.clone(), + filter.clone(), + join_type, + PartitionMode::CollectLeft, + null_equals_null, + )?)) + } + } + } else if let Some(SortMergeJoinExec { + left, + right, + on, + join_type, + sort_options, + null_equals_null, + .. + }) = plan_any.downcast_ref::() + { + let join_key_pairs = extract_join_keys(on); + if let Some(( + JoinKeyPairs { + left_keys, + right_keys, + }, + new_positions, + )) = try_reorder( + join_key_pairs.clone(), + parent_required, + &plan.equivalence_properties(), + ) { + let new_join_on = if !new_positions.is_empty() { + new_join_conditions(&left_keys, &right_keys) + } else { + on.clone() + }; + let new_options = if !new_positions.is_empty() { + let mut new_sort_options = vec![]; + for idx in 0..sort_options.len() { + new_sort_options.push(sort_options[new_positions[idx]]) + } + new_sort_options + } else { + sort_options.clone() + }; + + let new_left = adjust_input_keys_down_recursively(left.clone(), left_keys)?; + let new_right = + adjust_input_keys_down_recursively(right.clone(), right_keys)?; + + Ok(Arc::new(SortMergeJoinExec::try_new( + new_left, + new_right, + new_join_on, + *join_type, + new_options, + *null_equals_null, + )?)) + } else { + let new_left = adjust_input_keys_down_recursively( + left.clone(), + join_key_pairs.left_keys, + )?; + let new_right = adjust_input_keys_down_recursively( + right.clone(), + join_key_pairs.right_keys, + )?; + Ok(Arc::new(SortMergeJoinExec::try_new( + new_left, + new_right, + on.clone(), + *join_type, + sort_options.clone(), + *null_equals_null, + )?)) + } + } else if let Some(AggregateExec { + mode, + group_by, + aggr_expr, + input, + input_schema, + .. + }) = plan_any.downcast_ref::() + { + if parent_required.is_empty() { + Ok(plan) + } else { + match mode { + AggregateMode::FinalPartitioned | AggregateMode::Partial => { + let out_put_columns = group_by + .expr() + .iter() + .enumerate() + .map(|(index, (_col, name))| Column::new(name, index)) + .collect::>(); + + let out_put_exprs = out_put_columns + .iter() + .map(|c| Arc::new(c.clone()) as Arc) + .collect::>(); + + // Check whether the requirements can be satisfied by the Aggregation + if parent_required.len() != out_put_exprs.len() + || expr_list_eq_strict_order(&out_put_exprs, &parent_required) + || !group_by.null_expr().is_empty() + { + Ok(plan) + } else { + let new_positions = + expected_expr_positions(&out_put_exprs, &parent_required); + match new_positions { + Some(positions) => { + let mut new_group_exprs = vec![]; + for idx in positions.into_iter() { + new_group_exprs.push(group_by.expr()[idx].clone()); + } + let new_group_by = + PhysicalGroupBy::new_single(new_group_exprs); + match mode { + AggregateMode::FinalPartitioned => { + let new_input = + adjust_input_keys_down_recursively( + input.clone(), + parent_required, + )?; + let new_agg = Arc::new(AggregateExec::try_new( + AggregateMode::FinalPartitioned, + new_group_by, + aggr_expr.clone(), + new_input, + input_schema.clone(), + )?); + + // Need to create a new projection to change the expr ordering back + let mut proj_exprs = out_put_columns + .iter() + .map(|col| { + ( + Arc::new(Column::new( + col.name(), + new_agg + .schema() + .index_of(col.name()) + .unwrap(), + )) + as Arc, + col.name().to_owned(), + ) + }) + .collect::>(); + let agg_schema = new_agg.schema(); + let agg_fields = agg_schema.fields(); + for (idx, field) in agg_fields + .iter() + .enumerate() + .skip(out_put_columns.len()) + { + proj_exprs.push(( + Arc::new(Column::new( + field.name().as_str(), + idx, + )) + as Arc, + field.name().clone(), + )) + } + // TODO merge adjacent Projections if there are + Ok(Arc::new(ProjectionExec::try_new( + proj_exprs, new_agg, + )?)) + } + AggregateMode::Partial => { + Ok(Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + new_group_by, + aggr_expr.clone(), + input.clone(), + input_schema.clone(), + )?)) + } + _ => Ok(plan), + } + } + _ => Ok(plan), + } + } + } + _ => Ok(plan), + } + } + } else if let Some(ProjectionExec { expr, .. }) = + plan_any.downcast_ref::() + { + // For Projection, we need to transform the columns to the columns before the Projection + // And then to push down the requirements + let mut column_mapping = HashMap::new(); + for (expression, name) in expr.iter() { + if let Some(column) = expression.as_any().downcast_ref::() { + column_mapping.insert(name.clone(), column.clone()); + }; + } + let new_required: Vec> = parent_required + .iter() + .filter_map(|r| { + if let Some(column) = r.as_any().downcast_ref::() { + column_mapping.get(column.name()) + } else { + None + } + }) + .map(|e| Arc::new(e.clone()) as Arc) + .collect::>(); + if new_required.len() == parent_required.len() { + plan.map_children(|plan| { + adjust_input_keys_down_recursively(plan, new_required.clone()) + }) + } else { + Ok(plan) + } + } else if let Some(WindowAggExec { input: _, .. }) = + plan_any.downcast_ref::() + { + // TODO + Ok(plan) + } else if parent_required.is_empty() { + Ok(plan) + } else { + plan.map_children(|plan| { + adjust_input_keys_down_recursively(plan, parent_required.clone()) + }) + } +} + +/// When the physical planner creates the Joins, the ordering of join keys is from the original query. +/// That might not match with the output partitioning of the join node's children +/// This method will try to change the ordering of the join keys to match with the +/// partitioning of the join nodes' children. +/// If it can not match with both sides, it will try to match with one, either left side or right side. +fn reorder_join_keys_to_inputs( + plan: Arc, +) -> Arc { + let plan_any = plan.as_any(); + if let Some(HashJoinExec { + left, + right, + on, + filter, + join_type, + mode, + null_equals_null, + .. + }) = plan_any.downcast_ref::() + { + match mode { + PartitionMode::Partitioned => { + let join_key_pairs = extract_join_keys(on); + if let Some(( + JoinKeyPairs { + left_keys, + right_keys, + }, + new_positions, + )) = reorder_current_join_keys( + join_key_pairs, + Some(left.output_partitioning()), + Some(right.output_partitioning()), + &plan.equivalence_properties(), + ) { + if !new_positions.is_empty() { + let new_join_on = new_join_conditions(&left_keys, &right_keys); + Arc::new( + HashJoinExec::try_new( + left.clone(), + right.clone(), + new_join_on, + filter.clone(), + join_type, + PartitionMode::Partitioned, + null_equals_null, + ) + .unwrap(), + ) + } else { + plan + } + } else { + plan + } + } + _ => plan, + } + } else if let Some(SortMergeJoinExec { + left, + right, + on, + join_type, + sort_options, + null_equals_null, + .. + }) = plan_any.downcast_ref::() + { + let join_key_pairs = extract_join_keys(on); + if let Some(( + JoinKeyPairs { + left_keys, + right_keys, + }, + new_positions, + )) = reorder_current_join_keys( + join_key_pairs, + Some(left.output_partitioning()), + Some(right.output_partitioning()), + &plan.equivalence_properties(), + ) { + if !new_positions.is_empty() { + let new_join_on = new_join_conditions(&left_keys, &right_keys); + let mut new_sort_options = vec![]; + for idx in 0..sort_options.len() { + new_sort_options.push(sort_options[new_positions[idx]]) + } + Arc::new( + SortMergeJoinExec::try_new( + left.clone(), + right.clone(), + new_join_on, + *join_type, + new_sort_options, + *null_equals_null, + ) + .unwrap(), + ) + } else { + plan + } + } else { + plan + } + } else { + plan + } +} + +/// Reorder the current join keys ordering based on either left partition or right partition. +fn reorder_current_join_keys( + join_keys: JoinKeyPairs, + left_partition: Option, + right_partition: Option, + equivalence_properties: &[Vec], +) -> Option<(JoinKeyPairs, Vec)> { + match (left_partition.clone(), right_partition.clone()) { + (Some(Partitioning::Hash(left_exprs, _)), _) => { + try_reorder(join_keys.clone(), left_exprs, equivalence_properties).or_else( + || { + reorder_current_join_keys( + join_keys, + None, + right_partition, + equivalence_properties, + ) + }, + ) + } + (_, Some(Partitioning::Hash(right_exprs, _))) => { + try_reorder(join_keys.clone(), right_exprs, equivalence_properties).or_else( + || { + reorder_current_join_keys( + join_keys, + left_partition, + None, + equivalence_properties, + ) + }, + ) + } + _ => None, + } +} + +fn try_reorder( + join_keys: JoinKeyPairs, + expected: Vec>, + equivalence_properties: &[Vec], +) -> Option<(JoinKeyPairs, Vec)> { + if join_keys.left_keys.len() != expected.len() { + return None; + } + if expr_list_eq_strict_order(&expected, &join_keys.left_keys) { + return Some((join_keys, vec![])); + } + let new_positions = expected_expr_positions(&join_keys.left_keys, &expected); + match new_positions { + Some(positions) => { + let mut new_right_keys = vec![]; + for pos in positions.iter() { + new_right_keys.push(join_keys.right_keys[*pos].clone()); + } + Some(( + JoinKeyPairs { + left_keys: expected, + right_keys: new_right_keys, + }, + positions, + )) + } + None => { + if !equivalence_properties.is_empty() { + let normalized_expected = expected + .iter() + .map(|e| { + normalize_expr_with_equivalence_properties( + e.clone(), + equivalence_properties, + ) + }) + .collect::>(); + let normalized_left_keys = join_keys + .left_keys + .iter() + .map(|e| { + normalize_expr_with_equivalence_properties( + e.clone(), + equivalence_properties, + ) + }) + .collect::>(); + if expr_list_eq_strict_order(&normalized_expected, &normalized_left_keys) + { + Some((join_keys, vec![])) + } else { + let new_positions = expected_expr_positions( + &normalized_left_keys, + &normalized_expected, + ); + match new_positions { + Some(positions) => { + let mut new_left_keys = vec![]; + let mut new_right_keys = vec![]; + for pos in positions.iter() { + new_left_keys.push(join_keys.left_keys[*pos].clone()); + new_right_keys.push(join_keys.right_keys[*pos].clone()); + } + Some(( + JoinKeyPairs { + left_keys: new_left_keys, + right_keys: new_right_keys, + }, + positions, + )) + } + None => None, + } + } + } else { + None + } + } + } +} + +/// Return the expected expressions positions. +/// For example, the current expressions are ['c', 'a', 'a', b'], the expected expressions are ['b', 'c', 'a', 'a'], +/// +/// This method will return a Vec [3, 0, 1, 2] +fn expected_expr_positions( + current: &[Arc], + expected: &[Arc], +) -> Option> { + let mut indexes: Vec = vec![]; + let mut current = current.to_vec(); + for expr in expected.iter() { + // Find the position of the expected expr in the current expressions + if let Some(expected_position) = current.iter().position(|e| e.eq(expr)) { + current[expected_position] = Arc::new(NoOp::new()); + indexes.push(expected_position); + } else { + return None; + } + } + Some(indexes) +} + +fn extract_join_keys(on: &[(Column, Column)]) -> JoinKeyPairs { + let (left_keys, right_keys) = on + .iter() + .map(|(l, r)| { + ( + Arc::new(l.clone()) as Arc, + Arc::new(r.clone()) as Arc, + ) + }) + .unzip(); + JoinKeyPairs { + left_keys, + right_keys, + } +} + +fn new_join_conditions( + new_left_keys: &[Arc], + new_right_keys: &[Arc], +) -> Vec<(Column, Column)> { + let new_join_on = new_left_keys + .iter() + .zip(new_right_keys.iter()) + .map(|(l_key, r_key)| { + ( + l_key.as_any().downcast_ref::().unwrap().clone(), + r_key.as_any().downcast_ref::().unwrap().clone(), + ) + }) + .collect::>(); + new_join_on +} + +fn ensure_distribution_and_ordering( + plan: Arc, + target_partitions: usize, +) -> Arc { + if plan.children().is_empty() { + return plan; + } + let required_input_distributions = plan.required_input_distribution(); + let required_input_orderings = plan.required_input_ordering(); + let children: Vec> = plan.children(); + assert_eq!(children.len(), required_input_distributions.len()); + assert_eq!(children.len(), required_input_orderings.len()); + + // Add RepartitionExec to guarantee output partitioning + let children = children + .into_iter() + .zip(required_input_distributions.into_iter()) + .map(|(child, required)| { + if child + .output_partitioning() + .satisfy(required.clone(), || child.equivalence_properties()) + { + child + } else { + let new_child: Arc = match required { + Distribution::SinglePartition + if child.output_partitioning().partition_count() > 1 => + { + Arc::new(CoalescePartitionsExec::new(child.clone())) + } + _ => { + let partition = required.create_partitioning(target_partitions); + Arc::new(RepartitionExec::try_new(child, partition).unwrap()) + } + }; + new_child + } + }); + + // Add SortExec to guarantee output ordering + let new_children: Vec> = children + .zip(required_input_orderings.into_iter()) + .map(|(child, required)| { + if ordering_satisfy(child.output_ordering(), required, || { + child.equivalence_properties() + }) { + child + } else { + let sort_expr = required.unwrap().to_vec(); + if child.output_partitioning().partition_count() > 1 { + Arc::new(SortExec::new_with_partitioning( + sort_expr, child, true, None, + )) + } else { + Arc::new(SortExec::try_new(sort_expr, child, None).unwrap()) + } + } + }) + .collect::>(); + + with_new_children_if_necessary(plan, new_children).unwrap() +} + +/// DynamicEnforcement rule +/// +/// +#[derive(Default)] +pub struct DynamicEnforcement {} + +// TODO +impl DynamicEnforcement { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +/// Check the required ordering requirements are satisfied by the provided PhysicalSortExprs. +fn ordering_satisfy Vec>>( + provided: Option<&[PhysicalSortExpr]>, + required: Option<&[PhysicalSortExpr]>, + equal_properties: F, +) -> bool { + match (provided, required) { + (_, None) => true, + (None, Some(_)) => false, + (Some(provided), Some(required)) => { + if required.len() > provided.len() { + false + } else { + let fast_match = required + .iter() + .zip(provided.iter()) + .all(|(order1, order2)| order1.eq(order2)); + + if !fast_match { + let eq_properties = equal_properties(); + if !eq_properties.is_empty() { + let normalized_required_exprs = required + .iter() + .map(|e| { + normalize_sort_expr_with_equivalence_properties( + e.clone(), + &eq_properties, + ) + }) + .collect::>(); + let normalized_provided_exprs = provided + .iter() + .map(|e| { + normalize_sort_expr_with_equivalence_properties( + e.clone(), + &eq_properties, + ) + }) + .collect::>(); + normalized_required_exprs + .iter() + .zip(normalized_provided_exprs.iter()) + .all(|(order1, order2)| order1.eq(order2)) + } else { + fast_match + } + } else { + fast_match + } + } + } + } +} + +#[derive(Debug, Clone)] +struct JoinKeyPairs { + left_keys: Vec>, + right_keys: Vec>, +} + +#[cfg(test)] +mod tests { + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_expr::logical_plan::JoinType; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr::{expressions, PhysicalExpr}; + + use super::*; + use crate::config::ConfigOptions; + use crate::datasource::listing::PartitionedFile; + use crate::datasource::object_store::ObjectStoreUrl; + use crate::physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, + }; + use crate::physical_plan::expressions::col; + use crate::physical_plan::file_format::{FileScanConfig, ParquetExec}; + use crate::physical_plan::joins::{ + utils::JoinOn, HashJoinExec, PartitionMode, SortMergeJoinExec, + }; + use crate::physical_plan::projection::ProjectionExec; + use crate::physical_plan::{displayable, Statistics}; + + fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Boolean, true), + ])) + } + + fn parquet_exec() -> Arc { + Arc::new(ParquetExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema(), + file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], + statistics: Statistics::default(), + projection: None, + limit: None, + table_partition_cols: vec![], + config_options: ConfigOptions::new().into_shareable(), + }, + None, + None, + )) + } + + fn projection_exec_with_alias( + input: Arc, + alias_pairs: Vec<(String, String)>, + ) -> Arc { + let mut exprs = vec![]; + for (column, alias) in alias_pairs.iter() { + exprs.push((col(column, &input.schema()).unwrap(), alias.to_string())); + } + Arc::new(ProjectionExec::try_new(exprs, input).unwrap()) + } + + fn aggregate_exec_with_alias( + input: Arc, + alias_pairs: Vec<(String, String)>, + ) -> Arc { + let schema = schema(); + let mut group_by_expr: Vec<(Arc, String)> = vec![]; + for (column, alias) in alias_pairs.iter() { + group_by_expr + .push((col(column, &input.schema()).unwrap(), alias.to_string())); + } + let group_by = PhysicalGroupBy::new_single(group_by_expr.clone()); + + let final_group_by_expr = group_by_expr + .iter() + .enumerate() + .map(|(index, (_col, name))| { + ( + Arc::new(expressions::Column::new(name, index)) + as Arc, + name.clone(), + ) + }) + .collect::>(); + let final_grouping = PhysicalGroupBy::new_single(final_group_by_expr); + + Arc::new( + AggregateExec::try_new( + AggregateMode::FinalPartitioned, + final_grouping, + vec![], + Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + group_by, + vec![], + input, + schema.clone(), + ) + .unwrap(), + ), + schema, + ) + .unwrap(), + ) + } + + fn hash_join_exec( + left: Arc, + right: Arc, + join_on: &JoinOn, + join_type: &JoinType, + ) -> Arc { + Arc::new( + HashJoinExec::try_new( + left, + right, + join_on.clone(), + None, + join_type, + PartitionMode::Partitioned, + &false, + ) + .unwrap(), + ) + } + + fn sort_merge_join_exec( + left: Arc, + right: Arc, + join_on: &JoinOn, + join_type: &JoinType, + ) -> Arc { + Arc::new( + SortMergeJoinExec::try_new( + left, + right, + join_on.clone(), + *join_type, + vec![SortOptions::default(); join_on.len()], + false, + ) + .unwrap(), + ) + } + + fn trim_plan_display(plan: &str) -> Vec<&str> { + plan.split('\n') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .collect() + } + + /// Runs the repartition optimizer and asserts the plan against the expected + macro_rules! assert_optimized { + ($EXPECTED_LINES: expr, $PLAN: expr) => { + let expected_lines: Vec<&str> = $EXPECTED_LINES.iter().map(|s| *s).collect(); + + // run optimizer + let optimizer = BasicEnforcement {}; + let optimized = optimizer + .optimize($PLAN, &SessionConfig::new().with_target_partitions(10))?; + + // Now format correctly + let plan = displayable(optimized.as_ref()).indent().to_string(); + let actual_lines = trim_plan_display(&plan); + + assert_eq!( + &expected_lines, &actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); + }; + } + + #[test] + fn multi_hash_joins() -> Result<()> { + let left = parquet_exec(); + let alias_pairs: Vec<(String, String)> = vec![ + ("a".to_string(), "a1".to_string()), + ("b".to_string(), "b1".to_string()), + ("c".to_string(), "c1".to_string()), + ("d".to_string(), "d1".to_string()), + ("e".to_string(), "e1".to_string()), + ]; + let right = projection_exec_with_alias(parquet_exec(), alias_pairs); + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::Semi, + JoinType::Anti, + ]; + + // Join on (a == b1) + let join_on = vec![( + Column::new_with_schema("a", &schema()).unwrap(), + Column::new_with_schema("b1", &right.schema()).unwrap(), + )]; + + for join_type in join_types { + let join = hash_join_exec(left.clone(), right.clone(), &join_on, &join_type); + // Join on (a == c) + let top_join_on = vec![( + Column::new_with_schema("a", &join.schema()).unwrap(), + Column::new_with_schema("c", &schema()).unwrap(), + )]; + + let top_join = + hash_join_exec(join.clone(), parquet_exec(), &top_join_on, &join_type); + + let top_join_plan = + format!("HashJoinExec: mode=Partitioned, join_type={}, on=[(Column {{ name: \"a\", index: 0 }}, Column {{ name: \"c\", index: 2 }})]", join_type); + let join_plan = + format!("HashJoinExec: mode=Partitioned, join_type={}, on=[(Column {{ name: \"a\", index: 0 }}, Column {{ name: \"b1\", index: 1 }})]", join_type); + + let expected = match join_type { + // Should include 3 RepartitionExecs + JoinType::Inner | JoinType::Left => vec![ + top_join_plan.as_str(), + join_plan.as_str(), + "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10)", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + ], + // Should include 4 RepartitionExecs + _ => vec![ + top_join_plan.as_str(), + "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)", + join_plan.as_str(), + "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10)", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + ], + }; + assert_optimized!(expected, top_join); + + match join_type { + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + // This time we use (b1 == c) for top join + // Join on (b1 == c) + let top_join_on = vec![( + Column::new_with_schema("b1", &join.schema()).unwrap(), + Column::new_with_schema("c", &schema()).unwrap(), + )]; + + let top_join = + hash_join_exec(join, parquet_exec(), &top_join_on, &join_type); + let top_join_plan = + format!("HashJoinExec: mode=Partitioned, join_type={}, on=[(Column {{ name: \"b1\", index: 6 }}, Column {{ name: \"c\", index: 2 }})]", join_type); + + let expected = match join_type { + // Should include 3 RepartitionExecs + JoinType::Inner | JoinType::Right => vec![ + top_join_plan.as_str(), + join_plan.as_str(), + "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10)", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + ], + // Should include 4 RepartitionExecs + _ => vec![ + top_join_plan.as_str(), + "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 6 }], 10)", + join_plan.as_str(), + "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10)", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + ], + }; + assert_optimized!(expected, top_join); + } + _ => {} + } + } + + Ok(()) + } + + #[test] + fn multi_joins_after_alias() -> Result<()> { + let left = parquet_exec(); + let right = parquet_exec(); + + // Join on (a == b) + let join_on = vec![( + Column::new_with_schema("a", &schema()).unwrap(), + Column::new_with_schema("b", &schema()).unwrap(), + )]; + let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); + + // Projection(a as a1, a as a2) + let alias_pairs: Vec<(String, String)> = vec![ + ("a".to_string(), "a1".to_string()), + ("a".to_string(), "a2".to_string()), + ]; + let projection = projection_exec_with_alias(join, alias_pairs); + + // Join on (a1 == c) + let top_join_on = vec![( + Column::new_with_schema("a1", &projection.schema()).unwrap(), + Column::new_with_schema("c", &schema()).unwrap(), + )]; + + let top_join = hash_join_exec( + projection.clone(), + right.clone(), + &top_join_on, + &JoinType::Inner, + ); + + // Output partition need to respect the Alias and should not introduce additional RepartitionExec + let expected = &[ + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a1\", index: 0 }, Column { name: \"c\", index: 2 })]", + "ProjectionExec: expr=[a@0 as a1, a@0 as a2]", + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a\", index: 0 }, Column { name: \"b\", index: 1 })]", + "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 1 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + ]; + assert_optimized!(expected, top_join); + + // Join on (a2 == c) + let top_join_on = vec![( + Column::new_with_schema("a2", &projection.schema()).unwrap(), + Column::new_with_schema("c", &schema()).unwrap(), + )]; + + let top_join = hash_join_exec(projection, right, &top_join_on, &JoinType::Inner); + + // Output partition need to respect the Alias and should not introduce additional RepartitionExec + let expected = &[ + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a2\", index: 1 }, Column { name: \"c\", index: 2 })]", + "ProjectionExec: expr=[a@0 as a1, a@0 as a2]", + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a\", index: 0 }, Column { name: \"b\", index: 1 })]", + "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 1 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + ]; + + assert_optimized!(expected, top_join); + Ok(()) + } + + #[test] + fn multi_joins_after_multi_alias() -> Result<()> { + let left = parquet_exec(); + let right = parquet_exec(); + + // Join on (a == b) + let join_on = vec![( + Column::new_with_schema("a", &schema()).unwrap(), + Column::new_with_schema("b", &schema()).unwrap(), + )]; + + let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); + + // Projection(c as c1) + let alias_pairs: Vec<(String, String)> = + vec![("c".to_string(), "c1".to_string())]; + let projection = projection_exec_with_alias(join, alias_pairs); + + // Projection(c1 as a) + let alias_pairs: Vec<(String, String)> = + vec![("c1".to_string(), "a".to_string())]; + let projection2 = projection_exec_with_alias(projection, alias_pairs); + + // Join on (a == c) + let top_join_on = vec![( + Column::new_with_schema("a", &projection2.schema()).unwrap(), + Column::new_with_schema("c", &schema()).unwrap(), + )]; + + let top_join = hash_join_exec(projection2, right, &top_join_on, &JoinType::Inner); + + // The Column 'a' has different meaning now after the two Projections + // The original Output partition can not satisfy the Join requirements and need to add an additional RepartitionExec + let expected = &[ + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a\", index: 0 }, Column { name: \"c\", index: 2 })]", + "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)", + "ProjectionExec: expr=[c1@0 as a]", + "ProjectionExec: expr=[c@2 as c1]", + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a\", index: 0 }, Column { name: \"b\", index: 1 })]", + "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 1 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + ]; + + assert_optimized!(expected, top_join); + Ok(()) + } + + #[test] + fn join_after_agg_alias() -> Result<()> { + // group by (a as a1) + let left = aggregate_exec_with_alias( + parquet_exec(), + vec![("a".to_string(), "a1".to_string())], + ); + // group by (a as a2) + let right = aggregate_exec_with_alias( + parquet_exec(), + vec![("a".to_string(), "a2".to_string())], + ); + + // Join on (a1 == a2) + let join_on = vec![( + Column::new_with_schema("a1", &left.schema()).unwrap(), + Column::new_with_schema("a2", &right.schema()).unwrap(), + )]; + let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); + + // Only two RepartitionExecs added + let expected = &[ + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a1\", index: 0 }, Column { name: \"a2\", index: 0 })]", + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", + "RepartitionExec: partitioning=Hash([Column { name: \"a1\", index: 0 }], 10)", + "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "AggregateExec: mode=FinalPartitioned, gby=[a2@0 as a2], aggr=[]", + "RepartitionExec: partitioning=Hash([Column { name: \"a2\", index: 0 }], 10)", + "AggregateExec: mode=Partial, gby=[a@0 as a2], aggr=[]", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + ]; + assert_optimized!(expected, join); + Ok(()) + } + + #[test] + fn hash_join_key_ordering() -> Result<()> { + // group by (a as a1, b as b1) + let left = aggregate_exec_with_alias( + parquet_exec(), + vec![ + ("a".to_string(), "a1".to_string()), + ("b".to_string(), "b1".to_string()), + ], + ); + // group by (b, a) + let right = aggregate_exec_with_alias( + parquet_exec(), + vec![ + ("b".to_string(), "b".to_string()), + ("a".to_string(), "a".to_string()), + ], + ); + + // Join on (b1 == b && a1 == a) + let join_on = vec![ + ( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b", &right.schema()).unwrap(), + ), + ( + Column::new_with_schema("a1", &left.schema()).unwrap(), + Column::new_with_schema("a", &right.schema()).unwrap(), + ), + ]; + let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); + + // Only two RepartitionExecs added + let expected = &[ + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"b1\", index: 1 }, Column { name: \"b\", index: 0 }), (Column { name: \"a1\", index: 0 }, Column { name: \"a\", index: 1 })]", + "ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]", + "AggregateExec: mode=FinalPartitioned, gby=[b1@1 as b1, a1@0 as a1], aggr=[]", + "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 0 }, Column { name: \"a1\", index: 1 }], 10)", + "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[]", + "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 0 }, Column { name: \"a\", index: 1 }], 10)", + "AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[]", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + ]; + assert_optimized!(expected, join); + Ok(()) + } + + #[test] + fn multi_hash_join_key_ordering() -> Result<()> { + let left = parquet_exec(); + let alias_pairs: Vec<(String, String)> = vec![ + ("a".to_string(), "a1".to_string()), + ("b".to_string(), "b1".to_string()), + ("c".to_string(), "c1".to_string()), + ]; + let right = projection_exec_with_alias(parquet_exec(), alias_pairs); + + // Join on (a == a1 and b == b1 and c == c1) + let join_on = vec![ + ( + Column::new_with_schema("a", &schema()).unwrap(), + Column::new_with_schema("a1", &right.schema()).unwrap(), + ), + ( + Column::new_with_schema("b", &schema()).unwrap(), + Column::new_with_schema("b1", &right.schema()).unwrap(), + ), + ( + Column::new_with_schema("c", &schema()).unwrap(), + Column::new_with_schema("c1", &right.schema()).unwrap(), + ), + ]; + let top_left_join = + hash_join_exec(left.clone(), right.clone(), &join_on, &JoinType::Inner); + + // Projection(a as A, a as AA, b as B, c as C) + let alias_pairs: Vec<(String, String)> = vec![ + ("a".to_string(), "A".to_string()), + ("a".to_string(), "AA".to_string()), + ("b".to_string(), "B".to_string()), + ("c".to_string(), "C".to_string()), + ]; + let projection = projection_exec_with_alias(top_left_join, alias_pairs); + + // Join on (c == c1 and b == b1 and a == a1) + let join_on = vec![ + ( + Column::new_with_schema("c", &schema()).unwrap(), + Column::new_with_schema("c1", &right.schema()).unwrap(), + ), + ( + Column::new_with_schema("b", &schema()).unwrap(), + Column::new_with_schema("b1", &right.schema()).unwrap(), + ), + ( + Column::new_with_schema("a", &schema()).unwrap(), + Column::new_with_schema("a1", &right.schema()).unwrap(), + ), + ]; + let top_right_join = + hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); + + // Join on (B == b1 and C == c and AA = a1) + let top_join_on = vec![ + ( + Column::new_with_schema("B", &projection.schema()).unwrap(), + Column::new_with_schema("b1", &top_right_join.schema()).unwrap(), + ), + ( + Column::new_with_schema("C", &projection.schema()).unwrap(), + Column::new_with_schema("c", &top_right_join.schema()).unwrap(), + ), + ( + Column::new_with_schema("AA", &projection.schema()).unwrap(), + Column::new_with_schema("a1", &top_right_join.schema()).unwrap(), + ), + ]; + + let top_join = hash_join_exec( + projection.clone(), + top_right_join, + &top_join_on, + &JoinType::Inner, + ); + + // Output partition need to respect the Alias and should not introduce additional RepartitionExec + let expected = &[ + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"B\", index: 2 }, Column { name: \"b1\", index: 6 }), (Column { name: \"C\", index: 3 }, Column { name: \"c\", index: 2 }), (Column { name: \"AA\", index: 1 }, Column { name: \"a1\", index: 5 })]", + "ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C]", + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"b\", index: 1 }, Column { name: \"b1\", index: 1 }), (Column { name: \"c\", index: 2 }, Column { name: \"c1\", index: 2 }), (Column { name: \"a\", index: 0 }, Column { name: \"a1\", index: 0 })]", + "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 1 }, Column { name: \"c\", index: 2 }, Column { name: \"a\", index: 0 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }, Column { name: \"c1\", index: 2 }, Column { name: \"a1\", index: 0 }], 10)", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"b\", index: 1 }, Column { name: \"b1\", index: 1 }), (Column { name: \"c\", index: 2 }, Column { name: \"c1\", index: 2 }), (Column { name: \"a\", index: 0 }, Column { name: \"a1\", index: 0 })]", + "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 1 }, Column { name: \"c\", index: 2 }, Column { name: \"a\", index: 0 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }, Column { name: \"c1\", index: 2 }, Column { name: \"a1\", index: 0 }], 10)", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + ]; + assert_optimized!(expected, top_join); + Ok(()) + } + + #[test] + fn multi_smj_joins() -> Result<()> { + let left = parquet_exec(); + let alias_pairs: Vec<(String, String)> = vec![ + ("a".to_string(), "a1".to_string()), + ("b".to_string(), "b1".to_string()), + ("c".to_string(), "c1".to_string()), + ("d".to_string(), "d1".to_string()), + ("e".to_string(), "e1".to_string()), + ]; + let right = projection_exec_with_alias(parquet_exec(), alias_pairs); + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::Semi, + JoinType::Anti, + ]; + + // Join on (a == b1) + let join_on = vec![( + Column::new_with_schema("a", &schema()).unwrap(), + Column::new_with_schema("b1", &right.schema()).unwrap(), + )]; + + for join_type in join_types { + let join = + sort_merge_join_exec(left.clone(), right.clone(), &join_on, &join_type); + + // Top join on (a == c) + let top_join_on = vec![( + Column::new_with_schema("a", &join.schema()).unwrap(), + Column::new_with_schema("c", &schema()).unwrap(), + )]; + + let top_join = sort_merge_join_exec( + join.clone(), + parquet_exec(), + &top_join_on, + &join_type, + ); + + let top_join_plan = + format!("SortMergeJoin: join_type={}, on=[(Column {{ name: \"a\", index: 0 }}, Column {{ name: \"c\", index: 2 }})]", join_type); + let join_plan = + format!("SortMergeJoin: join_type={}, on=[(Column {{ name: \"a\", index: 0 }}, Column {{ name: \"b1\", index: 1 }})]", join_type); + + let expected = match join_type { + // Should include 3 RepartitionExecs 3 SortExecs + JoinType::Inner | JoinType::Left => vec![ + top_join_plan.as_str(), + join_plan.as_str(), + "SortExec: [a@0 ASC]", + "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "SortExec: [b1@1 ASC]", + "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10)", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "SortExec: [c@2 ASC]", + "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + ], + // Should include 4 RepartitionExecs + _ => vec![ + top_join_plan.as_str(), + "SortExec: [a@0 ASC]", + "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)", + join_plan.as_str(), + "SortExec: [a@0 ASC]", + "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "SortExec: [b1@1 ASC]", + "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10)", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "SortExec: [c@2 ASC]", + "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + ], + }; + assert_optimized!(expected, top_join); + + match join_type { + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + // This time we use (b1 == c) for top join + // Join on (b1 == c) + let top_join_on = vec![( + Column::new_with_schema("b1", &join.schema()).unwrap(), + Column::new_with_schema("c", &schema()).unwrap(), + )]; + let top_join = sort_merge_join_exec( + join, + parquet_exec(), + &top_join_on, + &join_type, + ); + let top_join_plan = + format!("SortMergeJoin: join_type={}, on=[(Column {{ name: \"b1\", index: 6 }}, Column {{ name: \"c\", index: 2 }})]", join_type); + + let expected = match join_type { + // Should include 3 RepartitionExecs and 3 SortExecs + JoinType::Inner | JoinType::Right => vec![ + top_join_plan.as_str(), + join_plan.as_str(), + "SortExec: [a@0 ASC]", + "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "SortExec: [b1@1 ASC]", + "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10)", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "SortExec: [c@2 ASC]", + "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + ], + // Should include 4 RepartitionExecs and 4 SortExecs + _ => vec![ + top_join_plan.as_str(), + "SortExec: [b1@6 ASC]", + "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 6 }], 10)", + join_plan.as_str(), + "SortExec: [a@0 ASC]", + "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "SortExec: [b1@1 ASC]", + "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10)", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "SortExec: [c@2 ASC]", + "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + ], + }; + assert_optimized!(expected, top_join); + } + _ => {} + } + } + + Ok(()) + } + + #[test] + fn smj_join_key_ordering() -> Result<()> { + // group by (a as a1, b as b1) + let left = aggregate_exec_with_alias( + parquet_exec(), + vec![ + ("a".to_string(), "a1".to_string()), + ("b".to_string(), "b1".to_string()), + ], + ); + //Projection(a1 as a3, b1 as b3) + let alias_pairs: Vec<(String, String)> = vec![ + ("a1".to_string(), "a3".to_string()), + ("b1".to_string(), "b3".to_string()), + ]; + let left = projection_exec_with_alias(left, alias_pairs); + + // group by (b, a) + let right = aggregate_exec_with_alias( + parquet_exec(), + vec![ + ("b".to_string(), "b".to_string()), + ("a".to_string(), "a".to_string()), + ], + ); + + //Projection(a as a2, b as b2) + let alias_pairs: Vec<(String, String)> = vec![ + ("a".to_string(), "a2".to_string()), + ("b".to_string(), "b2".to_string()), + ]; + let right = projection_exec_with_alias(right, alias_pairs); + + // Join on (b3 == b2 && a3 == a2) + let join_on = vec![ + ( + Column::new_with_schema("b3", &left.schema()).unwrap(), + Column::new_with_schema("b2", &right.schema()).unwrap(), + ), + ( + Column::new_with_schema("a3", &left.schema()).unwrap(), + Column::new_with_schema("a2", &right.schema()).unwrap(), + ), + ]; + let join = sort_merge_join_exec(left, right.clone(), &join_on, &JoinType::Inner); + + // Only two RepartitionExecs added + let expected = &[ + "SortMergeJoin: join_type=Inner, on=[(Column { name: \"b3\", index: 1 }, Column { name: \"b2\", index: 1 }), (Column { name: \"a3\", index: 0 }, Column { name: \"a2\", index: 0 })]", + "SortExec: [b3@1 ASC,a3@0 ASC]", + "ProjectionExec: expr=[a1@0 as a3, b1@1 as b3]", + "ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]", + "AggregateExec: mode=FinalPartitioned, gby=[b1@1 as b1, a1@0 as a1], aggr=[]", + "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 0 }, Column { name: \"a1\", index: 1 }], 10)", + "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + "SortExec: [b2@1 ASC,a2@0 ASC]", + "ProjectionExec: expr=[a@1 as a2, b@0 as b2]", + "AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[]", + "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 0 }, Column { name: \"a\", index: 1 }], 10)", + "AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[]", + "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]", + ]; + assert_optimized!(expected, join); + Ok(()) + } +} diff --git a/datafusion/core/src/physical_optimizer/merge_exec.rs b/datafusion/core/src/physical_optimizer/merge_exec.rs index f614673f500e6..77fcce9d3b601 100644 --- a/datafusion/core/src/physical_optimizer/merge_exec.rs +++ b/datafusion/core/src/physical_optimizer/merge_exec.rs @@ -52,27 +52,21 @@ impl PhysicalOptimizerRule for AddCoalescePartitionsExec { .iter() .map(|child| self.optimize(child.clone(), config)) .collect::>>()?; - match plan.required_child_distribution() { - Distribution::UnspecifiedDistribution => { - with_new_children_if_necessary(plan, children) - } - Distribution::HashPartitioned(_) => { - with_new_children_if_necessary(plan, children) - } - Distribution::SinglePartition => with_new_children_if_necessary( - plan, - children - .iter() - .map(|child| { - if child.output_partitioning().partition_count() == 1 { - child.clone() - } else { - Arc::new(CoalescePartitionsExec::new(child.clone())) - } - }) - .collect(), - ), - } + assert_eq!(children.len(), plan.required_input_distribution().len()); + + let new_children = children + .into_iter() + .zip(plan.required_input_distribution()) + .map(|(child, dist)| match dist { + Distribution::SinglePartition + if child.output_partitioning().partition_count() > 1 => + { + Arc::new(CoalescePartitionsExec::new(child.clone())) + } + _ => child, + }) + .collect::>(); + with_new_children_if_necessary(plan, new_children) } } diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index 55550bcd2cffc..5ecb9cd37a48d 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -20,6 +20,7 @@ pub mod aggregate_statistics; pub mod coalesce_batches; +pub mod enforcement; pub mod hash_build_probe_order; pub mod merge_exec; pub mod optimizer; diff --git a/datafusion/core/src/physical_optimizer/repartition.rs b/datafusion/core/src/physical_optimizer/repartition.rs index 839908d0659bb..c0aa9bf7a96db 100644 --- a/datafusion/core/src/physical_optimizer/repartition.rs +++ b/datafusion/core/src/physical_optimizer/repartition.rs @@ -137,7 +137,7 @@ impl Repartition { /// /// 1. Has fewer partitions than `target_partitions` /// -/// 2. Has a direct parent that `benefits_from_input_partitioning` +/// 2. Has a direct parent that `prefer_parallel` /// /// 3. Does not have a parent that `relies_on_input_order` unless there /// is an intervening node that does not `maintain_input_order` @@ -189,7 +189,7 @@ fn optimize_partitions( target_partitions, child.clone(), can_reorder_children, - plan.benefits_from_input_partitioning(), + plan.prefer_parallel(), ) }) .collect::>()?; @@ -234,6 +234,7 @@ impl PhysicalOptimizerRule for Repartition { "repartition" } } + #[cfg(test)] mod tests { use arrow::compute::SortOptions; diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index 2c4a9b26c3994..90305d7556adc 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -34,9 +34,12 @@ use datafusion_common::Result; use datafusion_expr::Accumulator; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{ - expressions, AggregateExpr, PhysicalExpr, PhysicalSortExpr, + expressions, merge_equivalence_properties_with_alias, + normalize_out_expr_with_alias_schema, truncate_equivalence_properties_not_in_schema, + AggregateExpr, PhysicalExpr, PhysicalSortExpr, }; use std::any::Any; +use std::collections::HashMap; use std::sync::Arc; @@ -150,19 +153,21 @@ impl PhysicalGroupBy { #[derive(Debug)] pub struct AggregateExec { /// Aggregation mode (full, partial) - mode: AggregateMode, + pub mode: AggregateMode, /// Group by expressions - group_by: PhysicalGroupBy, + pub group_by: PhysicalGroupBy, /// Aggregate expressions - aggr_expr: Vec>, + pub aggr_expr: Vec>, /// Input plan, could be a partial aggregate or the input to the aggregate - input: Arc, + pub input: Arc, /// Schema after the aggregate is applied schema: SchemaRef, /// Input schema before any aggregation is applied. For partial aggregate this will be the /// same as input.schema() but for the final aggregate it will be the same as the input /// to the partial aggregate - input_schema: SchemaRef, + pub input_schema: SchemaRef, + /// The alias map used to normalize out expressions like Partitioning + alias_map: HashMap>, /// Execution Metrics metrics: ExecutionPlanMetricsSet, } @@ -186,6 +191,18 @@ impl AggregateExec { let schema = Arc::new(schema); + let mut alias_map: HashMap> = HashMap::new(); + for (expression, name) in group_by.expr.iter() { + if let Some(column) = expression.as_any().downcast_ref::() { + let new_col_idx = schema.index_of(name)?; + // When the column name is the same, but index does not equal, treat it as Alias + if (column.name() != name) || (column.index() != new_col_idx) { + let entry = alias_map.entry(column.clone()).or_insert_with(Vec::new); + entry.push(Column::new(name, new_col_idx)); + } + }; + } + Ok(AggregateExec { mode, group_by, @@ -193,6 +210,7 @@ impl AggregateExec { input, schema, input_schema, + alias_map, metrics: ExecutionPlanMetricsSet::new(), }) } @@ -255,25 +273,57 @@ impl ExecutionPlan for AggregateExec { /// Get the output partitioning of this plan fn output_partitioning(&self) -> Partitioning { - self.input.output_partitioning() + match &self.mode { + AggregateMode::Partial => { + // Partial Aggregation will not change the output partitioning but need to respect the Alias + let input_partition = self.input.output_partitioning(); + match input_partition { + Partitioning::Hash(exprs, part) => { + let normalized_exprs = exprs + .into_iter() + .map(|expr| { + normalize_out_expr_with_alias_schema( + expr, + &self.alias_map, + &self.schema, + ) + }) + .collect::>(); + Partitioning::Hash(normalized_exprs, part) + } + _ => input_partition, + } + } + // Final Aggregation's output partitioning is the same as its real input + _ => self.input.output_partitioning(), + } } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { None } - fn required_child_distribution(&self) -> Distribution { + fn required_input_distribution(&self) -> Vec { match &self.mode { - AggregateMode::Partial => Distribution::UnspecifiedDistribution, - AggregateMode::FinalPartitioned => Distribution::HashPartitioned( - self.group_by.expr.iter().map(|x| x.0.clone()).collect(), - ), - AggregateMode::Final => Distribution::SinglePartition, + AggregateMode::Partial => vec![Distribution::UnspecifiedDistribution], + AggregateMode::FinalPartitioned => { + vec![Distribution::HashPartitioned(self.output_group_expr())] + } + AggregateMode::Final => vec![Distribution::SinglePartition], } } - fn relies_on_input_order(&self) -> bool { - false + fn equivalence_properties(&self) -> Vec> { + let mut input_equivalence_properties = self.input.equivalence_properties(); + merge_equivalence_properties_with_alias( + &mut input_equivalence_properties, + &self.alias_map, + ); + truncate_equivalence_properties_not_in_schema( + &mut input_equivalence_properties, + &self.schema, + ); + input_equivalence_properties } fn children(&self) -> Vec> { @@ -654,7 +704,7 @@ mod tests { use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use datafusion_common::{DataFusionError, Result, ScalarValue}; - use datafusion_physical_expr::expressions::{lit, Count}; + use datafusion_physical_expr::expressions::{lit, Column, Count}; use datafusion_physical_expr::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; use futures::{FutureExt, Stream}; use std::any::Any; @@ -924,6 +974,10 @@ mod tests { None } + fn equivalence_properties(&self) -> Vec> { + vec![] + } + fn children(&self) -> Vec> { vec![] } diff --git a/datafusion/core/src/physical_plan/analyze.rs b/datafusion/core/src/physical_plan/analyze.rs index 8134ee7d2f2da..eeb65cc21fe10 100644 --- a/datafusion/core/src/physical_plan/analyze.rs +++ b/datafusion/core/src/physical_plan/analyze.rs @@ -28,6 +28,7 @@ use crate::{ }, }; use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; +use datafusion_physical_expr::expressions::Column; use futures::StreamExt; use super::expressions::PhysicalSortExpr; @@ -72,8 +73,8 @@ impl ExecutionPlan for AnalyzeExec { } /// Specifies we want the input as a single stream - fn required_child_distribution(&self) -> Distribution { - Distribution::SinglePartition + fn required_input_distribution(&self) -> Vec { + vec![Distribution::SinglePartition] } /// Get the output partitioning of this plan @@ -85,8 +86,8 @@ impl ExecutionPlan for AnalyzeExec { None } - fn relies_on_input_order(&self) -> bool { - false + fn equivalence_properties(&self) -> Vec> { + vec![] } fn with_new_children( diff --git a/datafusion/core/src/physical_plan/coalesce_batches.rs b/datafusion/core/src/physical_plan/coalesce_batches.rs index 317500ddc904f..b8d690f2bbf1a 100644 --- a/datafusion/core/src/physical_plan/coalesce_batches.rs +++ b/datafusion/core/src/physical_plan/coalesce_batches.rs @@ -34,6 +34,7 @@ use arrow::compute::kernels::concat::concat; use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; +use datafusion_physical_expr::expressions::Column; use futures::stream::{Stream, StreamExt}; use log::trace; @@ -96,12 +97,15 @@ impl ExecutionPlan for CoalesceBatchesExec { self.input.output_partitioning() } + // Depends on how the CoalesceBatches was implemented, it is possible to keep + // the input ordering when combines small batches into larger batches + // TODO revisit the logic later fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { None } - fn relies_on_input_order(&self) -> bool { - false + fn equivalence_properties(&self) -> Vec> { + self.input.equivalence_properties() } fn with_new_children( diff --git a/datafusion/core/src/physical_plan/coalesce_partitions.rs b/datafusion/core/src/physical_plan/coalesce_partitions.rs index d1c797eacd5c9..e6378e2d33571 100644 --- a/datafusion/core/src/physical_plan/coalesce_partitions.rs +++ b/datafusion/core/src/physical_plan/coalesce_partitions.rs @@ -27,6 +27,7 @@ use tokio::sync::mpsc; use arrow::record_batch::RecordBatch; use arrow::{datatypes::SchemaRef, error::Result as ArrowResult}; +use datafusion_physical_expr::expressions::Column; use super::common::AbortOnDropMany; use super::expressions::PhysicalSortExpr; @@ -87,8 +88,8 @@ impl ExecutionPlan for CoalescePartitionsExec { None } - fn relies_on_input_order(&self) -> bool { - false + fn equivalence_properties(&self) -> Vec> { + self.input.equivalence_properties() } fn with_new_children( diff --git a/datafusion/core/src/physical_plan/empty.rs b/datafusion/core/src/physical_plan/empty.rs index c693764c87aa0..a8e17d7d854dc 100644 --- a/datafusion/core/src/physical_plan/empty.rs +++ b/datafusion/core/src/physical_plan/empty.rs @@ -22,11 +22,12 @@ use std::sync::Arc; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ - memory::MemoryStream, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + memory::MemoryStream, DisplayFormatType, ExecutionPlan, Partitioning, }; use arrow::array::NullArray; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use datafusion_physical_expr::expressions::Column; use log::debug; use super::expressions::PhysicalSortExpr; @@ -98,10 +99,6 @@ impl ExecutionPlan for EmptyExec { vec![] } - fn required_child_distribution(&self) -> Distribution { - Distribution::UnspecifiedDistribution - } - /// Get the output partitioning of this plan fn output_partitioning(&self) -> Partitioning { Partitioning::UnknownPartitioning(self.partitions) @@ -111,6 +108,10 @@ impl ExecutionPlan for EmptyExec { None } + fn equivalence_properties(&self) -> Vec> { + vec![] + } + fn with_new_children( self: Arc, _: Vec>, diff --git a/datafusion/core/src/physical_plan/explain.rs b/datafusion/core/src/physical_plan/explain.rs index 15f459fb045b1..463dbac30ad29 100644 --- a/datafusion/core/src/physical_plan/explain.rs +++ b/datafusion/core/src/physical_plan/explain.rs @@ -29,6 +29,7 @@ use crate::{ }, }; use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; +use datafusion_physical_expr::expressions::Column; use log::debug; use super::{expressions::PhysicalSortExpr, SendableRecordBatchStream}; @@ -97,8 +98,8 @@ impl ExecutionPlan for ExplainExec { None } - fn relies_on_input_order(&self) -> bool { - false + fn equivalence_properties(&self) -> Vec> { + vec![] } fn with_new_children( diff --git a/datafusion/core/src/physical_plan/file_format/avro.rs b/datafusion/core/src/physical_plan/file_format/avro.rs index 2aab84fadbcfc..24c146b865885 100644 --- a/datafusion/core/src/physical_plan/file_format/avro.rs +++ b/datafusion/core/src/physical_plan/file_format/avro.rs @@ -25,6 +25,7 @@ use arrow::datatypes::SchemaRef; use crate::execution::context::TaskContext; use crate::physical_plan::metrics::ExecutionPlanMetricsSet; +use datafusion_physical_expr::expressions::Column; use std::any::Any; use std::sync::Arc; @@ -76,8 +77,8 @@ impl ExecutionPlan for AvroExec { None } - fn relies_on_input_order(&self) -> bool { - false + fn equivalence_properties(&self) -> Vec> { + vec![] } fn children(&self) -> Vec> { diff --git a/datafusion/core/src/physical_plan/file_format/csv.rs b/datafusion/core/src/physical_plan/file_format/csv.rs index d086a77982357..122dc59726701 100644 --- a/datafusion/core/src/physical_plan/file_format/csv.rs +++ b/datafusion/core/src/physical_plan/file_format/csv.rs @@ -34,7 +34,7 @@ use arrow::csv; use arrow::datatypes::SchemaRef; use bytes::Buf; - +use datafusion_physical_expr::expressions::Column; use futures::{StreamExt, TryStreamExt}; use object_store::{GetResult, ObjectStore}; use std::any::Any; @@ -109,14 +109,14 @@ impl ExecutionPlan for CsvExec { Partitioning::UnknownPartitioning(self.base_config.file_groups.len()) } - fn relies_on_input_order(&self) -> bool { - false - } - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { None } + fn equivalence_properties(&self) -> Vec> { + vec![] + } + fn children(&self) -> Vec> { // this is a leaf node and has no children vec![] diff --git a/datafusion/core/src/physical_plan/file_format/json.rs b/datafusion/core/src/physical_plan/file_format/json.rs index c8c5d71bd73f2..e8e98d975cb80 100644 --- a/datafusion/core/src/physical_plan/file_format/json.rs +++ b/datafusion/core/src/physical_plan/file_format/json.rs @@ -34,6 +34,7 @@ use arrow::json::reader::DecoderOptions; use arrow::{datatypes::SchemaRef, json}; use bytes::Buf; +use datafusion_physical_expr::expressions::Column; use futures::{StreamExt, TryStreamExt}; use object_store::{GetResult, ObjectStore}; @@ -91,8 +92,8 @@ impl ExecutionPlan for NdJsonExec { None } - fn relies_on_input_order(&self) -> bool { - false + fn equivalence_properties(&self) -> Vec> { + vec![] } fn children(&self) -> Vec> { diff --git a/datafusion/core/src/physical_plan/file_format/parquet.rs b/datafusion/core/src/physical_plan/file_format/parquet.rs index f5bd890591fd5..492c0959002cb 100644 --- a/datafusion/core/src/physical_plan/file_format/parquet.rs +++ b/datafusion/core/src/physical_plan/file_format/parquet.rs @@ -290,8 +290,10 @@ impl ExecutionPlan for ParquetExec { None } - fn relies_on_input_order(&self) -> bool { - false + fn equivalence_properties( + &self, + ) -> Vec> { + vec![] } fn with_new_children( diff --git a/datafusion/core/src/physical_plan/filter.rs b/datafusion/core/src/physical_plan/filter.rs index b4e3edaee05fd..bc7c6fd8e9238 100644 --- a/datafusion/core/src/physical_plan/filter.rs +++ b/datafusion/core/src/physical_plan/filter.rs @@ -39,6 +39,11 @@ use arrow::record_batch::RecordBatch; use log::debug; use crate::execution::context::TaskContext; +use datafusion_expr::Operator; +use datafusion_physical_expr::expressions::{BinaryExpr, Column}; +use datafusion_physical_expr::{ + combine_equivalence_properties, remove_equivalence_properties, split_predicate, +}; use futures::stream::{Stream, StreamExt}; /// FilterExec evaluates a boolean predicate against all input batches to determine which rows to @@ -113,8 +118,17 @@ impl ExecutionPlan for FilterExec { true } - fn relies_on_input_order(&self) -> bool { - false + fn equivalence_properties(&self) -> Vec> { + // Combine the equal predicates with the input equivalence properties + let mut input_properties = self.input.equivalence_properties(); + let (equal_pairs, ne_pairs) = collect_columns_from_predicate(&self.predicate); + for new_condition in equal_pairs { + combine_equivalence_properties(&mut input_properties, new_condition) + } + for remove_condition in ne_pairs { + remove_equivalence_properties(&mut input_properties, remove_condition) + } + input_properties } fn with_new_children( @@ -231,6 +245,39 @@ impl RecordBatchStream for FilterExecStream { } } +/// Return the equals Column-Pairs and Non-equals Column-Pairs +fn collect_columns_from_predicate(predicate: &Arc) -> EqualAndNonEqual { + let mut eq_predicate_columns: Vec<(&Column, &Column)> = Vec::new(); + let mut ne_predicate_columns: Vec<(&Column, &Column)> = Vec::new(); + + let predicates = split_predicate(predicate); + predicates.into_iter().for_each(|p| { + if let Some(binary) = p.as_any().downcast_ref::() { + let left = binary.left(); + let right = binary.right(); + if left.as_any().is::() && right.as_any().is::() { + let left_column = left.as_any().downcast_ref::().unwrap(); + let right_column = left.as_any().downcast_ref::().unwrap(); + match binary.op() { + Operator::Eq => { + eq_predicate_columns.push((left_column, right_column)) + } + Operator::NotEq => { + ne_predicate_columns.push((left_column, right_column)) + } + _ => {} + } + } + } + }); + + (eq_predicate_columns, ne_predicate_columns) +} + +/// The equals Column-Pairs and Non-equals Column-Pairs in the Predicates +pub type EqualAndNonEqual<'a> = + (Vec<(&'a Column, &'a Column)>, Vec<(&'a Column, &'a Column)>); + #[cfg(test)] mod tests { diff --git a/datafusion/core/src/physical_plan/joins/cross_join.rs b/datafusion/core/src/physical_plan/joins/cross_join.rs index 7a35116a46585..1cd3b9097de5e 100644 --- a/datafusion/core/src/physical_plan/joins/cross_join.rs +++ b/datafusion/core/src/physical_plan/joins/cross_join.rs @@ -34,6 +34,7 @@ use crate::physical_plan::{ }; use crate::{error::Result, scalar::ScalarValue}; use async_trait::async_trait; +use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::PhysicalSortExpr; use log::debug; use std::time::Instant; @@ -161,8 +162,11 @@ impl ExecutionPlan for CrossJoinExec { None } - fn relies_on_input_order(&self) -> bool { - false + fn equivalence_properties(&self) -> Vec> { + let mut left_properties = self.left.equivalence_properties(); + let right_properties = self.left.equivalence_properties(); + left_properties.extend(right_properties); + left_properties } fn execute( diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs b/datafusion/core/src/physical_plan/joins/hash_join.rs index ff036b78b32e6..c5248a85241f6 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join.rs @@ -66,8 +66,8 @@ use crate::physical_plan::{ JoinFilter, JoinOn, JoinSide, }, metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, - DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, - SendableRecordBatchStream, Statistics, + DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, + RecordBatchStream, SendableRecordBatchStream, Statistics, }; use crate::error::{DataFusionError, Result}; @@ -77,6 +77,9 @@ use crate::arrow::array::BooleanBufferBuilder; use crate::arrow::datatypes::TimeUnit; use crate::execution::context::TaskContext; +use datafusion_physical_expr::combine_equivalence_properties; +use datafusion_physical_expr::TreeNodeRewritable; + use super::{ utils::{OnceAsync, OnceFut}, PartitionMode, @@ -117,15 +120,15 @@ type JoinLeftData = (JoinHashMap, RecordBatch); #[derive(Debug)] pub struct HashJoinExec { /// left (build) side which gets hashed - left: Arc, + pub left: Arc, /// right (probe) side which are filtered by the hash table - right: Arc, + pub right: Arc, /// Set of common columns used to join on - on: Vec<(Column, Column)>, + pub on: Vec<(Column, Column)>, /// Filters which are applied while finding matching rows - filter: Option, + pub filter: Option, /// How the join is performed - join_type: JoinType, + pub join_type: JoinType, /// The schema once the join is applied schema: SchemaRef, /// Build-side data @@ -133,13 +136,13 @@ pub struct HashJoinExec { /// Shares the `RandomState` for the hashing algorithm random_state: RandomState, /// Partitioning mode to use - mode: PartitionMode, + pub mode: PartitionMode, /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Information of index and left / right placement of columns column_indices: Vec, /// If null_equals_null is true, null == null else null != null - null_equals_null: bool, + pub null_equals_null: bool, } /// Metrics for HashJoinExec @@ -274,6 +277,31 @@ impl ExecutionPlan for HashJoinExec { vec![self.left.clone(), self.right.clone()] } + fn required_input_distribution(&self) -> Vec { + match self.mode { + PartitionMode::CollectLeft => vec![ + Distribution::SinglePartition, + Distribution::UnspecifiedDistribution, + ], + PartitionMode::Partitioned => { + let (left_expr, right_expr) = self + .on + .iter() + .map(|(l, r)| { + ( + Arc::new(l.clone()) as Arc, + Arc::new(r.clone()) as Arc, + ) + }) + .unzip(); + vec![ + Distribution::HashPartitioned(left_expr), + Distribution::HashPartitioned(right_expr), + ] + } + } + } + fn with_new_children( self: Arc, children: Vec>, @@ -290,15 +318,84 @@ impl ExecutionPlan for HashJoinExec { } fn output_partitioning(&self) -> Partitioning { - self.right.output_partitioning() + match self.join_type { + JoinType::Inner => self.left.output_partitioning(), + JoinType::Left => self.left.output_partitioning(), + JoinType::Right => { + let left_columns_len = self.left.schema().fields.len(); + match self.right.output_partitioning() { + Partitioning::RoundRobinBatch(size) => { + Partitioning::RoundRobinBatch(size) + } + Partitioning::Hash(exprs, size) => { + let new_exprs = exprs + .into_iter() + .map(|expr| { + expr.transform_down(&|e| match e + .as_any() + .downcast_ref::() + { + Some(col) => Some(Arc::new(Column::new( + col.name(), + left_columns_len + col.index(), + ))), + None => None, + }) + .unwrap() + }) + .collect::>(); + Partitioning::Hash(new_exprs, size) + } + Partitioning::UnknownPartitioning(size) => { + Partitioning::UnknownPartitioning(size) + } + } + } + _ => Partitioning::UnknownPartitioning( + self.right.output_partitioning().partition_count(), + ), + } } + // Output ordering might be kept for some cases. + // For example if it is inner join then the stream side order can be kept fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { None } - fn relies_on_input_order(&self) -> bool { - false + fn equivalence_properties(&self) -> Vec> { + let mut left_properties = self.left.equivalence_properties(); + match self.join_type { + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + let right_properties = self.right.equivalence_properties(); + let left_columns_len = self.left.schema().fields.len(); + let new_right_properties = right_properties + .into_iter() + .map(|cols| { + cols.into_iter() + .map(|col| { + Column::new(col.name(), left_columns_len + col.index()) + }) + .collect::>() + }) + .collect::>(); + left_properties.extend(new_right_properties); + } + JoinType::Semi | JoinType::Anti => {} + } + + if self.join_type == JoinType::Inner { + let left_columns_len = self.left.schema().fields.len(); + self.on.iter().for_each(|(column1, column2)| { + let new_column2 = + Column::new(column2.name(), left_columns_len + column2.index()); + combine_equivalence_properties( + &mut left_properties, + (column1, &new_column2), + ) + }) + } + left_properties } fn execute( @@ -309,6 +406,7 @@ impl ExecutionPlan for HashJoinExec { let on_left = self.on.iter().map(|on| on.0.clone()).collect::>(); let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); + //TODO fix this in distribution model, need to add shuffle let left_fut = match self.mode { PartitionMode::CollectLeft => self.left_fut.once(|| { collect_left_input( diff --git a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs index 3de712745de4f..d7659fa528320 100644 --- a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs @@ -33,6 +33,9 @@ use arrow::compute::{take, SortOptions}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; +use datafusion_physical_expr::{ + combine_equivalence_properties, PhysicalExpr, TreeNodeRewritable, +}; use futures::{Stream, StreamExt}; use crate::error::DataFusionError; @@ -47,8 +50,8 @@ use crate::physical_plan::joins::utils::{ }; use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use crate::physical_plan::{ - metrics, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, - SendableRecordBatchStream, Statistics, + metrics, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, Statistics, }; /// join execution plan executes partitions in parallel and combines them into a set of @@ -56,21 +59,27 @@ use crate::physical_plan::{ #[derive(Debug)] pub struct SortMergeJoinExec { /// Left sorted joining execution plan - left: Arc, + pub left: Arc, /// Right sorting joining execution plan - right: Arc, + pub right: Arc, /// Set of common columns used to join on - on: JoinOn, + pub on: JoinOn, /// How the join is performed - join_type: JoinType, + pub join_type: JoinType, /// The schema once the join is applied schema: SchemaRef, /// Execution metrics metrics: ExecutionPlanMetricsSet, + /// The left SortExpr + left_sort_exprs: Vec, + /// The right SortExpr + right_sort_exprs: Vec, + /// The output ordering + output_ordering: Option>, /// Sort options of join columns used in sorting left and right execution plans - sort_options: Vec, + pub sort_options: Vec, /// If null_equals_null is true, null == null else null != null - null_equals_null: bool, + pub null_equals_null: bool, } impl SortMergeJoinExec { @@ -98,6 +107,69 @@ impl SortMergeJoinExec { ))); } + let (left_expr, right_expr): (Vec<_>, Vec<_>) = on + .iter() + .map(|(l, r)| { + ( + Arc::new(l.clone()) as Arc, + Arc::new(r.clone()) as Arc, + ) + }) + .unzip(); + + let left_sort_exprs = left_expr + .into_iter() + .zip(sort_options.iter()) + .map(|(k, sort_op)| PhysicalSortExpr { + expr: k, + options: *sort_op, + }) + .collect::>(); + + let right_sort_exprs = right_expr + .into_iter() + .zip(sort_options.iter()) + .map(|(k, sort_op)| PhysicalSortExpr { + expr: k, + options: *sort_op, + }) + .collect::>(); + + let output_ordering = match join_type { + JoinType::Inner | JoinType::Left | JoinType::Semi | JoinType::Anti => { + left.output_ordering().map(|sort_exprs| sort_exprs.to_vec()) + } + JoinType::Right => { + let left_columns_len = left.schema().fields.len(); + right.output_ordering().map(|sort_exprs| { + sort_exprs + .iter() + .map(|e| { + let new_expr = e + .expr + .clone() + .transform_down(&|e| match e + .as_any() + .downcast_ref::() + { + Some(col) => Some(Arc::new(Column::new( + col.name(), + left_columns_len + col.index(), + ))), + None => None, + }) + .unwrap(); + PhysicalSortExpr { + expr: new_expr, + options: e.options, + } + }) + .collect::>() + }) + } + JoinType::Full => None, + }; + let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); @@ -108,6 +180,9 @@ impl SortMergeJoinExec { join_type, schema, metrics: ExecutionPlanMetricsSet::new(), + left_sort_exprs, + right_sort_exprs, + output_ordering, sort_options, null_equals_null, }) @@ -124,21 +199,103 @@ impl ExecutionPlan for SortMergeJoinExec { } fn output_partitioning(&self) -> Partitioning { - self.right.output_partitioning() + match self.join_type { + JoinType::Inner => self.left.output_partitioning(), + JoinType::Left => self.left.output_partitioning(), + JoinType::Right => { + let left_columns_len = self.left.schema().fields.len(); + match self.right.output_partitioning() { + Partitioning::RoundRobinBatch(size) => { + Partitioning::RoundRobinBatch(size) + } + Partitioning::Hash(exprs, size) => { + let new_exprs = exprs + .into_iter() + .map(|expr| { + expr.transform_down(&|e| match e + .as_any() + .downcast_ref::() + { + Some(col) => Some(Arc::new(Column::new( + col.name(), + left_columns_len + col.index(), + ))), + None => None, + }) + .unwrap() + }) + .collect::>(); + Partitioning::Hash(new_exprs, size) + } + Partitioning::UnknownPartitioning(size) => { + Partitioning::UnknownPartitioning(size) + } + } + } + _ => Partitioning::UnknownPartitioning( + self.right.output_partitioning().partition_count(), + ), + } } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.output_ordering.as_deref() + } + + fn required_input_distribution(&self) -> Vec { + let (left_expr, right_expr) = self + .on + .iter() + .map(|(l, r)| { + ( + Arc::new(l.clone()) as Arc, + Arc::new(r.clone()) as Arc, + ) + }) + .unzip(); + vec![ + Distribution::HashPartitioned(left_expr), + Distribution::HashPartitioned(right_expr), + ] + } + + fn required_input_ordering(&self) -> Vec> { + vec![Some(&self.left_sort_exprs), Some(&self.right_sort_exprs)] + } + + fn equivalence_properties(&self) -> Vec> { + let mut left_properties = self.left.equivalence_properties(); match self.join_type { - JoinType::Inner | JoinType::Left | JoinType::Semi | JoinType::Anti => { - self.left.output_ordering() + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + let right_properties = self.right.equivalence_properties(); + let left_columns_len = self.left.schema().fields.len(); + let new_right_properties = right_properties + .into_iter() + .map(|cols| { + cols.into_iter() + .map(|col| { + Column::new(col.name(), left_columns_len + col.index()) + }) + .collect::>() + }) + .collect::>(); + left_properties.extend(new_right_properties); } - JoinType::Right => self.right.output_ordering(), - JoinType::Full => None, + JoinType::Semi | JoinType::Anti => {} } - } - fn relies_on_input_order(&self) -> bool { - true + if self.join_type == JoinType::Inner { + let left_columns_len = self.left.schema().fields.len(); + self.on.iter().for_each(|(column1, column2)| { + let new_column2 = + Column::new(column2.name(), left_columns_len + column2.index()); + combine_equivalence_properties( + &mut left_properties, + (column1, &new_column2), + ) + }) + } + left_properties } fn children(&self) -> Vec> { @@ -219,8 +376,8 @@ impl ExecutionPlan for SortMergeJoinExec { DisplayFormatType::Default => { write!( f, - "SortMergeJoin: join_type={:?}, on={:?}, schema={:?}", - self.join_type, self.on, &self.schema + "SortMergeJoin: join_type={:?}, on={:?}", + self.join_type, self.on ) } } diff --git a/datafusion/core/src/physical_plan/limit.rs b/datafusion/core/src/physical_plan/limit.rs index 322c21ff419cc..cb3c389b3b2ae 100644 --- a/datafusion/core/src/physical_plan/limit.rs +++ b/datafusion/core/src/physical_plan/limit.rs @@ -34,6 +34,7 @@ use arrow::compute::limit; use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; +use datafusion_physical_expr::expressions::Column; use super::expressions::PhysicalSortExpr; use super::{ @@ -98,8 +99,8 @@ impl ExecutionPlan for GlobalLimitExec { vec![self.input.clone()] } - fn required_child_distribution(&self) -> Distribution { - Distribution::SinglePartition + fn required_input_distribution(&self) -> Vec { + vec![Distribution::SinglePartition] } /// Get the output partitioning of this plan @@ -115,7 +116,7 @@ impl ExecutionPlan for GlobalLimitExec { true } - fn benefits_from_input_partitioning(&self) -> bool { + fn prefer_parallel(&self) -> bool { false } @@ -123,6 +124,10 @@ impl ExecutionPlan for GlobalLimitExec { self.input.output_ordering() } + fn equivalence_properties(&self) -> Vec> { + self.input.equivalence_properties() + } + fn with_new_children( self: Arc, children: Vec>, @@ -277,18 +282,17 @@ impl ExecutionPlan for LocalLimitExec { self.input.output_ordering().is_some() } - fn benefits_from_input_partitioning(&self) -> bool { + fn prefer_parallel(&self) -> bool { false } - // Local limit does not make any attempt to maintain the input - // sortedness (if there is more than one partition) + // Local limit will not change the input plan's ordering fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - if self.output_partitioning().partition_count() == 1 { - self.input.output_ordering() - } else { - None - } + self.input.output_ordering() + } + + fn equivalence_properties(&self) -> Vec> { + self.input.equivalence_properties() } fn with_new_children( diff --git a/datafusion/core/src/physical_plan/memory.rs b/datafusion/core/src/physical_plan/memory.rs index d2dbe0d738c39..3b33fa4fed5a5 100644 --- a/datafusion/core/src/physical_plan/memory.rs +++ b/datafusion/core/src/physical_plan/memory.rs @@ -34,6 +34,7 @@ use arrow::record_batch::RecordBatch; use crate::execution::context::TaskContext; use datafusion_common::DataFusionError; +use datafusion_physical_expr::expressions::Column; use futures::Stream; /// Execution plan for reading in-memory batches of data @@ -81,8 +82,8 @@ impl ExecutionPlan for MemoryExec { None } - fn relies_on_input_order(&self) -> bool { - false + fn equivalence_properties(&self) -> Vec> { + vec![] } fn with_new_children( diff --git a/datafusion/core/src/physical_plan/mod.rs b/datafusion/core/src/physical_plan/mod.rs index 9e36c3ec80edc..0fa01ccba8ad0 100644 --- a/datafusion/core/src/physical_plan/mod.rs +++ b/datafusion/core/src/physical_plan/mod.rs @@ -42,6 +42,7 @@ use datafusion_common::DataFusionError; use std::sync::Arc; use std::task::{Context, Poll}; use std::{any::Any, pin::Pin}; +use tokio::macros::support::thread_rng_n; /// Trait for types that stream [arrow::record_batch::RecordBatch] pub trait RecordBatchStream: Stream> { @@ -86,6 +87,73 @@ impl Stream for EmptyRecordBatchStream { } } +/// CombinedRecordBatchStream can be used to combine a Vec of SendableRecordBatchStreams into one +pub struct CombinedRecordBatchStream { + /// Schema wrapped by Arc + schema: SchemaRef, + /// Stream entries + entries: Vec, +} + +impl CombinedRecordBatchStream { + /// Create an CombinedRecordBatchStream + pub fn new(schema: SchemaRef, entries: Vec) -> Self { + Self { schema, entries } + } +} + +impl RecordBatchStream for CombinedRecordBatchStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +impl Stream for CombinedRecordBatchStream { + type Item = ArrowResult; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + use Poll::*; + + let start = thread_rng_n(self.entries.len() as u32) as usize; + let mut idx = start; + + for _ in 0..self.entries.len() { + let stream = self.entries.get_mut(idx).unwrap(); + + match Pin::new(stream).poll_next(cx) { + Ready(Some(val)) => return Ready(Some(val)), + Ready(None) => { + // Remove the entry + self.entries.swap_remove(idx); + + // Check if this was the last entry, if so the cursor needs + // to wrap + if idx == self.entries.len() { + idx = 0; + } else if idx < start && start <= self.entries.len() { + // The stream being swapped into the current index has + // already been polled, so skip it. + idx = idx.wrapping_add(1) % self.entries.len(); + } + } + Pending => { + idx = idx.wrapping_add(1) % self.entries.len(); + } + } + } + + // If the map is empty, then the stream is complete. + if self.entries.is_empty() { + Ready(None) + } else { + Pending + } + } +} + /// Physical planner interface pub use self::planner::PhysicalPlanner; @@ -110,6 +178,8 @@ pub trait ExecutionPlan: Debug + Send + Sync { /// Specifies the output partitioning scheme of this plan fn output_partitioning(&self) -> Partitioning; + /// Describe how data is ordered in each partition. + /// /// If the output of this operator is sorted, returns `Some(keys)` /// with the description of how it was sorted. /// @@ -122,10 +192,20 @@ pub trait ExecutionPlan: Debug + Send + Sync { /// have any particular output order here fn output_ordering(&self) -> Option<&[PhysicalSortExpr]>; - /// Specifies the data distribution requirements of all the - /// children for this operator - fn required_child_distribution(&self) -> Distribution { - Distribution::UnspecifiedDistribution + /// Specifies the data distribution requirements for all the + /// children for this operator, By default it's [[Distribution::UnspecifiedDistribution]] for each child, + fn required_input_distribution(&self) -> Vec { + if !self.children().is_empty() { + vec![Distribution::UnspecifiedDistribution; self.children().len()] + } else { + vec![Distribution::UnspecifiedDistribution] + } + } + + /// Specifies the ordering requirements for all the + /// children for this operator. + fn required_input_ordering(&self) -> Vec> { + vec![None; self.children().len()] } /// Returns `true` if this operator relies on its inputs being @@ -136,13 +216,17 @@ pub trait ExecutionPlan: Debug + Send + Sync { /// optimizations which might reorder the inputs (such as /// repartitioning to increase concurrency). /// - /// The default implementation returns `true` + /// The default implementation checks the input ordering requirements + /// and if there is non empty ordering requirements to the input, the method will + /// return `true`. /// /// WARNING: if you override this default and return `false`, your /// operator can not rely on DataFusion preserving the input order /// as it will likely not. fn relies_on_input_order(&self) -> bool { - true + self.required_input_ordering() + .iter() + .any(|ordering| matches!(ordering, Some(_))) } /// Returns `false` if this operator's implementation may reorder @@ -172,15 +256,18 @@ pub trait ExecutionPlan: Debug + Send + Sync { /// /// The default implementation returns `true` unless this operator /// has signalled it requires a single child input partition. - fn benefits_from_input_partitioning(&self) -> bool { + fn prefer_parallel(&self) -> bool { // By default try to maximize parallelism with more CPUs if - // possible - !matches!( - self.required_child_distribution(), - Distribution::SinglePartition - ) + // possibles + !self + .required_input_distribution() + .into_iter() + .any(|dist| matches!(dist, Distribution::SinglePartition)) } + /// Get a list of equivalence properties within the plan + fn equivalence_properties(&self) -> Vec>; + /// Get a list of child execution plans that provide the input for this plan. The returned list /// will be empty for leaf nodes, will contain a single value for unary nodes, or two /// values for binary nodes (such as joins). @@ -385,6 +472,164 @@ pub fn visit_execution_plan( Ok(()) } +/// a Trait for marking tree node types that are rewritable +pub trait TreeNodeRewritable: Clone { + /// Transform the tree node using the given [TreeNodeRewriter] + /// It performs a depth first walk of an node and its children. + /// + /// For an node tree such as + /// ```text + /// ParentNode + /// left: ChildNode1 + /// right: ChildNode2 + /// ``` + /// + /// The nodes are visited using the following order + /// ```text + /// pre_visit(ParentNode) + /// pre_visit(ChildNode1) + /// mutatate(ChildNode1) + /// pre_visit(ChildNode2) + /// mutate(ChildNode2) + /// mutate(ParentNode) + /// ``` + /// + /// If an Err result is returned, recursion is stopped immediately + /// + /// If [`false`] is returned on a call to pre_visit, no + /// children of that node are visited, nor is mutate + /// called on that node + /// + fn transform_using>( + self, + rewriter: &mut R, + ) -> Result { + let need_mutate = match rewriter.pre_visit(&self)? { + RewriteRecursion::Mutate => return rewriter.mutate(self), + RewriteRecursion::Stop => return Ok(self), + RewriteRecursion::Continue => true, + RewriteRecursion::Skip => false, + }; + + let after_op_children = + self.map_children_mut(|node| node.transform_using(rewriter))?; + + // now rewrite this node itself + if need_mutate { + rewriter.mutate(after_op_children) + } else { + Ok(after_op_children) + } + } + + /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree. + /// When `op` does not apply to a given node, it is left uncshanged. + /// The default tree traversal direction is transform_up(Postorder Traversal). + fn transform(self, op: &F) -> Result + where + F: Fn(Self) -> Option, + { + self.transform_up(op) + } + + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its + /// children(Preorder Traversal). + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_down(self, op: &F) -> Result + where + F: Fn(Self) -> Option, + { + let node_cloned = self.clone(); + let after_op = match op(node_cloned) { + Some(value) => value, + None => self, + }; + after_op.map_children(|node| node.transform_down(op)) + } + + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its + /// children and then itself(Postorder Traversal). + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_up(self, op: &F) -> Result + where + F: Fn(Self) -> Option, + { + let after_op_children = self.map_children(|node| node.transform_up(op))?; + + let after_op_children_clone = after_op_children.clone(); + let new_node = match op(after_op_children) { + Some(value) => value, + None => after_op_children_clone, + }; + Ok(new_node) + } + + /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) + fn map_children(self, transform: F) -> Result + where + F: Fn(Self) -> Result; + + /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) + fn map_children_mut(self, transform: F) -> Result + where + F: FnMut(Self) -> Result; +} + +/// Trait for potentially recursively transform an [`TreeNodeRewritable`] node +/// tree. When passed to `TreeNodeRewritable::transform_using`, `TreeNodeRewriter::mutate` is +/// invoked recursively on all nodes of a tree. +pub trait TreeNodeRewriter: Sized { + /// Invoked before (Preorder) any children of `node` are rewritten / + /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)` + fn pre_visit(&mut self, _node: &N) -> Result { + Ok(RewriteRecursion::Continue) + } + + /// Invoked after (Postorder) all children of `node` have been mutated and + /// returns a potentially modified node. + fn mutate(&mut self, node: N) -> Result; +} + +/// Controls how the [TreeNodeRewriter] recursion should proceed. +pub enum RewriteRecursion { + /// Continue rewrite / visit this node tree. + Continue, + /// Call 'op' immediately and return. + Mutate, + /// Do not rewrite / visit the children of this node. + Stop, + /// Keep recursive but skip apply op on this node + Skip, +} + +impl TreeNodeRewritable for Arc { + fn map_children(self, transform: F) -> Result + where + F: Fn(Self) -> Result, + { + if !self.children().is_empty() { + let new_children: Result> = + self.children().into_iter().map(transform).collect(); + with_new_children_if_necessary(self, new_children?) + } else { + Ok(self) + } + } + + fn map_children_mut(self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + if !self.children().is_empty() { + let new_children: Result> = + self.children().into_iter().map(transform).collect(); + with_new_children_if_necessary(self, new_children?) + } else { + Ok(self) + } + } +} + /// Execute the [ExecutionPlan] and collect the results in memory pub async fn collect( plan: Arc, @@ -458,6 +703,82 @@ impl Partitioning { RoundRobinBatch(n) | Hash(_, n) | UnknownPartitioning(n) => *n, } } + + /// Returns true when the guarantees made by this [[Partitioning]] are sufficient to + /// satisfy the partitioning scheme mandated by the `required` [[Distribution]] + pub fn satisfy Vec>>( + &self, + required: Distribution, + equal_properties: F, + ) -> bool { + match required { + Distribution::UnspecifiedDistribution => true, + Distribution::SinglePartition if self.partition_count() == 1 => true, + Distribution::HashPartitioned(required_exprs) => { + match self { + // Here we do not check the partition count for hash partitioning and assumes the partition count + // and hash functions in the system are the same. In future if we plan to support storage partition-wise joins, + // then we need to have the partition count and hash functions validation. + Partitioning::Hash(partition_exprs, _) => { + let fast_match = + expr_list_eq_strict_order(&required_exprs, partition_exprs); + // If the required exprs do not match, need to leverage the eq_properties provided by the child + // and normalize both exprs based on the eq_properties + if !fast_match { + let eq_properties = equal_properties(); + if !eq_properties.is_empty() { + let normalized_required_exprs = required_exprs + .iter() + .map(|e| { + normalize_expr_with_equivalence_properties( + e.clone(), + &eq_properties, + ) + }) + .collect::>(); + let normalized_partition_exprs = partition_exprs + .iter() + .map(|e| { + normalize_expr_with_equivalence_properties( + e.clone(), + &eq_properties, + ) + }) + .collect::>(); + expr_list_eq_strict_order( + &normalized_required_exprs, + &normalized_partition_exprs, + ) + } else { + fast_match + } + } else { + fast_match + } + } + _ => false, + } + } + _ => false, + } + } +} + +impl PartialEq for Partitioning { + fn eq(&self, other: &Partitioning) -> bool { + match (self, other) { + ( + Partitioning::RoundRobinBatch(count1), + Partitioning::RoundRobinBatch(count2), + ) if count1 == count2 => true, + (Partitioning::Hash(exprs1, count1), Partitioning::Hash(exprs2, count2)) + if expr_list_eq_strict_order(exprs1, exprs2) && (count1 == count2) => + { + true + } + _ => false, + } + } } /// Distribution schemes @@ -472,6 +793,21 @@ pub enum Distribution { HashPartitioned(Vec>), } +impl Distribution { + /// Creates a Partitioning for this Distribution to satisfy itself + pub fn create_partitioning(&self, partition_count: usize) -> Partitioning { + match self { + Distribution::UnspecifiedDistribution => { + Partitioning::UnknownPartitioning(partition_count) + } + Distribution::SinglePartition => Partitioning::UnknownPartitioning(1), + Distribution::HashPartitioned(expr) => { + Partitioning::Hash(expr.clone(), partition_count) + } + } + } +} + pub use datafusion_physical_expr::window::WindowExpr; pub use datafusion_physical_expr::{AggregateExpr, PhysicalExpr}; @@ -541,4 +877,8 @@ pub mod values; pub mod windows; use crate::execution::context::TaskContext; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::{ + expr_list_eq_strict_order, normalize_expr_with_equivalence_properties, +}; pub use datafusion_physical_expr::{expressions, functions, type_coercion, udf}; diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 698cface10726..d12eb21e7ef92 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -522,8 +522,8 @@ impl DefaultPhysicalPlanner { && session_state.config.target_partitions > 1 && session_state.config.repartition_windows; - let input_exec = if can_repartition { - let partition_keys = partition_keys + let physical_partition_keys = if can_repartition { + partition_keys .iter() .map(|e| { self.create_physical_expr( @@ -533,19 +533,11 @@ impl DefaultPhysicalPlanner { session_state, ) }) - .collect::>>>()?; - Arc::new(RepartitionExec::try_new( - input_exec, - Partitioning::Hash( - partition_keys, - session_state.config.target_partitions, - ), - )?) + .collect::>>>()? } else { - input_exec + vec![] }; - // add a sort phase let get_sort_keys = |expr: &Expr| match expr { Expr::WindowFunction { ref partition_by, @@ -566,8 +558,8 @@ impl DefaultPhysicalPlanner { let logical_input_schema = input.schema(); - let input_exec = if sort_keys.is_empty() { - input_exec + let physical_sort_keys = if sort_keys.is_empty() { + None } else { let physical_input_schema = input_exec.schema(); let sort_keys = sort_keys @@ -590,11 +582,7 @@ impl DefaultPhysicalPlanner { _ => unreachable!(), }) .collect::>>()?; - Arc::new(if can_repartition { - SortExec::new_with_partitioning(sort_keys, input_exec, true, None) - } else { - SortExec::try_new(sort_keys, input_exec, None)? - }) + Some(sort_keys) }; let physical_input_schema = input_exec.schema(); @@ -614,7 +602,9 @@ impl DefaultPhysicalPlanner { window_expr, input_exec, physical_input_schema, - )?)) + physical_partition_keys, + physical_sort_keys, + )?) ) } LogicalPlan::Aggregate(Aggregate { input, @@ -664,16 +654,8 @@ impl DefaultPhysicalPlanner { Arc, AggregateMode, ) = if can_repartition { - // Divide partial hash aggregates into multiple partitions by hash key - let hash_repartition = Arc::new(RepartitionExec::try_new( - initial_aggr, - Partitioning::Hash( - final_group.clone(), - session_state.config.target_partitions, - ), - )?); - // Combine hash aggregates within the partition - (hash_repartition, AggregateMode::FinalPartitioned) + // construct a second aggregation with 'AggregateMode::FinalPartitioned' + (initial_aggr, AggregateMode::FinalPartitioned) } else { // construct a second aggregation, keeping the final column name equal to the // first aggregation and the expressions corresponding to the respective aggregate @@ -941,32 +923,10 @@ impl DefaultPhysicalPlanner { if session_state.config.target_partitions > 1 && session_state.config.repartition_joins { - let (left_expr, right_expr) = join_on - .iter() - .map(|(l, r)| { - ( - Arc::new(l.clone()) as Arc, - Arc::new(r.clone()) as Arc, - ) - }) - .unzip(); - // Use hash partition by default to parallelize hash joins Ok(Arc::new(HashJoinExec::try_new( - Arc::new(RepartitionExec::try_new( - physical_left, - Partitioning::Hash( - left_expr, - session_state.config.target_partitions, - ), - )?), - Arc::new(RepartitionExec::try_new( - physical_right, - Partitioning::Hash( - right_expr, - session_state.config.target_partitions, - ), - )?), + physical_left, + physical_right, join_on, join_filter, join_type, @@ -2307,8 +2267,8 @@ mod tests { None } - fn relies_on_input_order(&self) -> bool { - false + fn equivalence_properties(&self) -> Vec> { + vec![] } fn children(&self) -> Vec> { diff --git a/datafusion/core/src/physical_plan/projection.rs b/datafusion/core/src/physical_plan/projection.rs index 5fa3c93cdd421..aba68ea8ac1ec 100644 --- a/datafusion/core/src/physical_plan/projection.rs +++ b/datafusion/core/src/physical_plan/projection.rs @@ -21,7 +21,7 @@ //! projection expressions. `SELECT` without `FROM` will only evaluate expressions. use std::any::Any; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -39,6 +39,10 @@ use super::expressions::{Column, PhysicalSortExpr}; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{RecordBatchStream, SendableRecordBatchStream, Statistics}; use crate::execution::context::TaskContext; +use datafusion_physical_expr::{ + merge_equivalence_properties_with_alias, normalize_out_expr_with_alias_schema, + truncate_equivalence_properties_not_in_schema, +}; use futures::stream::Stream; use futures::stream::StreamExt; @@ -46,11 +50,13 @@ use futures::stream::StreamExt; #[derive(Debug)] pub struct ProjectionExec { /// The projection expressions stored as tuples of (expression, output column name) - expr: Vec<(Arc, String)>, + pub expr: Vec<(Arc, String)>, + /// The alias map used to normalize out expressions like Partitioning and PhysicalSortExpr + pub alias_map: HashMap>, /// The schema once the projection has been applied to the input - schema: SchemaRef, + pub schema: SchemaRef, /// The input plan - input: Arc, + pub input: Arc, /// Execution metrics metrics: ExecutionPlanMetricsSet, } @@ -82,8 +88,21 @@ impl ProjectionExec { input_schema.metadata().clone(), )); + let mut alias_map: HashMap> = HashMap::new(); + for (expression, name) in expr.iter() { + if let Some(column) = expression.as_any().downcast_ref::() { + let new_col_idx = schema.index_of(name)?; + // When the column name is the same, but index does not equal, treat it as Alias + if (column.name() != name) || (column.index() != new_col_idx) { + let entry = alias_map.entry(column.clone()).or_insert_with(Vec::new); + entry.push(Column::new(name, new_col_idx)); + } + }; + } + Ok(Self { expr, + alias_map, schema, input: input.clone(), metrics: ExecutionPlanMetricsSet::new(), @@ -118,10 +137,28 @@ impl ExecutionPlan for ProjectionExec { /// Get the output partitioning of this plan fn output_partitioning(&self) -> Partitioning { - self.input.output_partitioning() + // Output partition need to respect the Alias + let input_partition = self.input.output_partitioning(); + match input_partition { + Partitioning::Hash(exprs, part) => { + let normalized_exprs = exprs + .into_iter() + .map(|expr| { + normalize_out_expr_with_alias_schema( + expr, + &self.alias_map, + &self.schema, + ) + }) + .collect::>(); + Partitioning::Hash(normalized_exprs, part) + } + _ => input_partition, + } } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + // TODO Output ordering need to respect the Alias self.input.output_ordering() } @@ -130,8 +167,21 @@ impl ExecutionPlan for ProjectionExec { true } - fn relies_on_input_order(&self) -> bool { - false + // Equivalence properties need to be adjusted after the Projection. + // 1) Add Alias, Alias can introduce additional equivalence properties, + // For example: Projection(a, a as a1, a as a2) + // 2) Truncate the properties that are not in the schema of the Projection + fn equivalence_properties(&self) -> Vec> { + let mut input_equivalence_properties = self.input.equivalence_properties(); + merge_equivalence_properties_with_alias( + &mut input_equivalence_properties, + &self.alias_map, + ); + truncate_equivalence_properties_not_in_schema( + &mut input_equivalence_properties, + &self.schema, + ); + input_equivalence_properties } fn with_new_children( diff --git a/datafusion/core/src/physical_plan/repartition.rs b/datafusion/core/src/physical_plan/repartition.rs index 5611989f0ea9c..d9deb5c9140b3 100644 --- a/datafusion/core/src/physical_plan/repartition.rs +++ b/datafusion/core/src/physical_plan/repartition.rs @@ -39,6 +39,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use super::{RecordBatchStream, SendableRecordBatchStream}; use crate::execution::context::TaskContext; +use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::PhysicalExpr; use futures::stream::Stream; use futures::StreamExt; @@ -272,10 +273,6 @@ impl ExecutionPlan for RepartitionExec { vec![self.input.clone()] } - fn relies_on_input_order(&self) -> bool { - false - } - fn with_new_children( self: Arc, children: Vec>, @@ -294,6 +291,10 @@ impl ExecutionPlan for RepartitionExec { None } + fn equivalence_properties(&self) -> Vec> { + self.input.equivalence_properties() + } + fn execute( &self, partition: usize, diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs index 763c7c553f41f..03510c7f43ef5 100644 --- a/datafusion/core/src/physical_plan/sorts/sort.rs +++ b/datafusion/core/src/physical_plan/sorts/sort.rs @@ -46,6 +46,7 @@ use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::ipc::reader::FileReader; use arrow::record_batch::RecordBatch; use async_trait::async_trait; +use datafusion_physical_expr::expressions::Column; use futures::lock::Mutex; use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt}; use log::{debug, error}; @@ -743,11 +744,13 @@ impl ExecutionPlan for SortExec { } } - fn required_child_distribution(&self) -> Distribution { + fn required_input_distribution(&self) -> Vec { if self.preserve_partitioning { - Distribution::UnspecifiedDistribution + vec![Distribution::UnspecifiedDistribution] } else { - Distribution::SinglePartition + // global sort + // TODO support RangePartition and OrderedDistribution + vec![Distribution::SinglePartition] } } @@ -755,12 +758,7 @@ impl ExecutionPlan for SortExec { vec![self.input.clone()] } - fn relies_on_input_order(&self) -> bool { - // this operator resorts everything - false - } - - fn benefits_from_input_partitioning(&self) -> bool { + fn prefer_parallel(&self) -> bool { false } @@ -768,6 +766,10 @@ impl ExecutionPlan for SortExec { Some(&self.expr) } + fn equivalence_properties(&self) -> Vec> { + self.input.equivalence_properties() + } + fn with_new_children( self: Arc, children: Vec>, diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs index 22bd83ec9ae8b..3f4e424b363c1 100644 --- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs +++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs @@ -32,6 +32,7 @@ use arrow::{ error::Result as ArrowResult, record_batch::RecordBatch, }; +use datafusion_physical_expr::expressions::Column; use futures::stream::{Fuse, FusedStream}; use futures::{Stream, StreamExt}; use log::debug; @@ -123,18 +124,22 @@ impl ExecutionPlan for SortPreservingMergeExec { Partitioning::UnknownPartitioning(1) } - fn required_child_distribution(&self) -> Distribution { - Distribution::UnspecifiedDistribution + fn required_input_distribution(&self) -> Vec { + vec![Distribution::UnspecifiedDistribution] } - fn relies_on_input_order(&self) -> bool { - true + fn required_input_ordering(&self) -> Vec> { + vec![Some(&self.expr)] } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { Some(&self.expr) } + fn equivalence_properties(&self) -> Vec> { + self.input.equivalence_properties() + } + fn children(&self) -> Vec> { vec![self.input.clone()] } diff --git a/datafusion/core/src/physical_plan/union.rs b/datafusion/core/src/physical_plan/union.rs index bf9dfbd1b694c..ec297a4069f16 100644 --- a/datafusion/core/src/physical_plan/union.rs +++ b/datafusion/core/src/physical_plan/union.rs @@ -27,9 +27,11 @@ use arrow::{ datatypes::{Field, Schema, SchemaRef}, record_batch::RecordBatch, }; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::sort_expr_list_eq_strict_order; use futures::StreamExt; use itertools::Itertools; -use log::debug; +use log::{debug, warn}; use super::{ expressions::PhysicalSortExpr, @@ -38,6 +40,7 @@ use super::{ SendableRecordBatchStream, Statistics, }; use crate::execution::context::TaskContext; +use crate::physical_plan::CombinedRecordBatchStream; use crate::{ error::Result, physical_plan::{expressions, metrics::BaselineMetrics}, @@ -52,6 +55,8 @@ pub struct UnionExec { metrics: ExecutionPlanMetricsSet, /// Schema of Union schema: SchemaRef, + /// Partition aware Union + partition_aware: bool, } impl UnionExec { @@ -78,10 +83,24 @@ impl UnionExec { inputs[0].schema().metadata().clone(), )); + // If all the input partitions have the same Hash partition spec with the first_input_partition + // The UnionExec is partition aware. + // + // It might be too strict here in the case that the input partition specs are compatible but not exactly the same. + // For example one input partition has the partition spec Hash('a','b','c') and + // other has the partition spec Hash('a'), It is safe to derive the out partition with the spec Hash('a','b','c'). + let first_input_partition = inputs[0].output_partitioning(); + let partition_aware = matches!(first_input_partition, Partitioning::Hash(_, _)) + && inputs + .iter() + .map(|plan| plan.output_partitioning()) + .all(|partition| partition == first_input_partition); + UnionExec { inputs, metrics: ExecutionPlanMetricsSet::new(), schema, + partition_aware, } } @@ -105,25 +124,52 @@ impl ExecutionPlan for UnionExec { self.inputs.clone() } - /// Output of the union is the combination of all output partitions of the inputs fn output_partitioning(&self) -> Partitioning { - // Sums all the output partitions - let num_partitions = self - .inputs - .iter() - .map(|plan| plan.output_partitioning().partition_count()) - .sum(); - // TODO: this loses partitioning info in case of same partitioning scheme (for example `Partitioning::Hash`) - // https://issues.apache.org/jira/browse/ARROW-11991 - Partitioning::UnknownPartitioning(num_partitions) + if self.partition_aware { + self.inputs[0].output_partitioning() + } else { + // Output the combination of all output partitions of the inputs if the Union is not partition aware + let num_partitions = self + .inputs + .iter() + .map(|plan| plan.output_partitioning().partition_count()) + .sum(); + + Partitioning::UnknownPartitioning(num_partitions) + } } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - None + let first_input_ordering = self.inputs[0].output_ordering(); + // If the Union is not partition aware and all the input ordering spec strictly equal with the first_input_ordering + // Return the first_input_ordering as the output_ordering + // + // It might be too strict here in the case that the input ordering are compatible but not exactly the same. + // For example one input ordering has the ordering spec SortExpr('a','b','c') and the other has the ordering + // spec SortExpr('a'), It is safe to derive the out ordering with the spec SortExpr('a'). + if !self.partition_aware + && first_input_ordering.is_some() + && self + .inputs + .iter() + .map(|plan| plan.output_ordering()) + .all(|ordering| { + ordering.is_some() + && sort_expr_list_eq_strict_order( + ordering.unwrap(), + first_input_ordering.unwrap(), + ) + }) + { + first_input_ordering + } else { + None + } } - fn relies_on_input_order(&self) -> bool { - false + fn equivalence_properties(&self) -> Vec> { + // TODO calculate the common equivalence properties among all the inputs + vec![] } fn with_new_children( @@ -145,19 +191,38 @@ impl ExecutionPlan for UnionExec { let elapsed_compute = baseline_metrics.elapsed_compute().clone(); let _timer = elapsed_compute.timer(); // record on drop - // find partition to execute - for input in self.inputs.iter() { - // Calculate whether partition belongs to the current partition - if partition < input.output_partitioning().partition_count() { - let stream = input.execute(partition, context)?; - debug!("Found a Union partition to execute"); + if self.partition_aware { + let mut input_stream_vec = vec![]; + for input in self.inputs.iter() { + if partition < input.output_partitioning().partition_count() { + input_stream_vec.push(input.execute(partition, context.clone())?); + } else { + // Do not find a partition to execute + break; + } + } + if input_stream_vec.len() == self.inputs.len() { + let stream = Box::pin(CombinedRecordBatchStream::new( + self.schema(), + input_stream_vec, + )); return Ok(Box::pin(ObservedStream::new(stream, baseline_metrics))); - } else { - partition -= input.output_partitioning().partition_count(); + } + } else { + // find partition to execute + for input in self.inputs.iter() { + // Calculate whether partition belongs to the current partition + if partition < input.output_partitioning().partition_count() { + let stream = input.execute(partition, context)?; + debug!("Found a Union partition to execute"); + return Ok(Box::pin(ObservedStream::new(stream, baseline_metrics))); + } else { + partition -= input.output_partitioning().partition_count(); + } } } - debug!("Error in Union: Partition {} not found", partition); + warn!("Error in Union: Partition {} not found", partition); Err(crate::error::DataFusionError::Execution(format!( "Partition {} not found in Union", @@ -189,7 +254,7 @@ impl ExecutionPlan for UnionExec { .unwrap_or_default() } - fn benefits_from_input_partitioning(&self) -> bool { + fn prefer_parallel(&self) -> bool { false } } diff --git a/datafusion/core/src/physical_plan/values.rs b/datafusion/core/src/physical_plan/values.rs index 897936814ceea..56a8d2e494ac0 100644 --- a/datafusion/core/src/physical_plan/values.rs +++ b/datafusion/core/src/physical_plan/values.rs @@ -22,13 +22,14 @@ use super::{common, SendableRecordBatchStream, Statistics}; use crate::error::{DataFusionError, Result}; use crate::execution::context::TaskContext; use crate::physical_plan::{ - memory::MemoryStream, ColumnarValue, DisplayFormatType, Distribution, ExecutionPlan, - Partitioning, PhysicalExpr, + memory::MemoryStream, ColumnarValue, DisplayFormatType, ExecutionPlan, Partitioning, + PhysicalExpr, }; use crate::scalar::ScalarValue; use arrow::array::new_null_array; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use datafusion_physical_expr::expressions::Column; use std::any::Any; use std::sync::Arc; @@ -109,10 +110,6 @@ impl ExecutionPlan for ValuesExec { vec![] } - fn required_child_distribution(&self) -> Distribution { - Distribution::UnspecifiedDistribution - } - /// Get the output partitioning of this plan fn output_partitioning(&self) -> Partitioning { Partitioning::UnknownPartitioning(1) @@ -122,8 +119,8 @@ impl ExecutionPlan for ValuesExec { None } - fn relies_on_input_order(&self) -> bool { - false + fn equivalence_properties(&self) -> Vec> { + vec![] } fn with_new_children( diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs index 26cb14fe33a96..1f995c589132c 100644 --- a/datafusion/core/src/physical_plan/windows/mod.rs +++ b/datafusion/core/src/physical_plan/windows/mod.rs @@ -210,6 +210,8 @@ mod tests { ], input, schema.clone(), + vec![], + None, )?); let result: Vec = collect(window_exec, task_ctx).await?; @@ -255,6 +257,8 @@ mod tests { )?], blocking_exec, schema, + vec![], + None, )?); let fut = collect(window_agg_exec, task_ctx); diff --git a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs index e9eac35a3d883..5dfbaee1e5f5f 100644 --- a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs @@ -33,8 +33,11 @@ use arrow::{ error::{ArrowError, Result as ArrowResult}, record_batch::RecordBatch, }; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::PhysicalExpr; use futures::stream::Stream; use futures::{ready, StreamExt}; +use log::warn; use std::any::Any; use std::pin::Pin; use std::sync::Arc; @@ -44,13 +47,17 @@ use std::task::{Context, Poll}; #[derive(Debug)] pub struct WindowAggExec { /// Input plan - input: Arc, + pub input: Arc, /// Window function expression - window_expr: Vec>, + pub window_expr: Vec>, /// Schema after the window is run schema: SchemaRef, /// Schema before the window - input_schema: SchemaRef, + pub input_schema: SchemaRef, + /// Partition Keys + pub partition_keys: Vec>, + /// Sort Keys + pub sort_keys: Option>, /// Execution metrics metrics: ExecutionPlanMetricsSet, } @@ -61,6 +68,8 @@ impl WindowAggExec { window_expr: Vec>, input: Arc, input_schema: SchemaRef, + partition_keys: Vec>, + sort_keys: Option>, ) -> Result { let schema = create_schema(&input_schema, &window_expr)?; let schema = Arc::new(schema); @@ -69,6 +78,8 @@ impl WindowAggExec { window_expr, schema, input_schema, + partition_keys, + sort_keys, metrics: ExecutionPlanMetricsSet::new(), }) } @@ -119,22 +130,25 @@ impl ExecutionPlan for WindowAggExec { true } - fn relies_on_input_order(&self) -> bool { - true + fn required_input_ordering(&self) -> Vec> { + let sort_keys = self.sort_keys.as_deref(); + vec![sort_keys] } - fn required_child_distribution(&self) -> Distribution { - if self - .window_expr() - .iter() - .all(|expr| expr.partition_by().is_empty()) - { - Distribution::SinglePartition + fn required_input_distribution(&self) -> Vec { + if self.partition_keys.is_empty() { + warn!("No partition defined for WindowAggExec!!!"); + vec![Distribution::SinglePartition] } else { - Distribution::UnspecifiedDistribution + //TODO support PartitionCollections if there is no common partition columns in the window_expr + vec![Distribution::HashPartitioned(self.partition_keys.clone())] } } + fn equivalence_properties(&self) -> Vec> { + self.input.equivalence_properties() + } + fn with_new_children( self: Arc, children: Vec>, @@ -143,6 +157,8 @@ impl ExecutionPlan for WindowAggExec { self.window_expr.clone(), children[0].clone(), self.input_schema.clone(), + self.partition_keys.clone(), + self.sort_keys.clone(), )?)) } diff --git a/datafusion/core/src/scheduler/pipeline/execution.rs b/datafusion/core/src/scheduler/pipeline/execution.rs index 20e7c6e79a48c..66257f29f7538 100644 --- a/datafusion/core/src/scheduler/pipeline/execution.rs +++ b/datafusion/core/src/scheduler/pipeline/execution.rs @@ -22,6 +22,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll, Waker}; +use datafusion_physical_expr::expressions::Column; use futures::{Stream, StreamExt}; use parking_lot::Mutex; @@ -235,20 +236,20 @@ impl ExecutionPlan for ProxyExecutionPlan { self.inner.output_ordering() } - fn required_child_distribution(&self) -> Distribution { - self.inner.required_child_distribution() - } - - fn relies_on_input_order(&self) -> bool { - self.inner.relies_on_input_order() + fn required_input_distribution(&self) -> Vec { + self.inner.required_input_distribution() } fn maintains_input_order(&self) -> bool { self.inner.maintains_input_order() } - fn benefits_from_input_partitioning(&self) -> bool { - self.inner.benefits_from_input_partitioning() + fn prefer_parallel(&self) -> bool { + self.inner.prefer_parallel() + } + + fn equivalence_properties(&self) -> Vec> { + vec![] } fn children(&self) -> Vec> { diff --git a/datafusion/core/src/test/exec.rs b/datafusion/core/src/test/exec.rs index 855bb3bbc11e7..f15afe947695e 100644 --- a/datafusion/core/src/test/exec.rs +++ b/datafusion/core/src/test/exec.rs @@ -30,6 +30,7 @@ use arrow::{ error::{ArrowError, Result as ArrowResult}, record_batch::RecordBatch, }; +use datafusion_physical_expr::expressions::Column; use futures::Stream; use crate::execution::context::TaskContext; @@ -154,6 +155,10 @@ impl ExecutionPlan for MockExec { None } + fn equivalence_properties(&self) -> Vec> { + vec![] + } + fn children(&self) -> Vec> { unimplemented!() } @@ -292,6 +297,10 @@ impl ExecutionPlan for BarrierExec { None } + fn equivalence_properties(&self) -> Vec> { + vec![] + } + fn children(&self) -> Vec> { unimplemented!() } @@ -392,6 +401,10 @@ impl ExecutionPlan for ErrorExec { None } + fn equivalence_properties(&self) -> Vec> { + vec![] + } + fn children(&self) -> Vec> { unimplemented!() } @@ -471,6 +484,10 @@ impl ExecutionPlan for StatisticsExec { None } + fn equivalence_properties(&self) -> Vec> { + vec![] + } + fn children(&self) -> Vec> { vec![] } @@ -569,6 +586,10 @@ impl ExecutionPlan for BlockingExec { None } + fn equivalence_properties(&self) -> Vec> { + vec![] + } + fn with_new_children( self: Arc, _: Vec>, diff --git a/datafusion/core/tests/custom_sources.rs b/datafusion/core/tests/custom_sources.rs index 14daf7b6d60f6..933db5a902d18 100644 --- a/datafusion/core/tests/custom_sources.rs +++ b/datafusion/core/tests/custom_sources.rs @@ -37,6 +37,7 @@ use datafusion::{ physical_plan::collect, }; use datafusion::{error::Result, physical_plan::DisplayFormatType}; +use datafusion_physical_expr::expressions::Column; use futures::stream::Stream; use std::any::Any; @@ -117,6 +118,10 @@ impl ExecutionPlan for CustomExecutionPlan { None } + fn equivalence_properties(&self) -> Vec> { + vec![] + } + fn children(&self) -> Vec> { vec![] } diff --git a/datafusion/core/tests/provider_filter_pushdown.rs b/datafusion/core/tests/provider_filter_pushdown.rs index c1aa5ad7095c4..9975cb41eee20 100644 --- a/datafusion/core/tests/provider_filter_pushdown.rs +++ b/datafusion/core/tests/provider_filter_pushdown.rs @@ -75,6 +75,12 @@ impl ExecutionPlan for CustomPlan { None } + fn equivalence_properties( + &self, + ) -> Vec> { + vec![] + } + fn children(&self) -> Vec> { vec![] } diff --git a/datafusion/core/tests/statistics.rs b/datafusion/core/tests/statistics.rs index bcf82481cf1f5..c164fc5fc9d69 100644 --- a/datafusion/core/tests/statistics.rs +++ b/datafusion/core/tests/statistics.rs @@ -35,6 +35,7 @@ use datafusion::{ use async_trait::async_trait; use datafusion::execution::context::{SessionState, TaskContext}; +use datafusion_physical_expr::expressions::Column; /// This is a testing structure for statistics /// It will act both as a table provider and execution plan @@ -128,6 +129,10 @@ impl ExecutionPlan for StatisticsValidation { None } + fn equivalence_properties(&self) -> Vec> { + vec![] + } + fn children(&self) -> Vec> { vec![] } diff --git a/datafusion/core/tests/user_defined_plan.rs b/datafusion/core/tests/user_defined_plan.rs index c93b7c3eb4065..1f9087c820b23 100644 --- a/datafusion/core/tests/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined_plan.rs @@ -84,6 +84,7 @@ use datafusion::{ }, prelude::{SessionConfig, SessionContext}, }; +use datafusion_physical_expr::expressions::Column; use fmt::Debug; use std::task::{Context, Poll}; @@ -441,12 +442,12 @@ impl ExecutionPlan for TopKExec { None } - fn relies_on_input_order(&self) -> bool { - false + fn required_input_distribution(&self) -> Vec { + vec![Distribution::SinglePartition] } - fn required_child_distribution(&self) -> Distribution { - Distribution::SinglePartition + fn equivalence_properties(&self) -> Vec> { + vec![] } fn children(&self) -> Vec> { diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index e0703829d394a..d281b9543b281 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -48,6 +48,7 @@ datafusion-common = { path = "../common", version = "13.0.0" } datafusion-expr = { path = "../expr", version = "13.0.0" } datafusion-row = { path = "../row", version = "13.0.0" } hashbrown = { version = "0.12", features = ["raw"] } +itertools = { version = "0.10", features = ["use_std"] } lazy_static = { version = "^1.4.0" } md-5 = { version = "^0.10.0", optional = true } ordered-float = "3.0" diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index d6c1240a0d87b..8b93f49c1a034 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -74,6 +74,7 @@ use kernels_arrow::{ use arrow::datatypes::{DataType, Schema, TimeUnit}; use arrow::record_batch::RecordBatch; +use crate::physical_expr::down_cast_any_ref; use crate::{ExprBoundaries, PhysicalExpr, PhysicalExprStats}; use datafusion_common::{ColumnStatistics, ScalarValue}; use datafusion_common::{DataFusionError, Result}; @@ -649,6 +650,30 @@ impl PhysicalExpr for BinaryExpr { right: Arc::clone(self.right()), }) } + + fn children(&self) -> Vec> { + vec![self.left.clone(), self.right.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(BinaryExpr::new( + children[0].clone(), + self.op, + children[1].clone(), + ))) + } +} + +impl PartialEq for BinaryExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.left.eq(&x.left) && self.op == x.op && self.right.eq(&x.right)) + .unwrap_or(false) + } } struct BinaryExprStats { diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 581d9abdce382..88365e0be1c4c 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -26,6 +26,9 @@ use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; +use itertools::Itertools; + +use crate::expressions::no_op::NoOp; type WhenThen = (Arc, Arc); @@ -286,6 +289,67 @@ impl PhysicalExpr for CaseExpr { self.case_when_no_expr(batch) } } + + fn children(&self) -> Vec> { + let mut chileren = vec![]; + match &self.expr { + Some(expr) => chileren.push(expr.clone()), + None => chileren.push(Arc::new(NoOp::new())), + } + self.when_then_expr.iter().for_each(|(cond, value)| { + chileren.push(cond.clone()); + chileren.push(value.clone()); + }); + + match &self.else_expr { + Some(expr) => chileren.push(expr.clone()), + None => chileren.push(Arc::new(NoOp::new())), + } + chileren + } + + // For physical CaseExpr, we do not allow modifying children size + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() != self.children().len() { + Err(DataFusionError::Internal( + "CaseExpr: Wrong number of children".to_string(), + )) + } else { + assert_eq!(children.len() % 2, 0); + let expr = match children[0].clone().as_any().downcast_ref::() { + Some(_) => None, + _ => Some(children[0].clone()), + }; + let else_expr = match children[children.len() - 1] + .clone() + .as_any() + .downcast_ref::() + { + Some(_) => None, + _ => Some(children[children.len() - 1].clone()), + }; + + let branches = children[1..children.len() - 1].to_vec(); + let mut when_then_expr: Vec = vec![]; + for (prev, next) in branches.into_iter().tuples() { + when_then_expr.push((prev, next)); + } + Ok(Arc::new(CaseExpr::try_new( + expr, + when_then_expr, + else_expr, + )?)) + } + } +} + +impl PartialEq for CaseExpr { + fn eq(&self, _other: &dyn Any) -> bool { + false + } } /// Create a CASE expression diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 99d2859e710af..02a110f5aae07 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::fmt; use std::sync::Arc; +use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::compute; use arrow::compute::kernels; @@ -93,6 +94,36 @@ impl PhysicalExpr for CastExpr { let value = self.expr.evaluate(batch)?; cast_column(&value, &self.cast_type, &self.cast_options) } + + fn children(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(CastExpr::new( + children[0].clone(), + self.cast_type.clone(), + CastOptions { + safe: self.cast_options.safe, + }, + ))) + } +} + +impl PartialEq for CastExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.expr.eq(&x.expr) + && self.cast_type == x.cast_type + && self.cast_options.safe == x.cast_options.safe + }) + .unwrap_or(false) + } } /// Internal cast function for casting ColumnarValue -> ColumnarValue for cast_type diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index dfb6eca40d3ff..450919aa3f2cc 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -17,6 +17,7 @@ //! Column expression +use std::any::Any; use std::sync::Arc; use arrow::{ @@ -24,6 +25,7 @@ use arrow::{ record_batch::RecordBatch, }; +use crate::physical_expr::down_cast_any_ref; use crate::{ExprBoundaries, PhysicalExpr, PhysicalExprStats}; use datafusion_common::{ColumnStatistics, DataFusionError, Result}; use datafusion_expr::ColumnarValue; @@ -94,6 +96,26 @@ impl PhysicalExpr for Column { fn expr_stats(&self) -> Arc { Arc::new(ColumnExprStats { index: self.index }) } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } +} + +impl PartialEq for Column { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self == x) + .unwrap_or(false) + } } #[derive(Debug, Clone)] diff --git a/datafusion/physical-expr/src/expressions/datetime.rs b/datafusion/physical-expr/src/expressions/datetime.rs index fa021f61a940f..3c7096c7256cd 100644 --- a/datafusion/physical-expr/src/expressions/datetime.rs +++ b/datafusion/physical-expr/src/expressions/datetime.rs @@ -16,6 +16,7 @@ // under the License. use crate::expressions::delta::shift_months; +use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::array::{ Array, ArrayRef, Date32Array, Date64Array, TimestampMicrosecondArray, @@ -42,6 +43,7 @@ pub struct DateTimeIntervalExpr { lhs: Arc, op: Operator, rhs: Arc, + input_schema: Schema, } impl DateTimeIntervalExpr { @@ -56,7 +58,12 @@ impl DateTimeIntervalExpr { DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _) => { match rhs.data_type(input_schema)? { DataType::Interval(_) => match &op { - Operator::Plus | Operator::Minus => Ok(Self { lhs, op, rhs }), + Operator::Plus | Operator::Minus => Ok(Self { + lhs, + op, + rhs, + input_schema: input_schema.clone(), + }), _ => Err(DataFusionError::Execution(format!( "Invalid operator '{}' for DateIntervalExpr", op @@ -140,6 +147,31 @@ impl PhysicalExpr for DateTimeIntervalExpr { ColumnarValue::Array(array) => evaluate_array(array, sign, intervals), } } + + fn children(&self) -> Vec> { + vec![self.lhs.clone(), self.rhs.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(DateTimeIntervalExpr::try_new( + children[0].clone(), + self.op, + children[1].clone(), + &self.input_schema, + )?)) + } +} + +impl PartialEq for DateTimeIntervalExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.lhs.eq(&x.lhs) && self.op == x.op && self.rhs.eq(&x.rhs)) + .unwrap_or(false) + } } pub fn evaluate_array( diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index 5d9b1594d4383..ff10c06e2206f 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -22,6 +22,7 @@ use arrow::array::Array; use arrow::array::{ListArray, StructArray}; use arrow::compute::concat; +use crate::physical_expr::down_cast_any_ref; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -132,6 +133,29 @@ impl PhysicalExpr for GetIndexedFieldExpr { (dt, key) => Err(DataFusionError::Execution(format!("get indexed field is only possible on lists with int64 indexes or struct with utf8 indexes. Tried {:?} with {:?} index", dt, key))), } } + + fn children(&self) -> Vec> { + vec![self.arg.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(GetIndexedFieldExpr::new( + children[0].clone(), + self.key.clone(), + ))) + } +} + +impl PartialEq for GetIndexedFieldExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.arg.eq(&x.arg) && self.key == x.key) + .unwrap_or(false) + } } #[cfg(test)] diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 2d34115220ece..a43b6042b2485 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::collections::HashSet; +use std::fmt::Debug; use std::sync::Arc; use arrow::array::GenericStringArray; @@ -33,6 +34,7 @@ use arrow::{ record_batch::RecordBatch, }; +use crate::physical_expr::{down_cast_any_ref, expr_list_eq_any_order}; use crate::PhysicalExpr; use arrow::array::*; use arrow::datatypes::TimeUnit; @@ -52,16 +54,27 @@ use datafusion_expr::ColumnarValue; static OPTIMIZER_INSET_THRESHOLD: usize = 30; /// InList -#[derive(Debug)] pub struct InListExpr { expr: Arc, list: Vec>, negated: bool, set: Option, + input_schema: Schema, +} + +impl Debug for InListExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("InListExpr") + .field("expr", &self.expr) + .field("list", &self.list) + .field("negated", &self.negated) + .field("set", &self.set) + .finish() + } } /// InSet -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub struct InSet { // TODO: optimization: In the `IN` or `NOT IN` we don't need to consider the NULL value // The data type is same, we can use set: HashSet @@ -386,6 +399,7 @@ impl InListExpr { set: Some(InSet::new(set)), list, negated, + input_schema: schema.clone(), }; } } @@ -394,6 +408,7 @@ impl InListExpr { list, negated, set: None, + input_schema: schema.clone(), } } @@ -899,6 +914,39 @@ impl PhysicalExpr for InListExpr { } } } + + fn children(&self) -> Vec> { + let mut children = vec![]; + children.push(self.expr.clone()); + children.extend(self.list.clone()); + children + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + in_list( + children[0].clone(), + children[1..].to_vec(), + &self.negated, + &self.input_schema, + ) + } +} + +impl PartialEq for InListExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.expr.eq(&x.expr) + && expr_list_eq_any_order(&self.list, &x.list) + && self.negated == x.negated + && self.set == x.set + }) + .unwrap_or(false) + } } /// Creates a unary expression InList diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 6b614f3d98ca1..4e24159acd13b 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -19,6 +19,7 @@ use std::{any::Any, sync::Arc}; +use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::compute; use arrow::{ @@ -79,6 +80,26 @@ impl PhysicalExpr for IsNotNullExpr { )), } } + + fn children(&self) -> Vec> { + vec![self.arg.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(IsNotNullExpr::new(children[0].clone()))) + } +} + +impl PartialEq for IsNotNullExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.arg.eq(&x.arg)) + .unwrap_or(false) + } } /// Create an IS NOT NULL expression diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index e5dbfbdc74819..6ee11820c7f09 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -25,6 +25,7 @@ use arrow::{ record_batch::RecordBatch, }; +use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use datafusion_common::Result; use datafusion_common::ScalarValue; @@ -80,6 +81,26 @@ impl PhysicalExpr for IsNullExpr { )), } } + + fn children(&self) -> Vec> { + vec![self.arg.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(IsNullExpr::new(children[0].clone()))) + } +} + +impl PartialEq for IsNullExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.arg.eq(&x.arg)) + .unwrap_or(false) + } } /// Create an IS NULL expression diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index b3e4a7a320aca..7957aaef240e8 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -25,13 +25,14 @@ use arrow::{ record_batch::RecordBatch, }; +use crate::physical_expr::down_cast_any_ref; use crate::{ExprBoundaries, PhysicalExpr, PhysicalExprStats}; use datafusion_common::ScalarValue; use datafusion_common::{ColumnStatistics, Result}; use datafusion_expr::{ColumnarValue, Expr}; /// Represents a literal value -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub struct Literal { value: ScalarValue, } @@ -77,6 +78,26 @@ impl PhysicalExpr for Literal { value: self.value.clone(), }) } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } +} + +impl PartialEq for Literal { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self == x) + .unwrap_or(false) + } } struct LiteralExprStats { diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 208e6d0b51fba..e97482075dc27 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -30,6 +30,7 @@ mod is_not_null; mod is_null; mod literal; mod negative; +mod no_op; mod not; mod nullif; mod try_cast; @@ -80,6 +81,7 @@ pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; pub use literal::{lit, Literal}; pub use negative::{negative, NegativeExpr}; +pub use no_op::NoOp; pub use not::{not, NotExpr}; pub use nullif::nullif_func; pub use try_cast::{try_cast, TryCastExpr}; diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index 4f4def99e8210..1fed005ce3b28 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -28,6 +28,7 @@ use arrow::{ record_batch::RecordBatch, }; +use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{ @@ -110,6 +111,26 @@ impl PhysicalExpr for NegativeExpr { } } } + + fn children(&self) -> Vec> { + vec![self.arg.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(NegativeExpr::new(children[0].clone()))) + } +} + +impl PartialEq for NegativeExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.arg.eq(&x.arg)) + .unwrap_or(false) + } } /// Creates a unary expression NEGATIVE diff --git a/datafusion/physical-expr/src/expressions/no_op.rs b/datafusion/physical-expr/src/expressions/no_op.rs new file mode 100644 index 0000000000000..b81bcfc30b905 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/no_op.rs @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Literal expressions for physical operations + +use std::any::Any; +use std::sync::Arc; + +use arrow::{ + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; + +use crate::physical_expr::down_cast_any_ref; +use crate::PhysicalExpr; +use datafusion_common::Result; +use datafusion_expr::ColumnarValue; + +/// A place holder expressions, can not be evaluated +#[derive(Debug, PartialEq, Eq, Default)] +pub struct NoOp {} + +impl NoOp { + /// Create a NoOp expression + pub fn new() -> Self { + Self {} + } +} + +impl std::fmt::Display for NoOp { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "NoOp") + } +} + +impl PhysicalExpr for NoOp { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(DataType::Null) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(true) + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + unimplemented!("NoOp::evaluate"); + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } +} + +impl PartialEq for NoOp { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self == x) + .unwrap_or(false) + } +} diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index a7fba60ec362a..00f1670af7dc3 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -21,6 +21,7 @@ use std::any::Any; use std::fmt; use std::sync::Arc; +use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::array::BooleanArray; use arrow::datatypes::{DataType, Schema}; @@ -103,6 +104,26 @@ impl PhysicalExpr for NotExpr { } } } + + fn children(&self) -> Vec> { + vec![self.arg.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(NotExpr::new(children[0].clone()))) + } +} + +impl PartialEq for NotExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.arg.eq(&x.arg)) + .unwrap_or(false) + } } /// Creates a unary expression NOT diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index 6a03a6b5306c3..299d29ec6072d 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::fmt; use std::sync::Arc; +use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::compute; use arrow::compute::kernels; @@ -90,6 +91,29 @@ impl PhysicalExpr for TryCastExpr { } } } + + fn children(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(TryCastExpr::new( + children[0].clone(), + self.cast_type.clone(), + ))) + } +} + +impl PartialEq for TryCastExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.expr.eq(&x.expr) && self.cast_type == x.cast_type) + .unwrap_or(false) + } } /// Return a PhysicalExpression representing `expr` casted to diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 965ca330056df..ff857c3ef46b7 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -43,7 +43,17 @@ pub mod window; // reexport this to maintain compatibility with anything that used from_slice previously pub use aggregate::AggregateExpr; pub use datafusion_common::from_slice; -pub use physical_expr::{ExprBoundaries, PhysicalExpr, PhysicalExprStats}; +pub use physical_expr::{ + combine_equivalence_properties, expr_list_eq_any_order, expr_list_eq_strict_order, + merge_equivalence_properties_with_alias, normalize_expr_with_equivalence_properties, + normalize_out_expr_with_alias_schema, + normalize_sort_expr_with_equivalence_properties, remove_equivalence_properties, + sort_expr_list_eq_strict_order, split_predicate, + truncate_equivalence_properties_not_in_schema, +}; +pub use physical_expr::{ + ExprBoundaries, PhysicalExpr, PhysicalExprStats, TreeNodeRewritable, +}; pub use planner::create_physical_expr; pub use scalar_function::ScalarFunctionExpr; pub use sort_expr::PhysicalSortExpr; diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index df82cc07f1eeb..5335f67a403cf 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -15,23 +15,25 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use datafusion_common::{ColumnStatistics, DataFusionError, Result, ScalarValue}; -use datafusion_common::{ColumnStatistics, Result, ScalarValue}; - -use datafusion_expr::ColumnarValue; -use std::fmt::{Debug, Display}; -use std::sync::Arc; +use datafusion_expr::{ColumnarValue, Operator}; +use crate::expressions::{BinaryExpr, Column}; +use crate::PhysicalSortExpr; use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, filter_record_batch, is_not_null, SlicesIterator}; use std::any::Any; +use std::collections::HashMap; +use std::fmt::{Debug, Display}; +use std::sync::Arc; /// Expression that can be evaluated against a RecordBatch /// A Physical expression knows its type, nullability and how to evaluate itself. -pub trait PhysicalExpr: Send + Sync + Display + Debug { +pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { /// Returns the physical expression as [`Any`](std::any::Any) so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -66,6 +68,181 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug { fn expr_stats(&self) -> Arc { Arc::new(BasicExpressionStats {}) } + + /// Get a list of child PhysicalExpr that provide the input for this plan. + fn children(&self) -> Vec>; + + /// Returns a new PhysicalExpr where all children were replaced by new exprs. + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result>; +} + +/// a Trait for marking tree node types that are rewritable +pub trait TreeNodeRewritable: Clone { + /// Transform the tree node using the given [TreeNodeRewriter] + /// It performs a depth first walk of an node and its children. + /// + /// For an node tree such as + /// ```text + /// ParentNode + /// left: ChildNode1 + /// right: ChildNode2 + /// ``` + /// + /// The nodes are visited using the following order + /// ```text + /// pre_visit(ParentNode) + /// pre_visit(ChildNode1) + /// mutatate(ChildNode1) + /// pre_visit(ChildNode2) + /// mutate(ChildNode2) + /// mutate(ParentNode) + /// ``` + /// + /// If an Err result is returned, recursion is stopped immediately + /// + /// If [`false`] is returned on a call to pre_visit, no + /// children of that node are visited, nor is mutate + /// called on that node + /// + fn transform_using>( + self, + rewriter: &mut R, + ) -> Result { + let need_mutate = match rewriter.pre_visit(&self)? { + RewriteRecursion::Mutate => return rewriter.mutate(self), + RewriteRecursion::Stop => return Ok(self), + RewriteRecursion::Continue => true, + RewriteRecursion::Skip => false, + }; + + let after_op_children = + self.map_children(|node| node.transform_using(rewriter))?; + + // now rewrite this node itself + if need_mutate { + rewriter.mutate(after_op_children) + } else { + Ok(after_op_children) + } + } + + /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree. + /// When `op` does not apply to a given node, it is left uncshanged. + /// The default tree traversal direction is transform_up(Postorder Traversal). + fn transform(self, op: &F) -> Result + where + F: Fn(Self) -> Option, + { + self.transform_up(op) + } + + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its + /// children(Preorder Traversal). + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_down(self, op: &F) -> Result + where + F: Fn(Self) -> Option, + { + let node_cloned = self.clone(); + let after_op = match op(node_cloned) { + Some(value) => value, + None => self, + }; + after_op.map_children(|node| node.transform_down(op)) + } + + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its + /// children and then itself(Postorder Traversal). + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_up(self, op: &F) -> Result + where + F: Fn(Self) -> Option, + { + let after_op_children = self.map_children(|node| node.transform_up(op))?; + + let after_op_children_clone = after_op_children.clone(); + let new_node = match op(after_op_children) { + Some(value) => value, + None => after_op_children_clone, + }; + Ok(new_node) + } + + /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result; +} + +/// Trait for potentially recursively transform an [`TreeNodeRewritable`] node +/// tree. When passed to `TreeNodeRewritable::transform_using`, `TreeNodeRewriter::mutate` is +/// invoked recursively on all nodes of a tree. +pub trait TreeNodeRewriter: Sized { + /// Invoked before (Preorder) any children of `node` are rewritten / + /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)` + fn pre_visit(&mut self, _node: &N) -> Result { + Ok(RewriteRecursion::Continue) + } + + /// Invoked after (Postorder) all children of `node` have been mutated and + /// returns a potentially modified node. + fn mutate(&mut self, node: N) -> Result; +} + +/// Controls how the [TreeNodeRewriter] recursion should proceed. +#[allow(dead_code)] +pub enum RewriteRecursion { + /// Continue rewrite / visit this node tree. + Continue, + /// Call 'op' immediately and return. + Mutate, + /// Do not rewrite / visit the children of this node. + Stop, + /// Keep recursive but skip apply op on this node + Skip, +} + +impl TreeNodeRewritable for Arc { + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + if !self.children().is_empty() { + let new_children: Result> = + self.children().into_iter().map(transform).collect(); + with_new_children_if_necessary(self, new_children?) + } else { + Ok(self) + } + } +} + +/// Returns a copy of this expr if we change any child according to the pointer comparison. +/// The size of `children` must be equal to the size of `PhysicalExpr::children()`. +/// Allow the vtable address comparisons for PhysicalExpr Trait Objects,it is harmless even +/// in the case of 'false-native'. +#[allow(clippy::vtable_address_comparisons)] +pub fn with_new_children_if_necessary( + expr: Arc, + children: Vec>, +) -> Result> { + if children.len() != expr.children().len() { + Err(DataFusionError::Internal( + "PhysicalExpr: Wrong number of children".to_string(), + )) + } else if children.is_empty() + || children + .iter() + .zip(expr.children().iter()) + .any(|(c1, c2)| !Arc::ptr_eq(c1, c2)) + { + expr.with_new_children(children) + } else { + Ok(expr) + } } /// Statistics about the result of a single expression. @@ -171,12 +348,260 @@ fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result { Ok(make_array(data)) } +/// Compare the two expr lists are equal no matter the order. +/// For example two InListExpr can be considered to be equals no matter the order: +/// +/// In('a','b','c') == In('c','b','a') +pub fn expr_list_eq_any_order( + list1: &[Arc], + list2: &[Arc], +) -> bool { + if list1.len() == list2.len() { + let mut expr_vec1 = list1.to_vec(); + let mut expr_vec2 = list2.to_vec(); + while let Some(expr1) = expr_vec1.pop() { + if let Some(idx) = expr_vec2.iter().position(|expr2| expr1.eq(expr2)) { + expr_vec2.swap_remove(idx); + } else { + break; + } + } + expr_vec1.is_empty() && expr_vec2.is_empty() + } else { + false + } +} + +/// Strictly compare the two expr lists are equal in the given order. +pub fn expr_list_eq_strict_order( + list1: &[Arc], + list2: &[Arc], +) -> bool { + list1.len() == list2.len() && list1.iter().zip(list2.iter()).all(|(e1, e2)| e1.eq(e2)) +} + +/// Strictly compare the two sort expr lists in the given order. +/// +/// For Physical Sort Exprs, the order matters: +/// +/// SortExpr('a','b','c') != SortExpr('c','b','a') +pub fn sort_expr_list_eq_strict_order( + list1: &[PhysicalSortExpr], + list2: &[PhysicalSortExpr], +) -> bool { + list1.len() == list2.len() && list1.iter().zip(list2.iter()).all(|(e1, e2)| e1.eq(e2)) +} + +/// Assume the predicate is in the form of CNF, split the predicate to a Vec of PhysicalExprs. +/// +/// For example, split "a1 = a2 AND b1 <= b2 AND c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"] +/// +pub fn split_predicate(predicate: &Arc) -> Vec<&Arc> { + match predicate.as_any().downcast_ref::() { + Some(binary) => match binary.op() { + Operator::And => { + let mut vec1 = split_predicate(binary.left()); + let vec2 = split_predicate(binary.right()); + vec1.extend(vec2); + vec1 + } + _ => vec![predicate], + }, + None => vec![], + } +} + +/// Combine the new equal condition with the existing equivalence properties. +pub fn combine_equivalence_properties( + eq_properties: &mut Vec>, + new_condition: (&Column, &Column), +) { + let mut idx1 = -1i32; + let mut idx2 = -1i32; + for (idx, prop) in eq_properties.iter_mut().enumerate() { + let contains_first = prop.contains(new_condition.0); + let contains_second = prop.contains(new_condition.1); + if contains_first && !contains_second { + prop.push(new_condition.1.clone()); + idx1 = idx as i32; + } else if !contains_first && contains_second { + prop.push(new_condition.0.clone()); + idx2 = idx as i32; + } else if contains_first && contains_second { + idx1 = idx as i32; + idx2 = idx as i32; + break; + } + } + + if idx1 != -1 && idx2 != -1 && idx1 != idx2 { + // need to merge the two existing properties + let second_properties = eq_properties.get(idx2 as usize).unwrap().clone(); + let first_properties = eq_properties.get_mut(idx1 as usize).unwrap(); + for prop in second_properties { + if !first_properties.contains(&prop) { + first_properties.push(prop) + } + } + eq_properties.remove(idx2 as usize); + } else if idx1 == -1 && idx2 == -1 { + // adding new pairs + eq_properties.push(vec![new_condition.0.clone(), new_condition.1.clone()]) + } +} + +pub fn remove_equivalence_properties( + eq_properties: &mut Vec>, + remove_condition: (&Column, &Column), +) { + let mut match_idx = -1i32; + for (idx, prop) in eq_properties.iter_mut().enumerate() { + let contains_first = prop.contains(remove_condition.0); + let contains_second = prop.contains(remove_condition.1); + if contains_first && contains_second { + match_idx = idx as i32; + break; + } + } + if match_idx >= 0 { + let matches = eq_properties.get_mut(match_idx as usize).unwrap(); + matches.retain(|e| (e != remove_condition.0 && e != remove_condition.1)); + if matches.len() <= 1 { + eq_properties.remove(match_idx as usize); + } + } +} + +pub fn merge_equivalence_properties_with_alias( + eq_properties: &mut Vec>, + alias_map: &HashMap>, +) { + for (column, columns) in alias_map { + let mut find_match = false; + for (_idx, prop) in eq_properties.iter_mut().enumerate() { + if prop.contains(column) { + prop.extend(columns.clone()); + find_match = true; + break; + } + } + if !find_match { + let mut new_properties = vec![column.clone()]; + new_properties.extend(columns.clone()); + eq_properties.push(new_properties); + } + } +} + +pub fn truncate_equivalence_properties_not_in_schema( + eq_properties: &mut Vec>, + schema: &SchemaRef, +) { + for props in eq_properties.iter_mut() { + props.retain(|column| matches!(schema.index_of(column.name()), Ok(idx) if idx == column.index())) + } + eq_properties.retain(|props| !props.is_empty()); +} + +/// Normalize the output expressions base on Alias Map and SchemaRef. +/// +/// 1) If there is mapping in Alias Map, replace the Column in the output expressions with the 1st Column in Alias Map +/// 2) If the Column is invalid for the current Schema, replace the Column with a place holder Column with index = usize::MAX +/// +pub fn normalize_out_expr_with_alias_schema( + expr: Arc, + alias_map: &HashMap>, + schema: &SchemaRef, +) -> Arc { + let expr_clone = expr.clone(); + expr_clone + .transform(&|expr| { + let normalized_form: Option> = + match expr.as_any().downcast_ref::() { + Some(column) => { + let out = alias_map + .get(column) + .map(|c| { + let out_col: Arc = + Arc::new(c[0].clone()); + out_col + }) + .or_else(|| match schema.index_of(column.name()) { + // Exactly matching, return None, no need to do the transform + Ok(idx) if column.index() == idx => None, + _ => { + let out_col: Arc = + Arc::new(Column::new(column.name(), usize::MAX)); + Some(out_col) + } + }); + out + } + None => None, + }; + normalized_form + }) + .unwrap_or(expr) +} + +pub fn normalize_expr_with_equivalence_properties( + expr: Arc, + eq_properties: &[Vec], +) -> Arc { + let mut normalized = expr.clone(); + if let Some(column) = expr.as_any().downcast_ref::() { + for prop in eq_properties { + if prop.contains(column) { + normalized = Arc::new(prop.get(0).unwrap().clone()); + break; + } + } + } + normalized +} + +pub fn normalize_sort_expr_with_equivalence_properties( + sort_expr: PhysicalSortExpr, + eq_properties: &[Vec], +) -> PhysicalSortExpr { + let mut normalized = sort_expr.clone(); + if let Some(column) = sort_expr.expr.as_any().downcast_ref::() { + for prop in eq_properties { + if prop.contains(column) { + normalized = PhysicalSortExpr { + expr: Arc::new(prop.get(0).unwrap().clone()), + options: sort_expr.options, + }; + break; + } + } + } + normalized +} + +pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { + if any.is::>() { + any.downcast_ref::>() + .unwrap() + .as_any() + } else if any.is::>() { + any.downcast_ref::>() + .unwrap() + .as_any() + } else { + any + } +} + #[cfg(test)] mod tests { + use crate::expressions::Column; + use crate::PhysicalSortExpr; use std::sync::Arc; use super::*; use arrow::array::Int32Array; + use arrow::compute::SortOptions; use datafusion_common::Result; #[test] @@ -245,6 +670,209 @@ mod tests { Ok(()) } + #[test] + fn expr_list_eq_any_order_test() -> Result<()> { + let list1: Vec> = vec![ + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("b", 1)), + ]; + let list2: Vec> = vec![ + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("a", 0)), + ]; + assert!(!expr_list_eq_any_order(list1.as_slice(), list2.as_slice())); + assert!(!expr_list_eq_any_order(list2.as_slice(), list1.as_slice())); + + let list3: Vec> = vec![ + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("c", 2)), + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("b", 1)), + ]; + let list4: Vec> = vec![ + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("c", 2)), + Arc::new(Column::new("a", 0)), + ]; + assert!(expr_list_eq_any_order(list3.as_slice(), list4.as_slice())); + assert!(expr_list_eq_any_order(list4.as_slice(), list3.as_slice())); + assert!(expr_list_eq_any_order(list3.as_slice(), list3.as_slice())); + assert!(expr_list_eq_any_order(list4.as_slice(), list4.as_slice())); + + Ok(()) + } + + #[test] + fn sort_expr_list_eq_strict_order_test() -> Result<()> { + let list1: Vec = vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: SortOptions::default(), + }, + ]; + + let list2: Vec = vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }, + ]; + + assert!(!sort_expr_list_eq_strict_order( + list1.as_slice(), + list2.as_slice() + )); + assert!(!sort_expr_list_eq_strict_order( + list2.as_slice(), + list1.as_slice() + )); + + let list3: Vec = vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("c", 2)), + options: SortOptions::default(), + }, + ]; + let list4: Vec = vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("c", 2)), + options: SortOptions::default(), + }, + ]; + + assert!(sort_expr_list_eq_strict_order( + list3.as_slice(), + list4.as_slice() + )); + assert!(sort_expr_list_eq_strict_order( + list4.as_slice(), + list3.as_slice() + )); + assert!(sort_expr_list_eq_strict_order( + list3.as_slice(), + list3.as_slice() + )); + assert!(sort_expr_list_eq_strict_order( + list4.as_slice(), + list4.as_slice() + )); + + Ok(()) + } + + #[test] + fn combine_equivalence_properties_test() -> Result<()> { + let mut eq_properties: Vec> = vec![]; + let new_condition = (&Column::new("a", 0), &Column::new("b", 1)); + combine_equivalence_properties(&mut eq_properties, new_condition); + assert_eq!(eq_properties.len(), 1); + + let new_condition = (&Column::new("b", 1), &Column::new("a", 0)); + combine_equivalence_properties(&mut eq_properties, new_condition); + assert_eq!(eq_properties.len(), 1); + assert_eq!(eq_properties[0].len(), 2); + + let new_condition = (&Column::new("b", 1), &Column::new("c", 2)); + combine_equivalence_properties(&mut eq_properties, new_condition); + assert_eq!(eq_properties.len(), 1); + assert_eq!(eq_properties[0].len(), 3); + + let new_condition = (&Column::new("x", 99), &Column::new("y", 100)); + combine_equivalence_properties(&mut eq_properties, new_condition); + assert_eq!(eq_properties.len(), 2); + + let new_condition = (&Column::new("x", 99), &Column::new("a", 0)); + combine_equivalence_properties(&mut eq_properties, new_condition); + assert_eq!(eq_properties.len(), 1); + assert_eq!(eq_properties[0].len(), 5); + + Ok(()) + } + + #[test] + fn remove_equivalence_properties_test() -> Result<()> { + let mut eq_properties: Vec> = vec![]; + let remove_condition = (&Column::new("a", 0), &Column::new("b", 1)); + remove_equivalence_properties(&mut eq_properties, remove_condition); + assert_eq!(eq_properties.len(), 0); + + let new_condition = (&Column::new("a", 0), &Column::new("b", 1)); + combine_equivalence_properties(&mut eq_properties, new_condition); + let new_condition = (&Column::new("a", 0), &Column::new("c", 2)); + combine_equivalence_properties(&mut eq_properties, new_condition); + let new_condition = (&Column::new("c", 2), &Column::new("d", 3)); + combine_equivalence_properties(&mut eq_properties, new_condition); + assert_eq!(eq_properties.len(), 1); + + let remove_condition = (&Column::new("a", 0), &Column::new("b", 1)); + remove_equivalence_properties(&mut eq_properties, remove_condition); + assert_eq!(eq_properties.len(), 1); + assert_eq!(eq_properties[0].len(), 2); + + Ok(()) + } + + #[test] + fn merge_equivalence_properties_with_alias_test() -> Result<()> { + let mut eq_properties: Vec> = vec![]; + let mut alias_map = HashMap::new(); + alias_map.insert( + Column::new("a", 0), + vec![Column::new("a1", 1), Column::new("a2", 2)], + ); + + merge_equivalence_properties_with_alias(&mut eq_properties, &alias_map); + assert_eq!(eq_properties.len(), 1); + assert_eq!(eq_properties[0].len(), 3); + + let mut alias_map = HashMap::new(); + alias_map.insert( + Column::new("a", 0), + vec![Column::new("a3", 1), Column::new("a4", 2)], + ); + merge_equivalence_properties_with_alias(&mut eq_properties, &alias_map); + assert_eq!(eq_properties.len(), 1); + assert_eq!(eq_properties[0].len(), 5); + Ok(()) + } + #[test] fn reduce_boundaries() -> Result<()> { let different_boundaries = ExprBoundaries::new( @@ -268,7 +896,6 @@ mod tests { let no_boundaries = ExprBoundaries::new(ScalarValue::Int32(None), ScalarValue::Int32(None), None); assert_eq!(no_boundaries.reduce(), Some(ScalarValue::Int32(None))); - Ok(()) } } diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 1350d49510d58..93968b60f8b2d 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -145,4 +145,26 @@ impl PhysicalExpr for ScalarFunctionExpr { let fun = self.fun.as_ref(); (fun)(&inputs) } + + fn children(&self) -> Vec> { + self.args.clone() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(ScalarFunctionExpr::new( + &self.name, + self.fun.clone(), + children, + self.return_type(), + ))) + } +} + +impl PartialEq for ScalarFunctionExpr { + fn eq(&self, _other: &dyn Any) -> bool { + false + } } diff --git a/datafusion/physical-expr/src/sort_expr.rs b/datafusion/physical-expr/src/sort_expr.rs index 79656725d4f44..a173cc8ba8b3b 100644 --- a/datafusion/physical-expr/src/sort_expr.rs +++ b/datafusion/physical-expr/src/sort_expr.rs @@ -33,6 +33,12 @@ pub struct PhysicalSortExpr { pub options: SortOptions, } +impl PartialEq for PhysicalSortExpr { + fn eq(&self, other: &PhysicalSortExpr) -> bool { + self.options == other.options && self.expr.eq(&other.expr) + } +} + impl std::fmt::Display for PhysicalSortExpr { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { let opts_string = match (self.options.descending, self.options.nulls_first) {