From 54eef8ea201eebc6f8f99ef69a02ba92f54c8bef Mon Sep 17 00:00:00 2001 From: Louis Vialar Date: Thu, 18 Jun 2026 16:18:37 +0200 Subject: [PATCH 1/2] feat(subqueries): support subqueries in projections --- .../src/decorrelate_predicate_subquery.rs | 164 ++++++++++++------ 1 file changed, 112 insertions(+), 52 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 0609109ec6e58..3c875024b5c14 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -36,8 +36,8 @@ use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::{conjunction, expr_to_columns, split_conjunction_owned}; use datafusion_expr::{ - BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, exists, - in_subquery, lit, not, not_exists, not_in_subquery, + BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, Projection, + exists, in_subquery, lit, not, not_exists, not_in_subquery, }; use log::debug; @@ -69,70 +69,104 @@ impl OptimizerRule for DecorrelatePredicateSubquery { })? .data; - let LogicalPlan::Filter(filter) = plan else { - return Ok(Transformed::no(plan)); - }; - - if !has_subquery(&filter.predicate) { - return Ok(Transformed::no(LogicalPlan::Filter(filter))); + if let LogicalPlan::Filter(filter) = plan { + rewrite_filter(filter, config) + } else if let LogicalPlan::Projection(project) = plan { + rewrite_project(project, config) + } else { + Ok(Transformed::no(plan)) } + } - let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = - split_conjunction_owned(filter.predicate) - .into_iter() - .partition(has_subquery); + fn name(&self) -> &str { + "decorrelate_predicate_subquery" + } - assert_or_internal_err!( - !with_subqueries.is_empty(), - "can not find expected subqueries in DecorrelatePredicateSubquery" - ); + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } +} - // iterate through all exists clauses in predicate, turning each into a join - let mut cur_input = Arc::unwrap_or_clone(filter.input); - let original_schema = cur_input.schema().columns(); - for subquery_expr in with_subqueries { - match extract_subquery_info(subquery_expr) { - // The subquery expression is at the top level of the filter - SubqueryPredicate::Top(subquery) => { - match build_join_top(&subquery, &cur_input, config.alias_generator())? - { - Some(plan) => cur_input = plan, - // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter - None => other_exprs.push(subquery.expr()), - } - } - // The subquery expression is embedded within another expression - SubqueryPredicate::Embedded(expr) => { - let (plan, expr_without_subqueries) = - rewrite_inner_subqueries(cur_input, expr, config)?; - cur_input = plan; - other_exprs.push(expr_without_subqueries); +fn rewrite_filter( + filter: Filter, + config: &dyn OptimizerConfig, +) -> Result> { + if !has_subquery(&filter.predicate) { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } + + let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = + split_conjunction_owned(filter.predicate) + .into_iter() + .partition(has_subquery); + + assert_or_internal_err!( + !with_subqueries.is_empty(), + "can not find expected subqueries in DecorrelatePredicateSubquery" + ); + + // iterate through all exists clauses in predicate, turning each into a join + let mut cur_input = Arc::unwrap_or_clone(filter.input); + let original_schema = cur_input.schema().columns(); + for subquery_expr in with_subqueries { + match extract_subquery_info(subquery_expr) { + // The subquery expression is at the top level of the filter + SubqueryPredicate::Top(subquery) => { + match build_join_top(&subquery, &cur_input, config.alias_generator())? { + Some(plan) => cur_input = plan, + // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter + None => other_exprs.push(subquery.expr()), } } + // The subquery expression is embedded within another expression + SubqueryPredicate::Embedded(expr) => { + let (plan, expr_without_subqueries) = + rewrite_inner_subqueries(cur_input, expr, config)?; + cur_input = plan; + other_exprs.push(expr_without_subqueries); + } } + } - let expr = conjunction(other_exprs); - if let Some(expr) = expr { - let new_filter = Filter::try_new(expr, Arc::new(cur_input))?; - cur_input = LogicalPlan::Filter(new_filter); - } - - if cur_input.schema().fields().len() != original_schema.len() { - cur_input = LogicalPlanBuilder::from(cur_input) - .project(original_schema.into_iter().map(Expr::from))? - .build()?; - } + let expr = conjunction(other_exprs); + if let Some(expr) = expr { + let new_filter = Filter::try_new(expr, Arc::new(cur_input))?; + cur_input = LogicalPlan::Filter(new_filter); + } - Ok(Transformed::yes(cur_input)) + if cur_input.schema().fields().len() != original_schema.len() { + cur_input = LogicalPlanBuilder::from(cur_input) + .project(original_schema.into_iter().map(Expr::from))? + .build()?; } - fn name(&self) -> &str { - "decorrelate_predicate_subquery" + Ok(Transformed::yes(cur_input)) +} + +fn rewrite_project( + project: Projection, + config: &dyn OptimizerConfig, +) -> Result> { + if !project.expr.iter().any(has_subquery) { + return Ok(Transformed::no(LogicalPlan::Projection(project))); } - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) + let mut cur_input = Arc::unwrap_or_clone(project.input); + let mut new_projections = Vec::with_capacity(project.expr.len()); + for (expression, field) in project.expr.into_iter().zip(project.schema.fields()) { + if !has_subquery(&expression) { + new_projections.push(expression); + continue; + } + + let (plan, expr_without_subqueries) = + rewrite_inner_subqueries(cur_input, expression, config)?; + cur_input = plan; + new_projections.push(expr_without_subqueries.alias(field.name())); } + + let new_project = Projection::try_new(new_projections, Arc::new(cur_input))?; + Ok(Transformed::yes(LogicalPlan::Projection(new_project))) } fn rewrite_inner_subqueries( @@ -546,6 +580,7 @@ mod tests { use crate::assert_optimized_plan_eq_display_indent_snapshot; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_expr::builder::table_source; + use datafusion_expr::select_expr::SelectExpr; use datafusion_expr::{and, binary_expr, col, out_ref_col, table_scan}; macro_rules! assert_optimized_plan_equal { @@ -2123,4 +2158,29 @@ mod tests { " ) } + + #[test] + fn exist_projection() -> Result<()> { + let fields = vec![Field::new("A", DataType::UInt32, false)]; + + let schema = Schema::new(fields); + let subquery_scan = table_scan(Some("\"TEST_A\""), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::empty(true) + .project(vec![SelectExpr::Expression(exists(Arc::new( + subquery_scan, + )))])? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Projection: __correlated_sq_1.mark AS EXISTS [EXISTS:Boolean] + LeftMark Join: Filter: Boolean(true) [mark:Boolean] + EmptyRelation: rows=1 [] + SubqueryAlias: __correlated_sq_1 [A:UInt32] + TableScan: TEST_A [A:UInt32] + " + ) + } } From 6b32d07777640cda972f44ca544a8ee859155051 Mon Sep 17 00:00:00 2001 From: Louis Vialar Date: Fri, 19 Jun 2026 12:58:29 +0200 Subject: [PATCH 2/2] fix: use full plan for build_join_schema this fixes an issue where mark joins "lose" the qualifier after optimize_projections in EXISTS subqueries. Indeed, in EXISTS subqueries, we don't need any projected field, so optimize_projections remove them all, which makes it impossible for build_join_schema to know what qualifier to use for the mark field --- datafusion/expr/src/logical_plan/builder.rs | 67 +++++++++++++------ datafusion/expr/src/logical_plan/plan.rs | 11 ++- .../optimizer/src/eliminate_cross_join.rs | 16 ++--- .../optimizer/src/optimize_projections/mod.rs | 4 +- 4 files changed, 57 insertions(+), 41 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 2ecb12c30afad..113fc10ee8517 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1154,8 +1154,7 @@ impl LogicalPlanBuilder { .zip(right_keys) .map(|(l, r)| (Expr::Column(l), Expr::Column(r))) .collect(); - let join_schema = - build_join_schema(self.plan.schema(), right.schema(), &join_type)?; + let join_schema = build_join_schema(&self.plan, &right, &join_type)?; // Inner type without join condition is cross join if join_type != JoinType::Inner && on.is_empty() && filter.is_none() { @@ -1652,16 +1651,22 @@ pub fn unique_field_aliases(fields: &Fields) -> Vec> { .collect() } -fn mark_field(schema: &DFSchema) -> (Option, Arc) { - let mut table_references = schema - .iter() - .filter_map(|(qualifier, _)| qualifier) - .collect::>(); - table_references.dedup(); - let table_reference = if table_references.len() == 1 { - table_references.pop().cloned() +fn mark_field(child_plan: &LogicalPlan) -> (Option, Arc) { + let table_reference = if let LogicalPlan::SubqueryAlias(plan) = child_plan { + Some(plan.alias.clone()) } else { - None + let mut table_references = child_plan + .schema() + .iter() + .filter_map(|(qualifier, _)| qualifier) + .collect::>(); + table_references.dedup(); + + if table_references.len() == 1 { + table_references.pop().cloned() + } else { + None + } }; ( @@ -1673,8 +1678,8 @@ fn mark_field(schema: &DFSchema) -> (Option, Arc) { /// Creates a schema for a join operation. /// The fields from the left side are first pub fn build_join_schema( - left: &DFSchema, - right: &DFSchema, + left_plan: &LogicalPlan, + right_plan: &LogicalPlan, join_type: &JoinType, ) -> Result { fn nullify_fields<'a>( @@ -1689,6 +1694,8 @@ pub fn build_join_schema( .collect() } + let left = left_plan.schema(); + let right = right_plan.schema(); let right_fields = right.iter(); let left_fields = left.iter(); @@ -1738,7 +1745,7 @@ pub fn build_join_schema( } JoinType::LeftMark => left_fields .map(|(q, f)| (q.cloned(), Arc::clone(f))) - .chain(once(mark_field(right))) + .chain(once(mark_field(right_plan))) .collect(), JoinType::RightSemi | JoinType::RightAnti => { // Only use the right side for the schema @@ -1748,7 +1755,7 @@ pub fn build_join_schema( } JoinType::RightMark => right_fields .map(|(q, f)| (q.cloned(), Arc::clone(f))) - .chain(once(mark_field(left))) + .chain(once(mark_field(left_plan))) .collect(), }; let func_dependencies = left.functional_dependencies().join( @@ -2910,15 +2917,35 @@ mod tests { vec![(None, Arc::new(Field::new("b", DataType::Int32, false)))], HashMap::from([("key".to_string(), "right".to_string())]), )?; - - let join_schema = - build_join_schema(&left_schema, &right_schema, &JoinType::Left)?; + let left_schema = Arc::new(left_schema); + let right_schema = Arc::new(right_schema); + + let join_schema = build_join_schema( + &LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: Arc::clone(&left_schema), + }), + &LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: Arc::clone(&right_schema), + }), + &JoinType::Left, + )?; assert_eq!( join_schema.metadata(), &HashMap::from([("key".to_string(), "left".to_string())]) ); - let join_schema = - build_join_schema(&left_schema, &right_schema, &JoinType::Right)?; + let join_schema = build_join_schema( + &LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: Arc::clone(&left_schema), + }), + &LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: Arc::clone(&right_schema), + }), + &JoinType::Right, + )?; assert_eq!( join_schema.metadata(), &HashMap::from([("key".to_string(), "right".to_string())]) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9ca6941a61ce6..aa6028c5d8114 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -667,8 +667,7 @@ impl LogicalPlan { null_equality, null_aware, }) => { - let schema = - build_join_schema(left.schema(), right.schema(), &join_type)?; + let schema = build_join_schema(&left, &right, &join_type)?; let new_on: Vec<_> = on .into_iter() @@ -944,7 +943,7 @@ impl LogicalPlan { .. }) => { let (left, right) = self.only_two_inputs(inputs)?; - let schema = build_join_schema(left.schema(), right.schema(), join_type)?; + let schema = build_join_schema(&left, &right, join_type)?; let equi_expr_count = on.len() * 2; assert!(expr.len() >= equi_expr_count); @@ -4269,7 +4268,7 @@ impl Join { null_equality: NullEquality, null_aware: bool, ) -> Result { - let join_schema = build_join_schema(left.schema(), right.schema(), &join_type)?; + let join_schema = build_join_schema(&left, &right, &join_type)?; Ok(Join { left, @@ -4321,8 +4320,8 @@ impl Join { .collect(); let join_schema = build_join_schema( - left_sch.schema(), - right_sch.schema(), + &left_sch.build()?, + &right_sch.build()?, &original_join.join_type, )?; diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 95b70da443d88..0db08f5136810 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -371,8 +371,8 @@ fn find_inner_join( all_join_keys.insert_all(join_keys.iter()); let right_input = rights.remove(i); let join_schema = Arc::new(build_join_schema( - left_input.schema(), - right_input.schema(), + &left_input, + &right_input, &JoinType::Inner, )?); @@ -393,11 +393,7 @@ fn find_inner_join( // no matching right plan had any join keys, cross join with the first right // plan let right = rights.remove(0); - let join_schema = Arc::new(build_join_schema( - left_input.schema(), - right.schema(), - &JoinType::Inner, - )?); + let join_schema = Arc::new(build_join_schema(&left_input, &right, &JoinType::Inner)?); Ok(LogicalPlan::Join(Join { left: Arc::new(left_input), @@ -1398,11 +1394,7 @@ mod tests { let t2 = test_table_scan_with_name("t2")?; // Create an inner join with NullEquality::NullEqualsNull - let join_schema = Arc::new(build_join_schema( - t1.schema(), - t2.schema(), - &JoinType::Inner, - )?); + let join_schema = Arc::new(build_join_schema(&t1, &t2, &JoinType::Inner)?); let inner_join = LogicalPlan::Join(Join { left: Arc::new(t1), diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index acdbf71d05d5c..585c164d7eea7 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -1021,10 +1021,8 @@ mod tests { impl UserDefinedCrossJoin { fn new(left_child: Arc, right_child: Arc) -> Self { - let left_schema = left_child.schema(); - let right_schema = right_child.schema(); let schema = Arc::new( - build_join_schema(left_schema, right_schema, &JoinType::Inner).unwrap(), + build_join_schema(&left_child, &right_child, &JoinType::Inner).unwrap(), ); Self { exprs: vec![],