Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions datafusion/catalog/src/cte_worktable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,16 @@ use crate::{ScanArgs, ScanResult, Session, TableProvider};
pub struct CteWorkTable {
/// The name of the CTE work table
name: String,
/// This schema must be shared across both the static and recursive terms of a recursive query
/// Schema exposed by recursive self-references while planning the recursive term.
///
/// This is a conservative work-table schema, not the final recursive query output
/// schema. For example, the SQL planner may mark fields nullable here so recursive
/// references do not inherit unsound anchor-term nullability assumptions.
table_schema: SchemaRef,
}

impl CteWorkTable {
/// construct a new CteWorkTable with the given name and schema
/// This schema must match the schema of the recursive term of the query
/// Since the scan method will contain an physical plan that assumes this schema
/// Construct a new CteWorkTable with the given name and self-reference schema.
pub fn new(name: &str, table_schema: SchemaRef) -> Self {
Self {
name: name.to_owned(),
Expand All @@ -56,7 +58,7 @@ impl CteWorkTable {
&self.name
}

/// The schema of the recursive term of the query
/// The schema exposed by scans of the recursive self-reference.
pub fn schema(&self) -> SchemaRef {
Arc::clone(&self.table_schema)
}
Expand Down
6 changes: 5 additions & 1 deletion datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1785,11 +1785,15 @@ impl DefaultPhysicalPlanner {
}
}
LogicalPlan::RecursiveQuery(RecursiveQuery {
name, is_distinct, ..
name,
is_distinct,
schema,
..
}) => {
let [static_term, recursive_term] = children.two()?;
Arc::new(RecursiveQueryExec::try_new(
name.clone(),
Arc::clone(schema.inner()),
static_term,
recursive_term,
*is_distinct,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/explain_analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,7 @@ async fn parquet_recursive_projection_pushdown() -> Result<()> {
SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false]
RecursiveQueryExec: name=number_series, is_distinct=false
CoalescePartitionsExec
ProjectionExec: expr=[id@0 as id, 1 as level]
ProjectionExec: expr=[CAST(id@0 AS Int64) as id, CAST(1 AS Int64) as level]
FilterExec: id@0 = 1
RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1
DataSourceExec: file_groups={1 group: [[TMP_DIR/hierarchy.parquet]]}, projection=[id], file_type=parquet, predicate=id@0 = 1, pruning_predicate=id_null_count@2 != row_count@3 AND id_min@0 <= 1 AND 1 <= id_max@1, required_guarantees=[id in (1)]
Expand Down
9 changes: 5 additions & 4 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,13 @@ impl LogicalPlanBuilder {
// Ensure that the recursive term has the same field types as the static term
let coerced_recursive_term =
coerce_plan_expr_for_schema(recursive_term, self.plan.schema())?;
Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery {
let recursive_query = RecursiveQuery::try_new(
name,
static_term: self.plan,
recursive_term: Arc::new(coerced_recursive_term),
self.plan,
Arc::new(coerced_recursive_term),
is_distinct,
})))
)?;
Ok(Self::from(LogicalPlan::RecursiveQuery(recursive_query)))
}

/// Create a values list based relation, and the schema is inferred from data, consuming
Expand Down
181 changes: 169 additions & 12 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,7 @@ impl LogicalPlan {
LogicalPlan::Copy(CopyTo { output_schema, .. }) => output_schema,
LogicalPlan::Ddl(ddl) => ddl.schema(),
LogicalPlan::Unnest(Unnest { schema, .. }) => schema,
LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => {
// we take the schema of the static term as the schema of the entire recursive query
static_term.schema()
}
LogicalPlan::RecursiveQuery(RecursiveQuery { schema, .. }) => schema,
}
}

Expand Down Expand Up @@ -741,7 +738,14 @@ impl LogicalPlan {
};
Ok(LogicalPlan::Distinct(distinct))
}
LogicalPlan::RecursiveQuery(_) => Ok(self),
LogicalPlan::RecursiveQuery(RecursiveQuery {
name,
static_term,
recursive_term,
is_distinct,
schema: _,
}) => RecursiveQuery::try_new(name, static_term, recursive_term, is_distinct)
.map(LogicalPlan::RecursiveQuery),
LogicalPlan::Analyze(_) => Ok(self),
LogicalPlan::Explain(_) => Ok(self),
LogicalPlan::TableScan(_) => Ok(self),
Expand Down Expand Up @@ -1081,12 +1085,13 @@ impl LogicalPlan {
}) => {
self.assert_no_expressions(expr)?;
let (static_term, recursive_term) = self.only_two_inputs(inputs)?;
Ok(LogicalPlan::RecursiveQuery(RecursiveQuery {
name: name.clone(),
static_term: Arc::new(static_term),
recursive_term: Arc::new(recursive_term),
is_distinct: *is_distinct,
}))
RecursiveQuery::try_new(
name.clone(),
Arc::new(static_term),
Arc::new(recursive_term),
*is_distinct,
)
.map(LogicalPlan::RecursiveQuery)
}
LogicalPlan::Analyze(a) => {
self.assert_no_expressions(expr)?;
Expand Down Expand Up @@ -2258,7 +2263,7 @@ impl PartialOrd for EmptyRelation {
/// intermediate table, then empty the intermediate table.
///
/// [Postgres Docs]: https://www.postgresql.org/docs/current/queries-with.html#QUERIES-WITH-RECURSIVE
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct RecursiveQuery {
/// Name of the query
pub name: String,
Expand All @@ -2270,6 +2275,90 @@ pub struct RecursiveQuery {
/// Should the output of the recursive term be deduplicated (`UNION`) or
/// not (`UNION ALL`).
pub is_distinct: bool,
/// Schema exposed to parent plans after reconciling the static and recursive terms.
pub schema: DFSchemaRef,
}

impl PartialOrd for RecursiveQuery {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match self.name.partial_cmp(&other.name) {
Some(Ordering::Equal) => {
match self.static_term.partial_cmp(&other.static_term) {
Some(Ordering::Equal) => {
match self.recursive_term.partial_cmp(&other.recursive_term) {
Some(Ordering::Equal) => {
self.is_distinct.partial_cmp(&other.is_distinct)
}
cmp => cmp,
}
}
cmp => cmp,
}
}
cmp => cmp,
}
// If the query definition compares equal but the derived schema differs,
// return `None` instead of contradicting `PartialEq` with `Some(Equal)`.
// TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields
.filter(|cmp| *cmp != Ordering::Equal || self == other)
}
}

impl RecursiveQuery {
pub fn try_new(
name: String,
static_term: Arc<LogicalPlan>,
recursive_term: Arc<LogicalPlan>,
is_distinct: bool,
) -> Result<Self> {
let schema =
recursive_query_output_schema(static_term.schema(), recursive_term.schema())?;
Ok(Self {
name,
static_term,
recursive_term,
is_distinct,
schema,
})
}
}

/// Compute a recursive query's output schema by considering both its static and
/// recursive terms.
///
/// Field names, types, and metadata come from the static term. A field is
/// nullable if either the static or the recursive term produces a nullable
/// value in that position, matching how `UNION` reconciles branch nullability.
///
/// Functional dependencies are intentionally dropped: the recursive term
/// appends rows that can duplicate values the static term guarantees unique, so
/// any FDs carried by the static term may not hold over the combined output.
fn recursive_query_output_schema(
static_schema: &DFSchemaRef,
recursive_schema: &DFSchemaRef,
) -> Result<DFSchemaRef> {
if static_schema.fields().len() != recursive_schema.fields().len() {
return Err(DataFusionError::Plan(format!(
"Non-recursive term and recursive term must have the same number of columns ({} != {})",
static_schema.fields().len(),
recursive_schema.fields().len()
)));
}

let fields = static_schema
.iter()
.zip(recursive_schema.fields())
.map(|((qualifier, static_field), recursive_field)| {
let nullable = static_field.is_nullable() || recursive_field.is_nullable();
(
qualifier.cloned(),
static_field.as_ref().clone().with_nullable(nullable).into(),
)
})
.collect::<Vec<_>>();

DFSchema::new_with_metadata(fields, static_schema.metadata().clone())
.map(DFSchemaRef::new)
}

/// Values expression. See
Expand Down Expand Up @@ -4613,6 +4702,74 @@ mod tests {
.build()
}

fn recursive_term_scan(name: &str, fields: Vec<Field>) -> Result<Arc<LogicalPlan>> {
Ok(Arc::new(
table_scan(Some(name), &Schema::new(fields), None)?.build()?,
))
}

#[test]
fn recursive_query_widens_nullability_per_column() -> Result<()> {
// Column `a` is non-nullable in both terms and must stay non-nullable;
// column `b` is non-nullable in the static term but nullable in the
// recursive term, so the output must widen it to nullable.
let static_term = recursive_term_scan(
"static",
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
],
)?;
let recursive_term = recursive_term_scan(
"rec",
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, true),
],
)?;

