From 8bf7370d4de7afa33f7b6e9352a16538b8457350 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Mon, 20 May 2024 11:25:09 +0800 Subject: [PATCH 1/3] UserDefinedLogicalNode::from_template return Result --- datafusion/expr/src/logical_plan/extension.rs | 15 ++++----- datafusion/expr/src/logical_plan/plan.rs | 4 +-- datafusion/expr/src/logical_plan/tree_node.rs | 32 +++++++++---------- .../substrait/src/logical_plan/consumer.rs | 4 +-- .../tests/cases/roundtrip_logical_plan.rs | 6 ++-- 5 files changed, 29 insertions(+), 32 deletions(-) diff --git a/datafusion/expr/src/logical_plan/extension.rs b/datafusion/expr/src/logical_plan/extension.rs index 7e6f07e0c5098..ef3282ce4fc10 100644 --- a/datafusion/expr/src/logical_plan/extension.rs +++ b/datafusion/expr/src/logical_plan/extension.rs @@ -17,7 +17,7 @@ //! This module defines the interface for logical nodes use crate::{Expr, LogicalPlan}; -use datafusion_common::{DFSchema, DFSchemaRef}; +use datafusion_common::{DFSchema, DFSchemaRef, Result}; use std::hash::{Hash, Hasher}; use std::{any::Any, collections::HashSet, fmt, sync::Arc}; @@ -76,10 +76,10 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { /// For example: `TopK: k=10` fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result; - /// Create a new `ExtensionPlanNode` with the specified children + /// Create a new `UserDefinedLogicalNode` with the specified children /// and expressions. This function is used during optimization /// when the plan is being rewritten and a new instance of the - /// `ExtensionPlanNode` must be created. + /// `UserDefinedLogicalNode` must be created. /// /// Note that exprs and inputs are in the same order as the result /// of self.inputs and self.exprs. @@ -88,15 +88,12 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { // // TODO(clippy): This should probably be renamed to use a `with_*` prefix. Something // like `with_template`, or `with_exprs_and_inputs`. - // - // Also, I think `ExtensionPlanNode` has been renamed to `UserDefinedLogicalNode` - // but the doc comments have not been updated. #[allow(clippy::wrong_self_convention)] fn from_template( &self, exprs: &[Expr], inputs: &[LogicalPlan], - ) -> Arc; + ) -> Result>; /// Returns the necessary input columns for this node required to compute /// the columns in the output schema @@ -316,8 +313,8 @@ impl UserDefinedLogicalNode for T { &self, exprs: &[Expr], inputs: &[LogicalPlan], - ) -> Arc { - Arc::new(self.from_template(exprs, inputs)) + ) -> Result> { + Ok(Arc::new(self.from_template(exprs, inputs))) } fn necessary_children_exprs( diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 4872e5acda5e9..daa44319543bb 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -625,7 +625,7 @@ impl LogicalPlan { let expr = node.expressions(); let inputs: Vec<_> = node.inputs().into_iter().cloned().collect(); Ok(LogicalPlan::Extension(Extension { - node: node.from_template(&expr, &inputs), + node: node.from_template(&expr, &inputs)?, })) } LogicalPlan::Union(Union { inputs, schema }) => { @@ -923,7 +923,7 @@ impl LogicalPlan { definition: definition.clone(), }))), LogicalPlan::Extension(e) => Ok(LogicalPlan::Extension(Extension { - node: e.node.from_template(&expr, &inputs), + node: e.node.from_template(&expr, &inputs)?, })), LogicalPlan::Union(Union { schema, .. }) => { let input_schema = inputs[0].schema(); diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 2289eb1639330..056f9c4f36abc 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -417,7 +417,7 @@ where .map_data(|new_inputs| { let exprs = node.expressions(); Ok(Extension { - node: node.from_template(&exprs, &new_inputs), + node: node.from_template(&exprs, &new_inputs)?, }) }) } @@ -658,22 +658,22 @@ impl LogicalPlan { LogicalPlan::Extension(Extension { node }) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs - node.expressions() + let exprs = node + .expressions() .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|exprs| { - LogicalPlan::Extension(Extension { - node: UserDefinedLogicalNode::from_template( - node.as_ref(), - exprs.as_slice(), - node.inputs() - .into_iter() - .cloned() - .collect::>() - .as_slice(), - ), - }) - }) + .map_until_stop_and_collect(f)?; + let plan = LogicalPlan::Extension(Extension { + node: UserDefinedLogicalNode::from_template( + node.as_ref(), + exprs.data.as_slice(), + node.inputs() + .into_iter() + .cloned() + .collect::>() + .as_slice(), + )?, + }); + Transformed::new(plan, exprs.transformed, exprs.tnr) } LogicalPlan::TableScan(TableScan { table_name, diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index fab4528c0b421..77f3ba552053a 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -551,7 +551,7 @@ pub async fn from_substrait_rel( ); }; let input_plan = from_substrait_rel(ctx, input_rel, extensions).await?; - let plan = plan.from_template(&plan.expressions(), &[input_plan]); + let plan = plan.from_template(&plan.expressions(), &[input_plan])?; Ok(LogicalPlan::Extension(Extension { node: plan })) } Some(RelType::ExtensionMulti(extension)) => { @@ -567,7 +567,7 @@ pub async fn from_substrait_rel( let input_plan = from_substrait_rel(ctx, input, extensions).await?; inputs.push(input_plan); } - let plan = plan.from_template(&plan.expressions(), &inputs); + let plan = plan.from_template(&plan.expressions(), &inputs)?; Ok(LogicalPlan::Extension(Extension { node: plan })) } Some(RelType::Exchange(exchange)) => { diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 28c0de1c9973f..d0d001dd62e08 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -114,12 +114,12 @@ impl UserDefinedLogicalNode for MockUserDefinedLogicalPlan { &self, _: &[Expr], inputs: &[LogicalPlan], - ) -> Arc { - Arc::new(Self { + ) -> Result> { + Ok(Arc::new(Self { validation_bytes: self.validation_bytes.clone(), inputs: inputs.to_vec(), empty_schema: Arc::new(DFSchema::empty()), - }) + })) } fn dyn_hash(&self, _: &mut dyn std::hash::Hasher) { From 1c8cb3cf7cb9de12afdfba0c77ab26527d7918ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Mon, 20 May 2024 11:27:26 +0800 Subject: [PATCH 2/3] Rename from_template to with_exprs_and_inputs --- datafusion/expr/src/logical_plan/extension.rs | 10 +++------- datafusion/expr/src/logical_plan/plan.rs | 4 ++-- datafusion/expr/src/logical_plan/tree_node.rs | 4 ++-- datafusion/substrait/src/logical_plan/consumer.rs | 4 ++-- .../substrait/tests/cases/roundtrip_logical_plan.rs | 2 +- 5 files changed, 10 insertions(+), 14 deletions(-) diff --git a/datafusion/expr/src/logical_plan/extension.rs b/datafusion/expr/src/logical_plan/extension.rs index ef3282ce4fc10..fed43f0a186e4 100644 --- a/datafusion/expr/src/logical_plan/extension.rs +++ b/datafusion/expr/src/logical_plan/extension.rs @@ -84,12 +84,8 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { /// Note that exprs and inputs are in the same order as the result /// of self.inputs and self.exprs. /// - /// So, `self.from_template(exprs, ..).expressions() == exprs - // - // TODO(clippy): This should probably be renamed to use a `with_*` prefix. Something - // like `with_template`, or `with_exprs_and_inputs`. - #[allow(clippy::wrong_self_convention)] - fn from_template( + /// So, `self.with_exprs_and_inputs(exprs, ..).expressions() == exprs + fn with_exprs_and_inputs( &self, exprs: &[Expr], inputs: &[LogicalPlan], @@ -309,7 +305,7 @@ impl UserDefinedLogicalNode for T { self.fmt_for_explain(f) } - fn from_template( + fn with_exprs_and_inputs( &self, exprs: &[Expr], inputs: &[LogicalPlan], diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index daa44319543bb..7ff38da4ec6dd 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -625,7 +625,7 @@ impl LogicalPlan { let expr = node.expressions(); let inputs: Vec<_> = node.inputs().into_iter().cloned().collect(); Ok(LogicalPlan::Extension(Extension { - node: node.from_template(&expr, &inputs)?, + node: node.with_exprs_and_inputs(&expr, &inputs)?, })) } LogicalPlan::Union(Union { inputs, schema }) => { @@ -923,7 +923,7 @@ impl LogicalPlan { definition: definition.clone(), }))), LogicalPlan::Extension(e) => Ok(LogicalPlan::Extension(Extension { - node: e.node.from_template(&expr, &inputs)?, + node: e.node.with_exprs_and_inputs(&expr, &inputs)?, })), LogicalPlan::Union(Union { schema, .. }) => { let input_schema = inputs[0].schema(); diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 056f9c4f36abc..77f90a5001d87 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -417,7 +417,7 @@ where .map_data(|new_inputs| { let exprs = node.expressions(); Ok(Extension { - node: node.from_template(&exprs, &new_inputs)?, + node: node.with_exprs_and_inputs(&exprs, &new_inputs)?, }) }) } @@ -663,7 +663,7 @@ impl LogicalPlan { .into_iter() .map_until_stop_and_collect(f)?; let plan = LogicalPlan::Extension(Extension { - node: UserDefinedLogicalNode::from_template( + node: UserDefinedLogicalNode::with_exprs_and_inputs( node.as_ref(), exprs.data.as_slice(), node.inputs() diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 77f3ba552053a..2215ff82c7d6a 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -551,7 +551,7 @@ pub async fn from_substrait_rel( ); }; let input_plan = from_substrait_rel(ctx, input_rel, extensions).await?; - let plan = plan.from_template(&plan.expressions(), &[input_plan])?; + let plan = plan.with_exprs_and_inputs(&plan.expressions(), &[input_plan])?; Ok(LogicalPlan::Extension(Extension { node: plan })) } Some(RelType::ExtensionMulti(extension)) => { @@ -567,7 +567,7 @@ pub async fn from_substrait_rel( let input_plan = from_substrait_rel(ctx, input, extensions).await?; inputs.push(input_plan); } - let plan = plan.from_template(&plan.expressions(), &inputs)?; + let plan = plan.with_exprs_and_inputs(&plan.expressions(), &inputs)?; Ok(LogicalPlan::Extension(Extension { node: plan })) } Some(RelType::Exchange(exchange)) => { diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index d0d001dd62e08..07dc0c881809a 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -110,7 +110,7 @@ impl UserDefinedLogicalNode for MockUserDefinedLogicalPlan { ) } - fn from_template( + fn with_exprs_and_inputs( &self, _: &[Expr], inputs: &[LogicalPlan], From c4b6e1af52cc77859a3e2d5d3c6ca68141c8c660 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Tue, 21 May 2024 11:09:52 +0800 Subject: [PATCH 3/3] Resolve review comments --- datafusion/expr/src/logical_plan/extension.rs | 21 ++++++++++++++----- datafusion/expr/src/logical_plan/plan.rs | 4 ++-- datafusion/expr/src/logical_plan/tree_node.rs | 10 +++------ .../substrait/src/logical_plan/consumer.rs | 5 +++-- .../tests/cases/roundtrip_logical_plan.rs | 6 +++--- 5 files changed, 27 insertions(+), 19 deletions(-) diff --git a/datafusion/expr/src/logical_plan/extension.rs b/datafusion/expr/src/logical_plan/extension.rs index fed43f0a186e4..918e290ee43b7 100644 --- a/datafusion/expr/src/logical_plan/extension.rs +++ b/datafusion/expr/src/logical_plan/extension.rs @@ -76,6 +76,17 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { /// For example: `TopK: k=10` fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result; + #[deprecated(since = "39.0.0", note = "use with_exprs_and_inputs instead")] + #[allow(clippy::wrong_self_convention)] + fn from_template( + &self, + exprs: &[Expr], + inputs: &[LogicalPlan], + ) -> Arc { + self.with_exprs_and_inputs(exprs.to_vec(), inputs.to_vec()) + .unwrap() + } + /// Create a new `UserDefinedLogicalNode` with the specified children /// and expressions. This function is used during optimization /// when the plan is being rewritten and a new instance of the @@ -87,8 +98,8 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { /// So, `self.with_exprs_and_inputs(exprs, ..).expressions() == exprs fn with_exprs_and_inputs( &self, - exprs: &[Expr], - inputs: &[LogicalPlan], + exprs: Vec, + inputs: Vec, ) -> Result>; /// Returns the necessary input columns for this node required to compute @@ -307,10 +318,10 @@ impl UserDefinedLogicalNode for T { fn with_exprs_and_inputs( &self, - exprs: &[Expr], - inputs: &[LogicalPlan], + exprs: Vec, + inputs: Vec, ) -> Result> { - Ok(Arc::new(self.from_template(exprs, inputs))) + Ok(Arc::new(self.from_template(&exprs, &inputs))) } fn necessary_children_exprs( diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 7ff38da4ec6dd..42f3e1f163a75 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -625,7 +625,7 @@ impl LogicalPlan { let expr = node.expressions(); let inputs: Vec<_> = node.inputs().into_iter().cloned().collect(); Ok(LogicalPlan::Extension(Extension { - node: node.with_exprs_and_inputs(&expr, &inputs)?, + node: node.with_exprs_and_inputs(expr, inputs)?, })) } LogicalPlan::Union(Union { inputs, schema }) => { @@ -923,7 +923,7 @@ impl LogicalPlan { definition: definition.clone(), }))), LogicalPlan::Extension(e) => Ok(LogicalPlan::Extension(Extension { - node: e.node.with_exprs_and_inputs(&expr, &inputs)?, + node: e.node.with_exprs_and_inputs(expr, inputs)?, })), LogicalPlan::Union(Union { schema, .. }) => { let input_schema = inputs[0].schema(); diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 77f90a5001d87..ea1f1c3c85f76 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -417,7 +417,7 @@ where .map_data(|new_inputs| { let exprs = node.expressions(); Ok(Extension { - node: node.with_exprs_and_inputs(&exprs, &new_inputs)?, + node: node.with_exprs_and_inputs(exprs, new_inputs)?, }) }) } @@ -665,12 +665,8 @@ impl LogicalPlan { let plan = LogicalPlan::Extension(Extension { node: UserDefinedLogicalNode::with_exprs_and_inputs( node.as_ref(), - exprs.data.as_slice(), - node.inputs() - .into_iter() - .cloned() - .collect::>() - .as_slice(), + exprs.data, + node.inputs().into_iter().cloned().collect::>(), )?, }); Transformed::new(plan, exprs.transformed, exprs.tnr) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 2215ff82c7d6a..e16479110671d 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -551,7 +551,8 @@ pub async fn from_substrait_rel( ); }; let input_plan = from_substrait_rel(ctx, input_rel, extensions).await?; - let plan = plan.with_exprs_and_inputs(&plan.expressions(), &[input_plan])?; + let plan = + plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?; Ok(LogicalPlan::Extension(Extension { node: plan })) } Some(RelType::ExtensionMulti(extension)) => { @@ -567,7 +568,7 @@ pub async fn from_substrait_rel( let input_plan = from_substrait_rel(ctx, input, extensions).await?; inputs.push(input_plan); } - let plan = plan.with_exprs_and_inputs(&plan.expressions(), &inputs)?; + let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; Ok(LogicalPlan::Extension(Extension { node: plan })) } Some(RelType::Exchange(exchange)) => { diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 07dc0c881809a..4c7dc87145852 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -112,12 +112,12 @@ impl UserDefinedLogicalNode for MockUserDefinedLogicalPlan { fn with_exprs_and_inputs( &self, - _: &[Expr], - inputs: &[LogicalPlan], + _: Vec, + inputs: Vec, ) -> Result> { Ok(Arc::new(Self { validation_bytes: self.validation_bytes.clone(), - inputs: inputs.to_vec(), + inputs, empty_schema: Arc::new(DFSchema::empty()), })) }