diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs new file mode 100644 index 0000000000000..5b1f81e820a29 --- /dev/null +++ b/datafusion/sql/src/cte.rs @@ -0,0 +1,212 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; + +use arrow::datatypes::Schema; +use datafusion_common::{ + not_impl_err, plan_err, + tree_node::{TreeNode, TreeNodeRecursion}, + Result, +}; +use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, TableSource}; +use sqlparser::ast::{Query, SetExpr, SetOperator, With}; + +impl<'a, S: ContextProvider> SqlToRel<'a, S> { + pub(super) fn plan_with_clause( + &self, + with: With, + planner_context: &mut PlannerContext, + ) -> Result<()> { + let is_recursive = with.recursive; + // Process CTEs from top to bottom + for cte in with.cte_tables { + // A `WITH` block can't use the same name more than once + let cte_name = self.normalizer.normalize(cte.alias.name.clone()); + if planner_context.contains_cte(&cte_name) { + return plan_err!( + "WITH query name {cte_name:?} specified more than once" + ); + } + + // Create a logical plan for the CTE + let cte_plan = if is_recursive { + self.recursive_cte(cte_name.clone(), *cte.query, planner_context)? + } else { + self.non_recursive_cte(*cte.query, planner_context)? + }; + + // Each `WITH` block can change the column names in the last + // projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2"). + let final_plan = self.apply_table_alias(cte_plan, cte.alias)?; + // Export the CTE to the outer query + planner_context.insert_cte(cte_name, final_plan); + } + Ok(()) + } + + fn non_recursive_cte( + &self, + cte_query: Query, + planner_context: &mut PlannerContext, + ) -> Result { + // CTE expr don't need extend outer_query_schema, + // so we clone a new planner_context here. + let mut cte_planner_context = planner_context.clone(); + self.query_to_plan(cte_query, &mut cte_planner_context) + } + + fn recursive_cte( + &self, + cte_name: String, + mut cte_query: Query, + planner_context: &mut PlannerContext, + ) -> Result { + if !self + .context_provider + .options() + .execution + .enable_recursive_ctes + { + return not_impl_err!("Recursive CTEs are not enabled"); + } + + let (left_expr, right_expr, set_quantifier) = match *cte_query.body { + SetExpr::SetOperation { + op: SetOperator::Union, + left, + right, + set_quantifier, + } => (left, right, set_quantifier), + other => { + // If the query is not a UNION, then it is not a recursive CTE + cte_query.body = Box::new(other); + return self.non_recursive_cte(cte_query, planner_context); + } + }; + + // Each recursive CTE consists from two parts in the logical plan: + // 1. A static term (the left hand side on the SQL, where the + // referencing to the same CTE is not allowed) + // + // 2. A recursive term (the right hand side, and the recursive + // part) + + // Since static term does not have any specific properties, it can + // be compiled as if it was a regular expression. This will + // allow us to infer the schema to be used in the recursive term. + + // ---------- Step 1: Compile the static term ------------------ + let static_plan = + self.set_expr_to_plan(*left_expr, &mut planner_context.clone())?; + + // Since the recursive CTEs include a component that references a + // table with its name, like the example below: + // + // WITH RECURSIVE values(n) AS ( + // SELECT 1 as n -- static term + // UNION ALL + // SELECT n + 1 + // FROM values -- self reference + // WHERE n < 100 + // ) + // + // We need a temporary 'relation' to be referenced and used. PostgreSQL + // calls this a 'working table', but it is entirely an implementation + // detail and a 'real' table with that name might not even exist (as + // in the case of DataFusion). + // + // Since we can't simply register a table during planning stage (it is + // an execution problem), we'll use a relation object that preserves the + // schema of the input perfectly and also knows which recursive CTE it is + // bound to. + + // ---------- Step 2: Create a temporary relation ------------------ + // Step 2.1: Create a table source for the temporary relation + let work_table_source = self.context_provider.create_cte_work_table( + &cte_name, + Arc::new(Schema::from(static_plan.schema().as_ref())), + )?; + + // Step 2.2: Create a temporary relation logical plan that will be used + // as the input to the recursive term + let work_table_plan = LogicalPlanBuilder::scan( + cte_name.to_string(), + work_table_source.clone(), + None, + )? + .build()?; + + let name = cte_name.clone(); + + // Step 2.3: Register the temporary relation in the planning context + // For all the self references in the variadic term, we'll replace it + // with the temporary relation we created above by temporarily registering + // it as a CTE. This temporary relation in the planning context will be + // replaced by the actual CTE plan once we're done with the planning. + planner_context.insert_cte(cte_name.clone(), work_table_plan); + + // ---------- Step 3: Compile the recursive term ------------------ + // this uses the named_relation we inserted above to resolve the + // relation. This ensures that the recursive term uses the named relation logical plan + // and thus the 'continuance' physical plan as its input and source + let recursive_plan = + self.set_expr_to_plan(*right_expr, &mut planner_context.clone())?; + + // Check if the recursive term references the CTE itself, + // if not, it is a non-recursive CTE + if !has_work_table_reference(&recursive_plan, &work_table_source) { + // Remove the work table plan from the context + planner_context.remove_cte(&cte_name); + // Compile it as a non-recursive CTE + return self.set_operation_to_plan( + SetOperator::Union, + static_plan, + recursive_plan, + set_quantifier, + ); + } + + // ---------- Step 4: Create the final plan ------------------ + // Step 4.1: Compile the final plan + let distinct = !Self::is_union_all(set_quantifier)?; + LogicalPlanBuilder::from(static_plan) + .to_recursive_query(name, recursive_plan, distinct)? + .build() + } +} + +fn has_work_table_reference( + plan: &LogicalPlan, + work_table_source: &Arc, +) -> bool { + let mut has_reference = false; + plan.apply(&mut |node| { + if let LogicalPlan::TableScan(scan) = node { + if Arc::ptr_eq(&scan.source, work_table_source) { + has_reference = true; + return Ok(TreeNodeRecursion::Stop); + } + } + Ok(TreeNodeRecursion::Continue) + }) + // Closure always return Ok + .unwrap(); + has_reference +} diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index 12d6a46696346..1040cc61c702b 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -28,6 +28,7 @@ //! [`SqlToRel`]: planner::SqlToRel //! [`LogicalPlan`]: datafusion_expr::logical_plan::LogicalPlan +mod cte; mod expr; pub mod parser; pub mod planner; diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index f94c6ec4e8c93..d2182962b98ee 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -213,6 +213,11 @@ impl PlannerContext { pub fn get_cte(&self, cte_name: &str) -> Option<&LogicalPlan> { self.ctes.get(cte_name).map(|cte| cte.as_ref()) } + + /// Remove the plan of CTE / Subquery for the specified name + pub(super) fn remove_cte(&mut self, cte_name: &str) { + self.ctes.remove(cte_name); + } } /// SQL query planner diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index eda8398c432b2..ba876d052f5e2 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -19,21 +19,15 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use arrow::datatypes::Schema; -use datafusion_common::{ - not_impl_err, plan_err, sql_err, Constraints, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{plan_err, Constraints, Result, ScalarValue}; use datafusion_expr::{ CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, Operator, }; use sqlparser::ast::{ - Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, SetOperator, - SetQuantifier, Value, + Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, Value, }; -use sqlparser::parser::ParserError::ParserError; - impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Generate a logical plan from an SQL query pub(crate) fn query_to_plan( @@ -54,139 +48,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { let set_expr = query.body; if let Some(with) = query.with { - // Process CTEs from top to bottom - let is_recursive = with.recursive; - - for cte in with.cte_tables { - // A `WITH` block can't use the same name more than once - let cte_name = self.normalizer.normalize(cte.alias.name.clone()); - if planner_context.contains_cte(&cte_name) { - return sql_err!(ParserError(format!( - "WITH query name {cte_name:?} specified more than once" - ))); - } - - if is_recursive { - if !self - .context_provider - .options() - .execution - .enable_recursive_ctes - { - return not_impl_err!("Recursive CTEs are not enabled"); - } - - match *cte.query.body { - SetExpr::SetOperation { - op: SetOperator::Union, - left, - right, - set_quantifier, - } => { - let distinct = set_quantifier != SetQuantifier::All; - - // Each recursive CTE consists from two parts in the logical plan: - // 1. A static term (the left hand side on the SQL, where the - // referencing to the same CTE is not allowed) - // - // 2. A recursive term (the right hand side, and the recursive - // part) - - // Since static term does not have any specific properties, it can - // be compiled as if it was a regular expression. This will - // allow us to infer the schema to be used in the recursive term. - - // ---------- Step 1: Compile the static term ------------------ - let static_plan = self - .set_expr_to_plan(*left, &mut planner_context.clone())?; - - // Since the recursive CTEs include a component that references a - // table with its name, like the example below: - // - // WITH RECURSIVE values(n) AS ( - // SELECT 1 as n -- static term - // UNION ALL - // SELECT n + 1 - // FROM values -- self reference - // WHERE n < 100 - // ) - // - // We need a temporary 'relation' to be referenced and used. PostgreSQL - // calls this a 'working table', but it is entirely an implementation - // detail and a 'real' table with that name might not even exist (as - // in the case of DataFusion). - // - // Since we can't simply register a table during planning stage (it is - // an execution problem), we'll use a relation object that preserves the - // schema of the input perfectly and also knows which recursive CTE it is - // bound to. - - // ---------- Step 2: Create a temporary relation ------------------ - // Step 2.1: Create a table source for the temporary relation - let work_table_source = - self.context_provider.create_cte_work_table( - &cte_name, - Arc::new(Schema::from(static_plan.schema().as_ref())), - )?; - - // Step 2.2: Create a temporary relation logical plan that will be used - // as the input to the recursive term - let work_table_plan = LogicalPlanBuilder::scan( - cte_name.to_string(), - work_table_source, - None, - )? - .build()?; - - let name = cte_name.clone(); - - // Step 2.3: Register the temporary relation in the planning context - // For all the self references in the variadic term, we'll replace it - // with the temporary relation we created above by temporarily registering - // it as a CTE. This temporary relation in the planning context will be - // replaced by the actual CTE plan once we're done with the planning. - planner_context.insert_cte(cte_name.clone(), work_table_plan); - - // ---------- Step 3: Compile the recursive term ------------------ - // this uses the named_relation we inserted above to resolve the - // relation. This ensures that the recursive term uses the named relation logical plan - // and thus the 'continuance' physical plan as its input and source - let recursive_plan = self - .set_expr_to_plan(*right, &mut planner_context.clone())?; - - // ---------- Step 4: Create the final plan ------------------ - // Step 4.1: Compile the final plan - let logical_plan = LogicalPlanBuilder::from(static_plan) - .to_recursive_query(name, recursive_plan, distinct)? - .build()?; - - let final_plan = - self.apply_table_alias(logical_plan, cte.alias)?; - - // Step 4.2: Remove the temporary relation from the planning context and replace it - // with the final plan. - planner_context.insert_cte(cte_name.clone(), final_plan); - } - _ => { - return Err(DataFusionError::SQL( - ParserError(format!("Unsupported CTE: {cte}")), - None, - )); - } - }; - } else { - // create logical plan & pass backreferencing CTEs - // CTE expr don't need extend outer_query_schema - let logical_plan = - self.query_to_plan(*cte.query, &mut planner_context.clone())?; - - // Each `WITH` block can change the column names in the last - // projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2"). - let logical_plan = self.apply_table_alias(logical_plan, cte.alias)?; - - planner_context.insert_cte(cte_name, logical_plan); - } - } + self.plan_with_clause(with, planner_context)?; } let plan = self.set_expr_to_plan(*(set_expr.clone()), planner_context)?; let plan = self.order_by(plan, query.order_by, planner_context)?; diff --git a/datafusion/sql/src/set_expr.rs b/datafusion/sql/src/set_expr.rs index 2cbb68368f722..cbe41c33c729c 100644 --- a/datafusion/sql/src/set_expr.rs +++ b/datafusion/sql/src/set_expr.rs @@ -35,45 +35,58 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { right, set_quantifier, } => { - let all = match set_quantifier { - SetQuantifier::All => true, - SetQuantifier::Distinct | SetQuantifier::None => false, - SetQuantifier::ByName => { - return not_impl_err!("UNION BY NAME not implemented"); - } - SetQuantifier::AllByName => { - return not_impl_err!("UNION ALL BY NAME not implemented") - } - SetQuantifier::DistinctByName => { - return not_impl_err!("UNION DISTINCT BY NAME not implemented") - } - }; - let left_plan = self.set_expr_to_plan(*left, planner_context)?; let right_plan = self.set_expr_to_plan(*right, planner_context)?; - match (op, all) { - (SetOperator::Union, true) => LogicalPlanBuilder::from(left_plan) - .union(right_plan)? - .build(), - (SetOperator::Union, false) => LogicalPlanBuilder::from(left_plan) - .union_distinct(right_plan)? - .build(), - (SetOperator::Intersect, true) => { - LogicalPlanBuilder::intersect(left_plan, right_plan, true) - } - (SetOperator::Intersect, false) => { - LogicalPlanBuilder::intersect(left_plan, right_plan, false) - } - (SetOperator::Except, true) => { - LogicalPlanBuilder::except(left_plan, right_plan, true) - } - (SetOperator::Except, false) => { - LogicalPlanBuilder::except(left_plan, right_plan, false) - } - } + self.set_operation_to_plan(op, left_plan, right_plan, set_quantifier) } SetExpr::Query(q) => self.query_to_plan(*q, planner_context), _ => not_impl_err!("Query {set_expr} not implemented yet"), } } + + pub(super) fn is_union_all(set_quantifier: SetQuantifier) -> Result { + match set_quantifier { + SetQuantifier::All => Ok(true), + SetQuantifier::Distinct | SetQuantifier::None => Ok(false), + SetQuantifier::ByName => { + not_impl_err!("UNION BY NAME not implemented") + } + SetQuantifier::AllByName => { + not_impl_err!("UNION ALL BY NAME not implemented") + } + SetQuantifier::DistinctByName => { + not_impl_err!("UNION DISTINCT BY NAME not implemented") + } + } + } + + pub(super) fn set_operation_to_plan( + &self, + op: SetOperator, + left_plan: LogicalPlan, + right_plan: LogicalPlan, + set_quantifier: SetQuantifier, + ) -> Result { + let all = Self::is_union_all(set_quantifier)?; + match (op, all) { + (SetOperator::Union, true) => LogicalPlanBuilder::from(left_plan) + .union(right_plan)? + .build(), + (SetOperator::Union, false) => LogicalPlanBuilder::from(left_plan) + .union_distinct(right_plan)? + .build(), + (SetOperator::Intersect, true) => { + LogicalPlanBuilder::intersect(left_plan, right_plan, true) + } + (SetOperator::Intersect, false) => { + LogicalPlanBuilder::intersect(left_plan, right_plan, false) + } + (SetOperator::Except, true) => { + LogicalPlanBuilder::except(left_plan, right_plan, true) + } + (SetOperator::Except, false) => { + LogicalPlanBuilder::except(left_plan, right_plan, false) + } + } + } } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 448a9c54202e3..cf315d89fe1e2 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -2992,16 +2992,6 @@ fn join_with_aliases() { quick_test(sql, expected); } -#[test] -fn cte_use_same_name_multiple_times() { - let sql = - "with a as (select * from person), a as (select * from orders) select * from a;"; - let expected = - "SQL error: ParserError(\"WITH query name \\\"a\\\" specified more than once\")"; - let result = logical_plan(sql).err().unwrap(); - assert_eq!(result.strip_backtrace(), expected); -} - #[test] fn negative_interval_plus_interval_in_projection() { let sql = "select -interval '2 days' + interval '5 days';"; diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index e33dfabaf2caa..eec7eb0e3399a 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -39,6 +39,37 @@ physical_plan ProjectionExec: expr=[1 as a, 2 as b, 3 as c] --PlaceholderRowExec +# cte_use_same_name_multiple_times +statement error DataFusion error: Error during planning: WITH query name "a" specified more than once +WITH a AS (SELECT 1), a AS (SELECT 2) SELECT * FROM a; + +# Test disabling recursive CTE +statement ok +set datafusion.execution.enable_recursive_ctes = false; + +query error DataFusion error: This feature is not implemented: Recursive CTEs are not enabled +WITH RECURSIVE nodes AS ( + SELECT 1 as id + UNION ALL + SELECT id + 1 as id + FROM nodes + WHERE id < 3 +) SELECT * FROM nodes + +statement ok +set datafusion.execution.enable_recursive_ctes = true; + + +# DISTINCT UNION is not supported +query error DataFusion error: This feature is not implemented: Recursive queries with a distinct 'UNION' \(in which the previous iteration's results will be de\-duplicated\) is not supported +WITH RECURSIVE nodes AS ( + SELECT 1 as id + UNION + SELECT id + 1 as id + FROM nodes + WHERE id < 3 +) SELECT * FROM nodes + # trivial recursive CTE works query I rowsort @@ -744,3 +775,60 @@ WITH RECURSIVE my_cte AS ( UNION ALL SELECT 'abc' FROM my_cte WHERE CAST(a AS text) !='abc' ) SELECT * FROM my_cte; + +# Define a non-recursive CTE in the recursive WITH clause. +# Test issue: https://github.com/apache/arrow-datafusion/issues/9804 +query I +WITH RECURSIVE cte AS ( + SELECT a FROM (VALUES(1)) AS t(a) WHERE a > 2 + UNION ALL + SELECT 2 +) SELECT * FROM cte; +---- +2 + +# Define a non-recursive CTE in the recursive WITH clause. +# UNION ALL +query I rowsort +WITH RECURSIVE cte AS ( + SELECT 1 + UNION ALL + SELECT 2 +) SELECT * FROM cte; +---- +1 +2 + +# Define a non-recursive CTE in the recursive WITH clause. +# DISTINCT UNION +query I +WITH RECURSIVE cte AS ( + SELECT 2 + UNION + SELECT 2 +) SELECT * FROM cte; +---- +2 + +# Define a non-recursive CTE in the recursive WITH clause. +# UNION is not present. +query I +WITH RECURSIVE cte AS ( + SELECT 1 +) SELECT * FROM cte; +---- +1 + +# Define a recursive CTE and a non-recursive CTE at the same time. +query II rowsort +WITH RECURSIVE +non_recursive_cte AS ( + SELECT 1 +), +recursive_cte AS ( + SELECT 1 AS a UNION ALL SELECT a+2 FROM recursive_cte WHERE a < 3 +) +SELECT * FROM non_recursive_cte, recursive_cte; +---- +1 1 +1 3