From 7d3a96c5b6df6b63d78372b5b9d0a4f6014b3a70 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Mon, 6 Jan 2025 22:28:02 +0000 Subject: [PATCH 01/10] tests and optimizer in testing queries --- datafusion/sql/Cargo.toml | 1 + datafusion/sql/tests/cases/plan_to_sql.rs | 81 +++++++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 224c7cb191a38..504035a769255 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -56,6 +56,7 @@ sqlparser = { workspace = true } [dev-dependencies] ctor = { workspace = true } +datafusion-optimizer = { workspace = true } datafusion-functions = { workspace = true, default-features = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-nested = { workspace = true } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 24ec7f03deb08..e333fd51fb08b 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -29,6 +29,8 @@ use datafusion_functions_aggregate::grouping::grouping_udaf; use datafusion_functions_nested::make_array::make_array_udf; use datafusion_functions_nested::map::map_udf; use datafusion_functions_window::rank::rank_udwf; +use datafusion_optimizer::optimizer::Optimizer; +use datafusion_optimizer::{OptimizerContext, OptimizerRule}; use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_sql::unparser::dialect::{ CustomDialectBuilder, DefaultDialect as UnparserDefaultDialect, DefaultDialect, @@ -57,6 +59,12 @@ use datafusion_sql::unparser::extension_unparser::{ use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect}; use sqlparser::parser::Parser; +fn optimize_plan(plan: LogicalPlan) -> Result { + let optimizer = Optimizer::new(); + let observer = |_plan: &LogicalPlan, _rule: &dyn OptimizerRule| {}; + optimizer.optimize(plan, &OptimizerContext::default(), observer) +} + #[test] fn roundtrip_expr() { let tests: Vec<(TableReference, &str, &str)> = vec![ @@ -289,6 +297,7 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: &'static str, parser_dialect: Box, unparser_dialect: Box, + optimized: bool, } let tests: Vec = vec![ TestStatementWithDialect { @@ -299,6 +308,7 @@ fn roundtrip_statement_with_dialect() -> Result<()> { "SELECT `j1_min` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, min(`ta`.`j1_id`) FROM `j1` AS `ta` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10", parser_dialect: Box::new(MySqlDialect {}), unparser_dialect: Box::new(UnparserMySqlDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;", @@ -308,6 +318,7 @@ fn roundtrip_statement_with_dialect() -> Result<()> { "SELECT j1_min FROM (SELECT min(ta.j1_id) AS j1_min, min(ta.j1_id) FROM j1 AS ta ORDER BY min(ta.j1_id) ASC NULLS LAST) LIMIT 10", parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "select min(ta.j1_id) as j1_min, max(tb.j1_max) from j1 ta, (select distinct max(ta.j1_id) as j1_max from j1 ta order by max(ta.j1_id)) tb order by min(ta.j1_id) limit 10;", @@ -315,6 +326,7 @@ fn roundtrip_statement_with_dialect() -> Result<()> { "SELECT `j1_min`, `max(tb.j1_max)` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, max(`tb`.`j1_max`), min(`ta`.`j1_id`) FROM `j1` AS `ta` CROSS JOIN (SELECT `j1_max` FROM (SELECT DISTINCT max(`ta`.`j1_id`) AS `j1_max` FROM `j1` AS `ta`) AS `derived_distinct`) AS `tb` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10", parser_dialect: Box::new(MySqlDialect {}), unparser_dialect: Box::new(UnparserMySqlDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "select j1_id from (select 1 as j1_id);", @@ -322,6 +334,7 @@ fn roundtrip_statement_with_dialect() -> Result<()> { "SELECT `j1_id` FROM (SELECT 1 AS `j1_id`) AS `derived_projection`", parser_dialect: Box::new(MySqlDialect {}), unparser_dialect: Box::new(UnparserMySqlDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "select * from (select * from j1 limit 10);", @@ -329,6 +342,7 @@ fn roundtrip_statement_with_dialect() -> Result<()> { "SELECT * FROM (SELECT * FROM `j1` LIMIT 10) AS `derived_limit`", parser_dialect: Box::new(MySqlDialect {}), unparser_dialect: Box::new(UnparserMySqlDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "select ta.j1_id from j1 ta order by j1_id limit 10;", @@ -336,12 +350,14 @@ fn roundtrip_statement_with_dialect() -> Result<()> { "SELECT `ta`.`j1_id` FROM `j1` AS `ta` ORDER BY `ta`.`j1_id` ASC LIMIT 10", parser_dialect: Box::new(MySqlDialect {}), unparser_dialect: Box::new(UnparserMySqlDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "select ta.j1_id from j1 ta order by j1_id limit 10;", expected: r#"SELECT ta.j1_id FROM j1 AS ta ORDER BY ta.j1_id ASC NULLS LAST LIMIT 10"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "SELECT j1_id FROM j1 @@ -352,6 +368,27 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: r#"SELECT j1.j1_id FROM j1 UNION ALL SELECT tb.j2_id AS j1_id FROM j2 AS tb ORDER BY j1_id ASC NULLS LAST LIMIT 10"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, + }, + TestStatementWithDialect { + sql: "SELECT 1 x, 'a' y UNION ALL + SELECT 1 x, 'b' y UNION ALL + SELECT 2 x, 'a' y UNION ALL + SELECT 2 x, 'c' y", + expected: r#"SELECT 1 AS x, 'a' AS y UNION ALL SELECT 1 AS x, 'b' AS y UNION ALL SELECT 2 AS x, 'a' AS y UNION ALL SELECT 2 AS x, 'c' AS y"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: true, + }, + TestStatementWithDialect { + sql: "SELECT 1 x, 'a' y UNION ALL + SELECT 1 x, 'b' y UNION ALL + SELECT 2 x, 'a' y UNION ALL + SELECT 2 x, 'c' y", + expected: r#"SELECT 1 AS x, 'a' AS y UNION ALL SELECT 1 AS x, 'b' AS y UNION ALL SELECT 2 AS x, 'a' AS y UNION ALL SELECT 2 AS x, 'c' AS y"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, // Test query with derived tables that put distinct,sort,limit on the wrong level TestStatementWithDialect { @@ -359,18 +396,21 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: r#"SELECT j1.j1_string FROM j1 ORDER BY j1.j1_id ASC NULLS LAST"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "SELECT j1_string AS a from j1 order by j1_id", expected: r#"SELECT j1.j1_string AS a FROM j1 ORDER BY j1.j1_id ASC NULLS LAST"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "SELECT j1_string from j1 join j2 on j1.j1_id = j2.j2_id order by j1_id", expected: r#"SELECT j1.j1_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id ASC NULLS LAST"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, TestStatementWithDialect { sql: " @@ -396,6 +436,7 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: r#"SELECT abc.j1_string, abc.j2_string FROM (SELECT DISTINCT j1.j1_id, j1.j1_string, j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, // more tests around subquery/derived table roundtrip TestStatementWithDialect { @@ -413,6 +454,7 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: r#"SELECT agg.string_count FROM (SELECT j1.j1_id, min(j2.j2_string) FROM j1 LEFT JOIN j2 ON (j1.j1_id = j2.j2_id) GROUP BY j1.j1_id) AS agg (id, string_count)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, TestStatementWithDialect { sql: " @@ -442,6 +484,7 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: r#"SELECT abc.j1_string, abc.j2_string FROM (SELECT j1.j1_id, j1.j1_string, j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) GROUP BY j1.j1_id, j1.j1_string, j2.j2_string ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, // Test query that order by columns are not in select columns TestStatementWithDialect { @@ -467,18 +510,21 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: r#"SELECT abc.j1_string FROM (SELECT j1.j1_string, j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "SELECT id FROM (SELECT j1_id from j1) AS c (id)", expected: r#"SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "SELECT id FROM (SELECT j1_id as id from j1) AS c", expected: r#"SELECT c.id FROM (SELECT j1.j1_id AS id FROM j1) AS c"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, // Test query that has calculation in derived table with columns TestStatementWithDialect { @@ -486,6 +532,7 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: r#"SELECT c.id FROM (SELECT (j1.j1_id + (1 * 3)) FROM j1) AS c (id)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, // Test query that has limit/distinct/order in derived table with columns TestStatementWithDialect { @@ -493,162 +540,189 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: r#"SELECT c.id FROM (SELECT DISTINCT (j1.j1_id + (1 * 3)) FROM j1 LIMIT 1) AS c (id)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "SELECT id FROM (SELECT j1_id + 1 FROM j1 ORDER BY j1_id DESC LIMIT 1) AS c (id)", expected: r#"SELECT c.id FROM (SELECT (j1.j1_id + 1) FROM j1 ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 1) AS c (id)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "SELECT id FROM (SELECT CAST((CAST(j1_id as BIGINT) + 1) as int) * 10 FROM j1 LIMIT 1) AS c (id)", expected: r#"SELECT c.id FROM (SELECT (CAST((CAST(j1.j1_id AS BIGINT) + 1) AS INTEGER) * 10) FROM j1 LIMIT 1) AS c (id)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "SELECT id FROM (SELECT CAST(j1_id as BIGINT) + 1 FROM j1 ORDER BY j1_id LIMIT 1) AS c (id)", expected: r#"SELECT c.id FROM (SELECT (CAST(j1.j1_id AS BIGINT) + 1) FROM j1 ORDER BY j1.j1_id ASC NULLS LAST LIMIT 1) AS c (id)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "SELECT temp_j.id2 FROM (SELECT j1_id, j1_string FROM j1) AS temp_j(id2, string2)", expected: r#"SELECT temp_j.id2 FROM (SELECT j1.j1_id, j1.j1_string FROM j1) AS temp_j (id2, string2)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "SELECT temp_j.id2 FROM (SELECT j1_id, j1_string FROM j1) AS temp_j(id2, string2)", expected: r#"SELECT `temp_j`.`id2` FROM (SELECT `j1`.`j1_id` AS `id2`, `j1`.`j1_string` AS `string2` FROM `j1`) AS `temp_j`"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(SqliteDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM (SELECT j1_id + 1 FROM j1) AS temp_j(id2)", expected: r#"SELECT * FROM (SELECT (`j1`.`j1_id` + 1) AS `id2` FROM `j1`) AS `temp_j`"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(SqliteDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM (SELECT j1_id FROM j1 LIMIT 1) AS temp_j(id2)", expected: r#"SELECT * FROM (SELECT `j1`.`j1_id` AS `id2` FROM `j1` LIMIT 1) AS `temp_j`"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(SqliteDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3])", expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))")"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS t1 (c1)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS t1 (c1)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3]), j1", expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") CROSS JOIN j1"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) JOIN j1 ON u.c1 = j1.j1_id", expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS u (c1) JOIN j1 ON (u.c1 = j1.j1_id)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) UNION ALL SELECT * FROM UNNEST([4,5,6]) u(c1)", expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS u (c1) UNION ALL SELECT * FROM (SELECT UNNEST([4, 5, 6]) AS "UNNEST(make_array(Int64(4),Int64(5),Int64(6)))") AS u (c1)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3])", expected: r#"SELECT * FROM UNNEST([1, 2, 3])"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS t1 (c1)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS t1 (c1)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3]), j1", expected: r#"SELECT * FROM UNNEST([1, 2, 3]) CROSS JOIN j1"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) JOIN j1 ON u.c1 = j1.j1_id", expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS u (c1) JOIN j1 ON (u.c1 = j1.j1_id)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) UNION ALL SELECT * FROM UNNEST([4,5,6]) u(c1)", expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS u (c1) UNION ALL SELECT * FROM UNNEST([4, 5, 6]) AS u (c1)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + optimized: false, }, TestStatementWithDialect { sql: "SELECT UNNEST([1,2,3])", expected: r#"SELECT * FROM UNNEST([1, 2, 3])"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + optimized: false, }, TestStatementWithDialect { sql: "SELECT UNNEST([1,2,3]) as c1", expected: r#"SELECT UNNEST([1, 2, 3]) AS c1"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + optimized: false, }, TestStatementWithDialect { sql: "SELECT UNNEST([1,2,3]), 1", expected: r#"SELECT UNNEST([1, 2, 3]) AS UNNEST(make_array(Int64(1),Int64(2),Int64(3))), Int64(1)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col)", expected: r#"SELECT * FROM unnest_table AS u CROSS JOIN UNNEST(u.array_col)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col) AS t1 (c1)", expected: r#"SELECT * FROM unnest_table AS u CROSS JOIN UNNEST(u.array_col) AS t1 (c1)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col)", expected: r#"SELECT * FROM unnest_table AS u CROSS JOIN LATERAL (SELECT UNNEST(u.array_col) AS "UNNEST(outer_ref(u.array_col))")"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col) AS t1 (c1)", expected: r#"SELECT * FROM unnest_table AS u CROSS JOIN LATERAL (SELECT UNNEST(u.array_col) AS "UNNEST(outer_ref(u.array_col))") AS t1 (c1)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), + optimized: false, }, ]; @@ -669,6 +743,13 @@ fn roundtrip_statement_with_dialect() -> Result<()> { .sql_statement_to_plan(statement) .unwrap_or_else(|e| panic!("Failed to parse sql: {}\n{e}", query.sql)); + let plan = if query.optimized { + optimize_plan(plan) + .unwrap_or_else(|e| panic!("Failed to optimize plan: {}\n{e}", query.sql)) + } else { + plan + }; + let unparser = Unparser::new(&*query.unparser_dialect); let roundtrip_statement = unparser.plan_to_sql(&plan)?; From 7ad585a4ee39fe573df909f8a3e46d8722f3e50b Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Mon, 6 Jan 2025 22:28:42 +0000 Subject: [PATCH 02/10] unparse optimized unions --- datafusion/sql/src/unparser/plan.rs | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 6f30845eb8104..706a0bf4d1c43 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -706,13 +706,6 @@ impl Unparser<'_> { Ok(()) } LogicalPlan::Union(union) => { - if union.inputs.len() != 2 { - return not_impl_err!( - "UNION ALL expected 2 inputs, but found {}", - union.inputs.len() - ); - } - // Covers cases where the UNION is a subquery and the projection is at the top level if select.already_projected() { return self.derive_with_dialect_alias( @@ -723,18 +716,24 @@ impl Unparser<'_> { ); } + println!("union: {union:#?}"); + let input_exprs: Vec = union .inputs .iter() .map(|input| self.select_to_sql_expr(input, query)) .collect::>>()?; - let union_expr = SetExpr::SetOperation { - op: ast::SetOperator::Union, - set_quantifier: ast::SetQuantifier::All, - left: Box::new(input_exprs[0].clone()), - right: Box::new(input_exprs[1].clone()), - }; + let union_expr = input_exprs + .into_iter() + .rev() + .reduce(|a, b| SetExpr::SetOperation { + op: ast::SetOperator::Union, + set_quantifier: ast::SetQuantifier::All, + left: Box::new(b), + right: Box::new(a), + }) + .unwrap(); let Some(query) = query.as_mut() else { return internal_err!( From caa1a39b4c8347b1c7ccc79070ff6d130e860dbd Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Mon, 6 Jan 2025 22:46:11 +0000 Subject: [PATCH 03/10] format Cargo.toml --- datafusion/sql/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 504035a769255..4d54de8360ba3 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -56,11 +56,11 @@ sqlparser = { workspace = true } [dev-dependencies] ctor = { workspace = true } -datafusion-optimizer = { workspace = true } datafusion-functions = { workspace = true, default-features = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-nested = { workspace = true } datafusion-functions-window = { workspace = true } +datafusion-optimizer.workspace = true env_logger = { workspace = true } paste = "^1.0" rstest = { workspace = true } From ff3394b5a06ff511ffef08121a91e2b2bb9762de Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Mon, 6 Jan 2025 22:50:53 +0000 Subject: [PATCH 04/10] format Cargo.toml --- datafusion/sql/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 4d54de8360ba3..024923a610c42 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -60,7 +60,7 @@ datafusion-functions = { workspace = true, default-features = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-nested = { workspace = true } datafusion-functions-window = { workspace = true } -datafusion-optimizer.workspace = true +datafusion-optimizer = { workspace = true } env_logger = { workspace = true } paste = "^1.0" rstest = { workspace = true } From bbc43d3be2dd32e8d02169ad17b72b9abbb5ec80 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Tue, 7 Jan 2025 17:25:16 +0000 Subject: [PATCH 05/10] revert test --- datafusion/sql/Cargo.toml | 1 - datafusion/sql/tests/cases/plan_to_sql.rs | 63 ----------------------- 2 files changed, 64 deletions(-) diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 024923a610c42..224c7cb191a38 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -60,7 +60,6 @@ datafusion-functions = { workspace = true, default-features = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-nested = { workspace = true } datafusion-functions-window = { workspace = true } -datafusion-optimizer = { workspace = true } env_logger = { workspace = true } paste = "^1.0" rstest = { workspace = true } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index e333fd51fb08b..c928659a0e3e8 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -29,8 +29,6 @@ use datafusion_functions_aggregate::grouping::grouping_udaf; use datafusion_functions_nested::make_array::make_array_udf; use datafusion_functions_nested::map::map_udf; use datafusion_functions_window::rank::rank_udwf; -use datafusion_optimizer::optimizer::Optimizer; -use datafusion_optimizer::{OptimizerContext, OptimizerRule}; use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_sql::unparser::dialect::{ CustomDialectBuilder, DefaultDialect as UnparserDefaultDialect, DefaultDialect, @@ -59,12 +57,6 @@ use datafusion_sql::unparser::extension_unparser::{ use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect}; use sqlparser::parser::Parser; -fn optimize_plan(plan: LogicalPlan) -> Result { - let optimizer = Optimizer::new(); - let observer = |_plan: &LogicalPlan, _rule: &dyn OptimizerRule| {}; - optimizer.optimize(plan, &OptimizerContext::default(), observer) -} - #[test] fn roundtrip_expr() { let tests: Vec<(TableReference, &str, &str)> = vec![ @@ -297,7 +289,6 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: &'static str, parser_dialect: Box, unparser_dialect: Box, - optimized: bool, } let tests: Vec = vec![ TestStatementWithDialect { @@ -308,7 +299,6 @@ fn roundtrip_statement_with_dialect() -> Result<()> { "SELECT `j1_min` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, min(`ta`.`j1_id`) FROM `j1` AS `ta` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10", parser_dialect: Box::new(MySqlDialect {}), unparser_dialect: Box::new(UnparserMySqlDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;", @@ -318,7 +308,6 @@ fn roundtrip_statement_with_dialect() -> Result<()> { "SELECT j1_min FROM (SELECT min(ta.j1_id) AS j1_min, min(ta.j1_id) FROM j1 AS ta ORDER BY min(ta.j1_id) ASC NULLS LAST) LIMIT 10", parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "select min(ta.j1_id) as j1_min, max(tb.j1_max) from j1 ta, (select distinct max(ta.j1_id) as j1_max from j1 ta order by max(ta.j1_id)) tb order by min(ta.j1_id) limit 10;", @@ -326,7 +315,6 @@ fn roundtrip_statement_with_dialect() -> Result<()> { "SELECT `j1_min`, `max(tb.j1_max)` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, max(`tb`.`j1_max`), min(`ta`.`j1_id`) FROM `j1` AS `ta` CROSS JOIN (SELECT `j1_max` FROM (SELECT DISTINCT max(`ta`.`j1_id`) AS `j1_max` FROM `j1` AS `ta`) AS `derived_distinct`) AS `tb` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10", parser_dialect: Box::new(MySqlDialect {}), unparser_dialect: Box::new(UnparserMySqlDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "select j1_id from (select 1 as j1_id);", @@ -334,7 +322,6 @@ fn roundtrip_statement_with_dialect() -> Result<()> { "SELECT `j1_id` FROM (SELECT 1 AS `j1_id`) AS `derived_projection`", parser_dialect: Box::new(MySqlDialect {}), unparser_dialect: Box::new(UnparserMySqlDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "select * from (select * from j1 limit 10);", @@ -342,7 +329,6 @@ fn roundtrip_statement_with_dialect() -> Result<()> { "SELECT * FROM (SELECT * FROM `j1` LIMIT 10) AS `derived_limit`", parser_dialect: Box::new(MySqlDialect {}), unparser_dialect: Box::new(UnparserMySqlDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "select ta.j1_id from j1 ta order by j1_id limit 10;", @@ -350,14 +336,12 @@ fn roundtrip_statement_with_dialect() -> Result<()> { "SELECT `ta`.`j1_id` FROM `j1` AS `ta` ORDER BY `ta`.`j1_id` ASC LIMIT 10", parser_dialect: Box::new(MySqlDialect {}), unparser_dialect: Box::new(UnparserMySqlDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "select ta.j1_id from j1 ta order by j1_id limit 10;", expected: r#"SELECT ta.j1_id FROM j1 AS ta ORDER BY ta.j1_id ASC NULLS LAST LIMIT 10"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT j1_id FROM j1 @@ -368,7 +352,6 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: r#"SELECT j1.j1_id FROM j1 UNION ALL SELECT tb.j2_id AS j1_id FROM j2 AS tb ORDER BY j1_id ASC NULLS LAST LIMIT 10"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT 1 x, 'a' y UNION ALL @@ -378,7 +361,6 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: r#"SELECT 1 AS x, 'a' AS y UNION ALL SELECT 1 AS x, 'b' AS y UNION ALL SELECT 2 AS x, 'a' AS y UNION ALL SELECT 2 AS x, 'c' AS y"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: true, }, TestStatementWithDialect { sql: "SELECT 1 x, 'a' y UNION ALL @@ -388,7 +370,6 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: r#"SELECT 1 AS x, 'a' AS y UNION ALL SELECT 1 AS x, 'b' AS y UNION ALL SELECT 2 AS x, 'a' AS y UNION ALL SELECT 2 AS x, 'c' AS y"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, // Test query with derived tables that put distinct,sort,limit on the wrong level TestStatementWithDialect { @@ -396,21 +377,18 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: r#"SELECT j1.j1_string FROM j1 ORDER BY j1.j1_id ASC NULLS LAST"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT j1_string AS a from j1 order by j1_id", expected: r#"SELECT j1.j1_string AS a FROM j1 ORDER BY j1.j1_id ASC NULLS LAST"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT j1_string from j1 join j2 on j1.j1_id = j2.j2_id order by j1_id", expected: r#"SELECT j1.j1_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id ASC NULLS LAST"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: " @@ -436,7 +414,6 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: r#"SELECT abc.j1_string, abc.j2_string FROM (SELECT DISTINCT j1.j1_id, j1.j1_string, j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, // more tests around subquery/derived table roundtrip TestStatementWithDialect { @@ -454,7 +431,6 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: r#"SELECT agg.string_count FROM (SELECT j1.j1_id, min(j2.j2_string) FROM j1 LEFT JOIN j2 ON (j1.j1_id = j2.j2_id) GROUP BY j1.j1_id) AS agg (id, string_count)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: " @@ -484,7 +460,6 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: r#"SELECT abc.j1_string, abc.j2_string FROM (SELECT j1.j1_id, j1.j1_string, j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) GROUP BY j1.j1_id, j1.j1_string, j2.j2_string ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, // Test query that order by columns are not in select columns TestStatementWithDialect { @@ -510,21 +485,18 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: r#"SELECT abc.j1_string FROM (SELECT j1.j1_string, j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT id FROM (SELECT j1_id from j1) AS c (id)", expected: r#"SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT id FROM (SELECT j1_id as id from j1) AS c", expected: r#"SELECT c.id FROM (SELECT j1.j1_id AS id FROM j1) AS c"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, // Test query that has calculation in derived table with columns TestStatementWithDialect { @@ -532,7 +504,6 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: r#"SELECT c.id FROM (SELECT (j1.j1_id + (1 * 3)) FROM j1) AS c (id)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, // Test query that has limit/distinct/order in derived table with columns TestStatementWithDialect { @@ -540,189 +511,162 @@ fn roundtrip_statement_with_dialect() -> Result<()> { expected: r#"SELECT c.id FROM (SELECT DISTINCT (j1.j1_id + (1 * 3)) FROM j1 LIMIT 1) AS c (id)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT id FROM (SELECT j1_id + 1 FROM j1 ORDER BY j1_id DESC LIMIT 1) AS c (id)", expected: r#"SELECT c.id FROM (SELECT (j1.j1_id + 1) FROM j1 ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 1) AS c (id)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT id FROM (SELECT CAST((CAST(j1_id as BIGINT) + 1) as int) * 10 FROM j1 LIMIT 1) AS c (id)", expected: r#"SELECT c.id FROM (SELECT (CAST((CAST(j1.j1_id AS BIGINT) + 1) AS INTEGER) * 10) FROM j1 LIMIT 1) AS c (id)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT id FROM (SELECT CAST(j1_id as BIGINT) + 1 FROM j1 ORDER BY j1_id LIMIT 1) AS c (id)", expected: r#"SELECT c.id FROM (SELECT (CAST(j1.j1_id AS BIGINT) + 1) FROM j1 ORDER BY j1.j1_id ASC NULLS LAST LIMIT 1) AS c (id)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT temp_j.id2 FROM (SELECT j1_id, j1_string FROM j1) AS temp_j(id2, string2)", expected: r#"SELECT temp_j.id2 FROM (SELECT j1.j1_id, j1.j1_string FROM j1) AS temp_j (id2, string2)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT temp_j.id2 FROM (SELECT j1_id, j1_string FROM j1) AS temp_j(id2, string2)", expected: r#"SELECT `temp_j`.`id2` FROM (SELECT `j1`.`j1_id` AS `id2`, `j1`.`j1_string` AS `string2` FROM `j1`) AS `temp_j`"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(SqliteDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM (SELECT j1_id + 1 FROM j1) AS temp_j(id2)", expected: r#"SELECT * FROM (SELECT (`j1`.`j1_id` + 1) AS `id2` FROM `j1`) AS `temp_j`"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(SqliteDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM (SELECT j1_id FROM j1 LIMIT 1) AS temp_j(id2)", expected: r#"SELECT * FROM (SELECT `j1`.`j1_id` AS `id2` FROM `j1` LIMIT 1) AS `temp_j`"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(SqliteDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3])", expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))")"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS t1 (c1)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS t1 (c1)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3]), j1", expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") CROSS JOIN j1"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) JOIN j1 ON u.c1 = j1.j1_id", expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS u (c1) JOIN j1 ON (u.c1 = j1.j1_id)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) UNION ALL SELECT * FROM UNNEST([4,5,6]) u(c1)", expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS u (c1) UNION ALL SELECT * FROM (SELECT UNNEST([4, 5, 6]) AS "UNNEST(make_array(Int64(4),Int64(5),Int64(6)))") AS u (c1)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3])", expected: r#"SELECT * FROM UNNEST([1, 2, 3])"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS t1 (c1)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS t1 (c1)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3]), j1", expected: r#"SELECT * FROM UNNEST([1, 2, 3]) CROSS JOIN j1"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) JOIN j1 ON u.c1 = j1.j1_id", expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS u (c1) JOIN j1 ON (u.c1 = j1.j1_id)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) UNION ALL SELECT * FROM UNNEST([4,5,6]) u(c1)", expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS u (c1) UNION ALL SELECT * FROM UNNEST([4, 5, 6]) AS u (c1)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - optimized: false, }, TestStatementWithDialect { sql: "SELECT UNNEST([1,2,3])", expected: r#"SELECT * FROM UNNEST([1, 2, 3])"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - optimized: false, }, TestStatementWithDialect { sql: "SELECT UNNEST([1,2,3]) as c1", expected: r#"SELECT UNNEST([1, 2, 3]) AS c1"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - optimized: false, }, TestStatementWithDialect { sql: "SELECT UNNEST([1,2,3]), 1", expected: r#"SELECT UNNEST([1, 2, 3]) AS UNNEST(make_array(Int64(1),Int64(2),Int64(3))), Int64(1)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col)", expected: r#"SELECT * FROM unnest_table AS u CROSS JOIN UNNEST(u.array_col)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col) AS t1 (c1)", expected: r#"SELECT * FROM unnest_table AS u CROSS JOIN UNNEST(u.array_col) AS t1 (c1)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col)", expected: r#"SELECT * FROM unnest_table AS u CROSS JOIN LATERAL (SELECT UNNEST(u.array_col) AS "UNNEST(outer_ref(u.array_col))")"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, TestStatementWithDialect { sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col) AS t1 (c1)", expected: r#"SELECT * FROM unnest_table AS u CROSS JOIN LATERAL (SELECT UNNEST(u.array_col) AS "UNNEST(outer_ref(u.array_col))") AS t1 (c1)"#, parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), - optimized: false, }, ]; @@ -743,13 +687,6 @@ fn roundtrip_statement_with_dialect() -> Result<()> { .sql_statement_to_plan(statement) .unwrap_or_else(|e| panic!("Failed to parse sql: {}\n{e}", query.sql)); - let plan = if query.optimized { - optimize_plan(plan) - .unwrap_or_else(|e| panic!("Failed to optimize plan: {}\n{e}", query.sql)) - } else { - plan - }; - let unparser = Unparser::new(&*query.unparser_dialect); let roundtrip_statement = unparser.plan_to_sql(&plan)?; From 0fdc798deab24ded982e6eb044e2740092a0293e Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Tue, 7 Jan 2025 18:27:08 +0000 Subject: [PATCH 06/10] rewrite test to avoid cyclic dep --- datafusion/sql/tests/cases/plan_to_sql.rs | 41 ++++++++++++++++++++--- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index c928659a0e3e8..b81a43e47edf7 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -15,14 +15,14 @@ // specific language governing permissions and limitations // under the License. -use arrow_schema::*; +use arrow_schema::{DataType, Field, Schema}; use datafusion_common::{assert_contains, DFSchema, DFSchemaRef, Result, TableReference}; use datafusion_expr::test::function_stub::{ count_udaf, max_udaf, min_udaf, sum, sum_udaf, }; use datafusion_expr::{ - col, lit, table_scan, wildcard, Expr, Extension, LogicalPlan, LogicalPlanBuilder, - UserDefinedLogicalNode, UserDefinedLogicalNodeCore, + col, lit, table_scan, wildcard, EmptyRelation, Expr, Extension, LogicalPlan, + LogicalPlanBuilder, Union, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, }; use datafusion_functions::unicode; use datafusion_functions_aggregate::grouping::grouping_udaf; @@ -42,7 +42,7 @@ use std::{fmt, vec}; use crate::common::{MockContextProvider, MockSessionState}; use datafusion_expr::builder::{ - table_scan_with_filter_and_fetch, table_scan_with_filters, + project, table_scan_with_filter_and_fetch, table_scan_with_filters, }; use datafusion_functions::core::planner::CoreFunctionPlanner; use datafusion_functions_nested::extract::array_element_udf; @@ -1633,3 +1633,36 @@ fn test_unparse_extension_to_sql() -> Result<()> { } Ok(()) } + +#[test] +fn test_unparse_optimized_multi_union() -> Result<()> { + let unparser = Unparser::default(); + + let schema = Schema::new(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, false), + ]); + + let dfschema = Arc::new(DFSchema::try_from(schema)?); + + let empty = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: dfschema.clone(), + }); + + let plan = LogicalPlan::Union(Union { + inputs: vec![ + project(empty.clone(), vec![lit(1).alias("x"), lit("a").alias("y")])?.into(), + project(empty.clone(), vec![lit(1).alias("x"), lit("b").alias("y")])?.into(), + project(empty.clone(), vec![lit(2).alias("x"), lit("a").alias("y")])?.into(), + project(empty.clone(), vec![lit(2).alias("x"), lit("c").alias("y")])?.into(), + ], + schema: dfschema.clone(), + }); + + let sql = "SELECT 1 AS x, 'a' AS y UNION ALL SELECT 1 AS x, 'b' AS y UNION ALL SELECT 2 AS x, 'a' AS y UNION ALL SELECT 2 AS x, 'c' AS y"; + + assert_eq!(unparser.plan_to_sql(&plan)?.to_string(), sql); + + Ok(()) +} From 9392454ef35cdc0999f80c90678bc10c598ba38f Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Tue, 7 Jan 2025 18:56:03 +0000 Subject: [PATCH 07/10] remove old test --- datafusion/sql/tests/cases/plan_to_sql.rs | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index b81a43e47edf7..a6ce9e88560c7 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -353,24 +353,6 @@ fn roundtrip_statement_with_dialect() -> Result<()> { parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), }, - TestStatementWithDialect { - sql: "SELECT 1 x, 'a' y UNION ALL - SELECT 1 x, 'b' y UNION ALL - SELECT 2 x, 'a' y UNION ALL - SELECT 2 x, 'c' y", - expected: r#"SELECT 1 AS x, 'a' AS y UNION ALL SELECT 1 AS x, 'b' AS y UNION ALL SELECT 2 AS x, 'a' AS y UNION ALL SELECT 2 AS x, 'c' AS y"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - TestStatementWithDialect { - sql: "SELECT 1 x, 'a' y UNION ALL - SELECT 1 x, 'b' y UNION ALL - SELECT 2 x, 'a' y UNION ALL - SELECT 2 x, 'c' y", - expected: r#"SELECT 1 AS x, 'a' AS y UNION ALL SELECT 1 AS x, 'b' AS y UNION ALL SELECT 2 AS x, 'a' AS y UNION ALL SELECT 2 AS x, 'c' AS y"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, // Test query with derived tables that put distinct,sort,limit on the wrong level TestStatementWithDialect { sql: "SELECT j1_string from j1 order by j1_id", From e1267a326e31a73f7f0db9cc1e700bc5e7b910f0 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Tue, 7 Jan 2025 19:49:34 +0000 Subject: [PATCH 08/10] cleanup --- datafusion/sql/src/unparser/plan.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 706a0bf4d1c43..bacf3dc37415f 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -716,8 +716,6 @@ impl Unparser<'_> { ); } - println!("union: {union:#?}"); - let input_exprs: Vec = union .inputs .iter() From 786e85ad2064e83dda1cc38785a7db8119bcab1b Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Wed, 8 Jan 2025 10:40:48 +0000 Subject: [PATCH 09/10] comments and error handling --- datafusion/sql/src/unparser/plan.rs | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index bacf3dc37415f..32b66cddaab67 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -722,16 +722,21 @@ impl Unparser<'_> { .map(|input| self.select_to_sql_expr(input, query)) .collect::>>()?; - let union_expr = input_exprs - .into_iter() - .rev() - .reduce(|a, b| SetExpr::SetOperation { - op: ast::SetOperator::Union, - set_quantifier: ast::SetQuantifier::All, - left: Box::new(b), - right: Box::new(a), - }) - .unwrap(); + // Build the union expression tree bottom-up by reversing the order + // note that we are also swapping left and right inputs because of the rev + let Some(union_expr) = + input_exprs + .into_iter() + .rev() + .reduce(|a, b| SetExpr::SetOperation { + op: ast::SetOperator::Union, + set_quantifier: ast::SetQuantifier::All, + left: Box::new(b), + right: Box::new(a), + }) + else { + return internal_err!("UNION operator requires at least 2 inputs"); + }; let Some(query) = query.as_mut() else { return internal_err!( From 77d09ce187072085e7352dce623b382536c0c6ee Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Wed, 8 Jan 2025 15:22:38 +0000 Subject: [PATCH 10/10] handle union with lt 2 inputs --- datafusion/sql/src/unparser/plan.rs | 27 ++++++++++++----------- datafusion/sql/tests/cases/plan_to_sql.rs | 15 +++++++++++++ 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 32b66cddaab67..2bad683dc1bcc 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -722,21 +722,22 @@ impl Unparser<'_> { .map(|input| self.select_to_sql_expr(input, query)) .collect::>>()?; + if input_exprs.len() < 2 { + return internal_err!("UNION operator requires at least 2 inputs"); + } + // Build the union expression tree bottom-up by reversing the order // note that we are also swapping left and right inputs because of the rev - let Some(union_expr) = - input_exprs - .into_iter() - .rev() - .reduce(|a, b| SetExpr::SetOperation { - op: ast::SetOperator::Union, - set_quantifier: ast::SetQuantifier::All, - left: Box::new(b), - right: Box::new(a), - }) - else { - return internal_err!("UNION operator requires at least 2 inputs"); - }; + let union_expr = input_exprs + .into_iter() + .rev() + .reduce(|a, b| SetExpr::SetOperation { + op: ast::SetOperator::Union, + set_quantifier: ast::SetQuantifier::All, + left: Box::new(b), + right: Box::new(a), + }) + .unwrap(); let Some(query) = query.as_mut() else { return internal_err!( diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index a6ce9e88560c7..94b4df59ef00b 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -1646,5 +1646,20 @@ fn test_unparse_optimized_multi_union() -> Result<()> { assert_eq!(unparser.plan_to_sql(&plan)?.to_string(), sql); + let plan = LogicalPlan::Union(Union { + inputs: vec![project( + empty.clone(), + vec![lit(1).alias("x"), lit("a").alias("y")], + )? + .into()], + schema: dfschema.clone(), + }); + + if let Some(err) = plan_to_sql(&plan).err() { + assert_contains!(err.to_string(), "UNION operator requires at least 2 inputs"); + } else { + panic!("Expected error") + } + Ok(()) }