let query =
RecursiveQuery::try_new("t".to_string(), static_term, recursive_term, false)?;

// Names and types are taken from the static term.
assert_eq!(query.schema.field(0).name(), "a");
assert_eq!(query.schema.field(1).name(), "b");
assert_eq!(query.schema.field(0).data_type(), &DataType::Int32);
assert_eq!(query.schema.field(1).data_type(), &DataType::Int32);
// Nullability is widened independently per column.
assert!(!query.schema.field(0).is_nullable());
assert!(query.schema.field(1).is_nullable());
// `schema()` returns the widened recursive-query schema.
assert_eq!(
LogicalPlan::RecursiveQuery(query.clone()).schema(),
&query.schema
);
Ok(())
}

#[test]
fn recursive_query_rejects_column_count_mismatch() -> Result<()> {
let static_term =
recursive_term_scan("static", vec![Field::new("a", DataType::Int32, false)])?;
let recursive_term = recursive_term_scan(
"rec",
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
],
)?;

let err =
RecursiveQuery::try_new("t".to_string(), static_term, recursive_term, false)
.unwrap_err();
assert!(
err.strip_backtrace()
.contains("must have the same number of columns"),
"unexpected error: {err}"
);
Ok(())
}

#[test]
fn test_display_indent() -> Result<()> {
let plan = display_plan()?;
Expand Down
5 changes: 5 additions & 0 deletions datafusion/expr/src/logical_plan/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,13 +329,18 @@ impl TreeNode for LogicalPlan {
static_term,
recursive_term,
is_distinct,
schema,
}) => (static_term, recursive_term).map_elements(f)?.update_data(
|(static_term, recursive_term)| {
// Ordinary child rewrites preserve derived schemas. Call
// `LogicalPlan::recompute_schema` when child schemas should
// be reconciled again.
LogicalPlan::RecursiveQuery(RecursiveQuery {
name,
static_term,
recursive_term,
is_distinct,
schema,
})
},
),
Expand Down
6 changes: 5 additions & 1 deletion datafusion/expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ pub trait ContextProvider {
not_impl_err!("Table Functions are not supported")
}

/// Provides an intermediate table that is used to store the results of a CTE during execution
/// Provides an intermediate table that is used to expose a recursive CTE
/// self-reference during planning and execution.
///
/// CTE stands for "Common Table Expression"
///
Expand All @@ -72,6 +73,9 @@ pub trait ContextProvider {
/// of the sql crate (for example [`CteWorkTable`]).
///
/// The [`ContextProvider`] provides a way to "hide" this dependency.
/// The schema argument is the schema to expose for scans of the recursive
/// self-reference, which may be more conservative than the final recursive
/// query output schema.
///
/// [`SqlToRel`]: https://docs.rs/datafusion/latest/datafusion/sql/planner/struct.SqlToRel.html
/// [`CteWorkTable`]: https://docs.rs/datafusion/latest/datafusion/datasource/cte_worktable/struct.CteWorkTable.html
Expand Down
Loading
Loading