Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 47 additions & 20 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1154,8 +1154,7 @@ impl LogicalPlanBuilder {
.zip(right_keys)
.map(|(l, r)| (Expr::Column(l), Expr::Column(r)))
.collect();
let join_schema =
build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
let join_schema = build_join_schema(&self.plan, &right, &join_type)?;

// Inner type without join condition is cross join
if join_type != JoinType::Inner && on.is_empty() && filter.is_none() {
Expand Down Expand Up @@ -1652,16 +1651,22 @@ pub fn unique_field_aliases(fields: &Fields) -> Vec<Option<String>> {
.collect()
}

fn mark_field(schema: &DFSchema) -> (Option<TableReference>, Arc<Field>) {
let mut table_references = schema
.iter()
.filter_map(|(qualifier, _)| qualifier)
.collect::<Vec<_>>();
table_references.dedup();
let table_reference = if table_references.len() == 1 {
table_references.pop().cloned()
fn mark_field(child_plan: &LogicalPlan) -> (Option<TableReference>, Arc<Field>) {
let table_reference = if let LogicalPlan::SubqueryAlias(plan) = child_plan {
Some(plan.alias.clone())
} else {
None
let mut table_references = child_plan
.schema()
.iter()
.filter_map(|(qualifier, _)| qualifier)
.collect::<Vec<_>>();
table_references.dedup();

if table_references.len() == 1 {
table_references.pop().cloned()
} else {
None
}
};

(
Expand All @@ -1673,8 +1678,8 @@ fn mark_field(schema: &DFSchema) -> (Option<TableReference>, Arc<Field>) {
/// Creates a schema for a join operation.
/// The fields from the left side are first
pub fn build_join_schema(
left: &DFSchema,
right: &DFSchema,
left_plan: &LogicalPlan,
right_plan: &LogicalPlan,
join_type: &JoinType,
) -> Result<DFSchema> {
fn nullify_fields<'a>(
Expand All @@ -1689,6 +1694,8 @@ pub fn build_join_schema(
.collect()
}

let left = left_plan.schema();
let right = right_plan.schema();
let right_fields = right.iter();
let left_fields = left.iter();

Expand Down Expand Up @@ -1738,7 +1745,7 @@ pub fn build_join_schema(
}
JoinType::LeftMark => left_fields
.map(|(q, f)| (q.cloned(), Arc::clone(f)))
.chain(once(mark_field(right)))
.chain(once(mark_field(right_plan)))
.collect(),
JoinType::RightSemi | JoinType::RightAnti => {
// Only use the right side for the schema
Expand All @@ -1748,7 +1755,7 @@ pub fn build_join_schema(
}
JoinType::RightMark => right_fields
.map(|(q, f)| (q.cloned(), Arc::clone(f)))
.chain(once(mark_field(left)))
.chain(once(mark_field(left_plan)))
.collect(),
};
let func_dependencies = left.functional_dependencies().join(
Expand Down Expand Up @@ -2910,15 +2917,35 @@ mod tests {
vec![(None, Arc::new(Field::new("b", DataType::Int32, false)))],
HashMap::from([("key".to_string(), "right".to_string())]),
)?;

let join_schema =
build_join_schema(&left_schema, &right_schema, &JoinType::Left)?;
let left_schema = Arc::new(left_schema);
let right_schema = Arc::new(right_schema);

let join_schema = build_join_schema(
&LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: true,
schema: Arc::clone(&left_schema),
}),
&LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: true,
schema: Arc::clone(&right_schema),
}),
&JoinType::Left,
)?;
assert_eq!(
join_schema.metadata(),
&HashMap::from([("key".to_string(), "left".to_string())])
);
let join_schema =
build_join_schema(&left_schema, &right_schema, &JoinType::Right)?;
let join_schema = build_join_schema(
&LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: true,
schema: Arc::clone(&left_schema),
}),
&LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: true,
schema: Arc::clone(&right_schema),
}),
&JoinType::Right,
)?;
assert_eq!(
join_schema.metadata(),
&HashMap::from([("key".to_string(), "right".to_string())])
Expand Down
11 changes: 5 additions & 6 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -667,8 +667,7 @@ impl LogicalPlan {
null_equality,
null_aware,
}) => {
let schema =
build_join_schema(left.schema(), right.schema(), &join_type)?;
let schema = build_join_schema(&left, &right, &join_type)?;

let new_on: Vec<_> = on
.into_iter()
Expand Down Expand Up @@ -944,7 +943,7 @@ impl LogicalPlan {
..
}) => {
let (left, right) = self.only_two_inputs(inputs)?;
let schema = build_join_schema(left.schema(), right.schema(), join_type)?;
let schema = build_join_schema(&left, &right, join_type)?;

let equi_expr_count = on.len() * 2;
assert!(expr.len() >= equi_expr_count);
Expand Down Expand Up @@ -4269,7 +4268,7 @@ impl Join {
null_equality: NullEquality,
null_aware: bool,
) -> Result<Self> {
let join_schema = build_join_schema(left.schema(), right.schema(), &join_type)?;
let join_schema = build_join_schema(&left, &right, &join_type)?;

Ok(Join {
left,
Expand Down Expand Up @@ -4321,8 +4320,8 @@ impl Join {
.collect();

let join_schema = build_join_schema(
left_sch.schema(),
right_sch.schema(),
&left_sch.build()?,
&right_sch.build()?,
&original_join.join_type,
)?;

Expand Down
164 changes: 112 additions & 52 deletions datafusion/optimizer/src/decorrelate_predicate_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
use datafusion_expr::logical_plan::{JoinType, Subquery};
use datafusion_expr::utils::{conjunction, expr_to_columns, split_conjunction_owned};
use datafusion_expr::{
BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, exists,
in_subquery, lit, not, not_exists, not_in_subquery,
BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, Projection,
exists, in_subquery, lit, not, not_exists, not_in_subquery,
};

use log::debug;
Expand Down Expand Up @@ -69,70 +69,104 @@ impl OptimizerRule for DecorrelatePredicateSubquery {
})?
.data;

let LogicalPlan::Filter(filter) = plan else {
return Ok(Transformed::no(plan));
};

if !has_subquery(&filter.predicate) {
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
if let LogicalPlan::Filter(filter) = plan {
rewrite_filter(filter, config)
} else if let LogicalPlan::Projection(project) = plan {
rewrite_project(project, config)
} else {
Ok(Transformed::no(plan))
}
}

let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
split_conjunction_owned(filter.predicate)
.into_iter()
.partition(has_subquery);
fn name(&self) -> &str {
"decorrelate_predicate_subquery"
}

assert_or_internal_err!(
!with_subqueries.is_empty(),
"can not find expected subqueries in DecorrelatePredicateSubquery"
);
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
}
}

// iterate through all exists clauses in predicate, turning each into a join
let mut cur_input = Arc::unwrap_or_clone(filter.input);
let original_schema = cur_input.schema().columns();
for subquery_expr in with_subqueries {
match extract_subquery_info(subquery_expr) {
// The subquery expression is at the top level of the filter
SubqueryPredicate::Top(subquery) => {
match build_join_top(&subquery, &cur_input, config.alias_generator())?
{
Some(plan) => cur_input = plan,
// If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
None => other_exprs.push(subquery.expr()),
}
}
// The subquery expression is embedded within another expression
SubqueryPredicate::Embedded(expr) => {
let (plan, expr_without_subqueries) =
rewrite_inner_subqueries(cur_input, expr, config)?;
cur_input = plan;
other_exprs.push(expr_without_subqueries);
fn rewrite_filter(
filter: Filter,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
if !has_subquery(&filter.predicate) {
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
}

let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
split_conjunction_owned(filter.predicate)
.into_iter()
.partition(has_subquery);

assert_or_internal_err!(
!with_subqueries.is_empty(),
"can not find expected subqueries in DecorrelatePredicateSubquery"
);

// iterate through all exists clauses in predicate, turning each into a join
let mut cur_input = Arc::unwrap_or_clone(filter.input);
let original_schema = cur_input.schema().columns();
for subquery_expr in with_subqueries {
match extract_subquery_info(subquery_expr) {
// The subquery expression is at the top level of the filter
SubqueryPredicate::Top(subquery) => {
match build_join_top(&subquery, &cur_input, config.alias_generator())? {
Some(plan) => cur_input = plan,
// If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
None => other_exprs.push(subquery.expr()),
}
}
// The subquery expression is embedded within another expression
SubqueryPredicate::Embedded(expr) => {
let (plan, expr_without_subqueries) =
rewrite_inner_subqueries(cur_input, expr, config)?;
cur_input = plan;
other_exprs.push(expr_without_subqueries);
}
}
}

let expr = conjunction(other_exprs);
if let Some(expr) = expr {
let new_filter = Filter::try_new(expr, Arc::new(cur_input))?;
cur_input = LogicalPlan::Filter(new_filter);
}

if cur_input.schema().fields().len() != original_schema.len() {
cur_input = LogicalPlanBuilder::from(cur_input)
.project(original_schema.into_iter().map(Expr::from))?
.build()?;
}
let expr = conjunction(other_exprs);
if let Some(expr) = expr {
let new_filter = Filter::try_new(expr, Arc::new(cur_input))?;
cur_input = LogicalPlan::Filter(new_filter);
}

Ok(Transformed::yes(cur_input))
if cur_input.schema().fields().len() != original_schema.len() {
cur_input = LogicalPlanBuilder::from(cur_input)
.project(original_schema.into_iter().map(Expr::from))?
.build()?;
}

fn name(&self) -> &str {
"decorrelate_predicate_subquery"
Ok(Transformed::yes(cur_input))
}

fn rewrite_project(
project: Projection,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
if !project.expr.iter().any(has_subquery) {
return Ok(Transformed::no(LogicalPlan::Projection(project)));
}

fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
let mut cur_input = Arc::unwrap_or_clone(project.input);
let mut new_projections = Vec::with_capacity(project.expr.len());
for (expression, field) in project.expr.into_iter().zip(project.schema.fields()) {
if !has_subquery(&expression) {
new_projections.push(expression);
continue;
}

let (plan, expr_without_subqueries) =
rewrite_inner_subqueries(cur_input, expression, config)?;
cur_input = plan;
new_projections.push(expr_without_subqueries.alias(field.name()));
}

let new_project = Projection::try_new(new_projections, Arc::new(cur_input))?;
Ok(Transformed::yes(LogicalPlan::Projection(new_project)))
}

fn rewrite_inner_subqueries(
Expand Down Expand Up @@ -546,6 +580,7 @@ mod tests {
use crate::assert_optimized_plan_eq_display_indent_snapshot;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_expr::builder::table_source;
use datafusion_expr::select_expr::SelectExpr;
use datafusion_expr::{and, binary_expr, col, out_ref_col, table_scan};

macro_rules! assert_optimized_plan_equal {
Expand Down Expand Up @@ -2123,4 +2158,29 @@ mod tests {
"
)
}

#[test]
fn exist_projection() -> Result<()> {
let fields = vec![Field::new("A", DataType::UInt32, false)];

let schema = Schema::new(fields);
let subquery_scan = table_scan(Some("\"TEST_A\""), &schema, None)?.build()?;

let plan = LogicalPlanBuilder::empty(true)
.project(vec![SelectExpr::Expression(exists(Arc::new(
subquery_scan,
)))])?
.build()?;

assert_optimized_plan_equal!(
plan,
@r"
Projection: __correlated_sq_1.mark AS EXISTS [EXISTS:Boolean]
LeftMark Join: Filter: Boolean(true) [mark:Boolean]
EmptyRelation: rows=1 []
SubqueryAlias: __correlated_sq_1 [A:UInt32]
TableScan: TEST_A [A:UInt32]
"
)
}
}
Loading
Loading