diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 39990ffe11f4b..635a7eb5a9611 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -514,7 +514,7 @@ impl LogicalPlanBuilder { /// Wrap a plan in a window pub fn window_plan( input: LogicalPlan, - window_exprs: Vec, + window_exprs: impl IntoIterator, ) -> Result { let mut plan = input; let mut groups = group_window_expr_by_sort_keys(window_exprs)?; diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 9d7def76a11e5..552ce1502d466 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -575,7 +575,7 @@ pub fn compare_sort_expr( /// Group a slice of window expression expr by their order by expressions pub fn group_window_expr_by_sort_keys( - window_expr: Vec, + window_expr: impl IntoIterator, ) -> Result)>> { let mut result = vec![]; window_expr.into_iter().try_for_each(|expr| match &expr { diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 6752286a3c0f5..a23aad3d7237c 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1059,7 +1059,7 @@ pub async fn from_project_rel( p: &ProjectRel, ) -> Result { if let Some(input) = p.input.as_ref() { - let mut input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); + let input = consumer.consume_rel(input).await?; let original_schema = Arc::clone(input.schema()); // Ensure that all expressions have a unique display name, so that @@ -1075,6 +1075,10 @@ pub async fn from_project_rel( // leaving only explicit expressions. let mut explicit_exprs: Vec = vec![]; + // For WindowFunctions, we need to wrap them in a Window relation. If there are duplicates, + // we can do the window'ing only once, then the project will duplicate the result. + // Order here doesn't matter since LPB::window_plan sorts the expressions. + let mut window_exprs: HashSet = HashSet::new(); for expr in &p.expressions { let e = consumer .consume_expression(expr, input.clone().schema()) @@ -1084,18 +1088,24 @@ pub async fn from_project_rel( // Adding the same expression here and in the project below // works because the project's builder uses columnize_expr(..) // to transform it into a column reference - input = input.window(vec![e.clone()])? + window_exprs.insert(e.clone()); } explicit_exprs.push(name_tracker.get_uniquely_named_expr(e)?); } + let input = if !window_exprs.is_empty() { + LogicalPlanBuilder::window_plan(input, window_exprs)? + } else { + input + }; + let mut final_exprs: Vec = vec![]; for index in 0..original_schema.fields().len() { let e = Expr::Column(Column::from(original_schema.qualified_field(index))); final_exprs.push(name_tracker.get_uniquely_named_expr(e)?); } final_exprs.append(&mut explicit_exprs); - input.project(final_exprs)?.build() + project(input, final_exprs) } else { not_impl_err!("Projection without an input is not supported") } diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 6f58995955489..579e3535f16d6 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -45,6 +45,10 @@ mod tests { "Projection: NOT DATA.D AS EXPR$0\ \n TableScan: DATA" ); + + // Trigger execution to ensure plan validity + DataFrame::new(ctx.state(), plan).show().await?; + Ok(()) } @@ -71,6 +75,63 @@ mod tests { \n WindowAggr: windowExpr=[[sum(DATA.D) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]]\ \n TableScan: DATA" ); + + // Trigger execution to ensure plan validity + DataFrame::new(ctx.state(), plan).show().await?; + + Ok(()) + } + + #[tokio::test] + async fn double_window_function() -> Result<()> { + // Confirms a WindowExpr can be repeated in the same project. + // This wouldn't normally happen with DF-created plans since CSE would eliminate the duplicate. + + // File generated with substrait-java's Isthmus: + // ./isthmus-cli/build/graal/isthmus --create "create table data (a int)" "select ROW_NUMBER() OVER (), ROW_NUMBER() OVER () AS aliased from data"; + let proto_plan = + read_json("tests/testdata/test_plans/double_window.substrait.json"); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; + + assert_eq!( + format!("{}", plan), + "Projection: row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS EXPR$0, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW__temp__0 AS ALIASED\ + \n WindowAggr: windowExpr=[[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n TableScan: DATA" + ); + + // Trigger execution to ensure plan validity + DataFrame::new(ctx.state(), plan).show().await?; + + Ok(()) + } + + #[tokio::test] + async fn double_window_function_distinct_windows() -> Result<()> { + // Confirms a single project can have multiple window functions with separate windows in it. + // This wouldn't normally happen with DF-created plans since logical optimizer would + // separate them out. + + // File generated with substrait-java's Isthmus: + // ./isthmus-cli/build/graal/isthmus --create "create table data (a int)" "select ROW_NUMBER() OVER (), ROW_NUMBER() OVER (PARTITION BY a) from data"; + let proto_plan = read_json( + "tests/testdata/test_plans/double_window_distinct_windows.substrait.json", + ); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; + + assert_eq!( + format!("{}", plan), + "Projection: row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS EXPR$0, row_number() PARTITION BY [DATA.A] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS EXPR$1\ + \n WindowAggr: windowExpr=[[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n WindowAggr: windowExpr=[[row_number() PARTITION BY [DATA.A] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n TableScan: DATA" + ); + + // Trigger execution to ensure plan validity + DataFrame::new(ctx.state(), plan).show().await?; + Ok(()) } @@ -86,7 +147,7 @@ mod tests { assert_eq!(format!("{}", &plan), "Values: (List([1, 2]))"); - // Need to trigger execution to ensure that Arrow has validated the plan + // Trigger execution to ensure plan validity DataFrame::new(ctx.state(), plan).show().await?; Ok(()) @@ -107,6 +168,9 @@ mod tests { \n TableScan: sales" ); + // Trigger execution to ensure plan validity + DataFrame::new(ctx.state(), plan).show().await?; + Ok(()) } } diff --git a/datafusion/substrait/tests/testdata/test_plans/double_window.substrait.json b/datafusion/substrait/tests/testdata/test_plans/double_window.substrait.json new file mode 100644 index 0000000000000..880f6fcae6cb9 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/double_window.substrait.json @@ -0,0 +1,126 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "row_number:" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1, + 2 + ] + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "A" + ], + "struct": { + "types": [ + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "DATA" + ] + } + } + }, + "expressions": [ + { + "windowFunction": { + "functionReference": 0, + "partitions": [], + "sorts": [], + "upperBound": { + "currentRow": { + } + }, + "lowerBound": { + "unbounded": { + } + }, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "args": [], + "arguments": [], + "invocation": "AGGREGATION_INVOCATION_ALL", + "options": [], + "boundsType": "BOUNDS_TYPE_ROWS" + } + }, + { + "windowFunction": { + "functionReference": 0, + "partitions": [], + "sorts": [], + "upperBound": { + "currentRow": { + } + }, + "lowerBound": { + "unbounded": { + } + }, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "args": [], + "arguments": [], + "invocation": "AGGREGATION_INVOCATION_ALL", + "options": [], + "boundsType": "BOUNDS_TYPE_ROWS" + } + } + ] + } + }, + "names": [ + "EXPR$0", + "ALIASED" + ] + } + } + ], + "expectedTypeUrls": [] +} diff --git a/datafusion/substrait/tests/testdata/test_plans/double_window_distinct_windows.substrait.json b/datafusion/substrait/tests/testdata/test_plans/double_window_distinct_windows.substrait.json new file mode 100644 index 0000000000000..a8906e94c6661 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/double_window_distinct_windows.substrait.json @@ -0,0 +1,138 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "row_number:" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1, + 2 + ] + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "A" + ], + "struct": { + "types": [ + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "DATA" + ] + } + } + }, + "expressions": [ + { + "windowFunction": { + "functionReference": 0, + "partitions": [], + "sorts": [], + "upperBound": { + "currentRow": { + } + }, + "lowerBound": { + "unbounded": { + } + }, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "args": [], + "arguments": [], + "invocation": "AGGREGATION_INVOCATION_ALL", + "options": [], + "boundsType": "BOUNDS_TYPE_ROWS" + } + }, + { + "windowFunction": { + "functionReference": 0, + "partitions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + ], + "sorts": [], + "upperBound": { + "currentRow": { + } + }, + "lowerBound": { + "unbounded": { + } + }, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "args": [], + "arguments": [], + "invocation": "AGGREGATION_INVOCATION_ALL", + "options": [], + "boundsType": "BOUNDS_TYPE_ROWS" + } + } + ] + } + }, + "names": [ + "EXPR$0", + "EXPR$1" + ] + } + } + ], + "expectedTypeUrls": [] +}