From 27ceec775ace5766d1c926d5f6cd46c18cb54705 Mon Sep 17 00:00:00 2001 From: Vince Date: Tue, 7 Feb 2023 14:51:36 +0100 Subject: [PATCH 1/9] Add support to SQL parser for unnest function. --- datafusion/expr/src/built_in_function.rs | 4 ++++ datafusion/expr/src/function.rs | 11 +++++++++++ datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 3 +++ datafusion/proto/src/generated/prost.rs | 3 +++ .../proto/src/logical_plan/from_proto.rs | 1 + datafusion/proto/src/logical_plan/to_proto.rs | 1 + datafusion/sql/tests/integration_test.rs | 19 +++++++++++++++++++ 8 files changed, 43 insertions(+) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 8ea96f185cb03..8e56abd263d79 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -174,6 +174,8 @@ pub enum BuiltinScalarFunction { Struct, /// arrow_typeof ArrowTypeof, + /// unnest + Unnest, } impl BuiltinScalarFunction { @@ -261,6 +263,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Struct => Volatility::Immutable, BuiltinScalarFunction::FromUnixtime => Volatility::Immutable, BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable, + BuiltinScalarFunction::Unnest => Volatility::Immutable, // Stable builtin functions BuiltinScalarFunction::Now => Volatility::Stable, @@ -312,6 +315,7 @@ impl FromStr for BuiltinScalarFunction { // array functions "make_array" => BuiltinScalarFunction::MakeArray, + "unnest" => BuiltinScalarFunction::Unnest, // string functions "ascii" => BuiltinScalarFunction::Ascii, diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 50cc6bcd6c823..0a0c3c6f3c1e5 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -257,6 +257,16 @@ pub fn return_type( BuiltinScalarFunction::ArrowTypeof => Ok(DataType::Utf8), + BuiltinScalarFunction::Unnest => match &input_expr_types[0] { + // Unnest function called on + DataType::List(field) + | DataType::FixedSizeList(field, _) + | DataType::LargeList(field) => Ok(field.data_type().clone()), + _ => Err(DataFusionError::Internal( + "The unnest function only accepts list columns.".to_string(), + )), + }, + BuiltinScalarFunction::Abs | BuiltinScalarFunction::Acos | BuiltinScalarFunction::Asin @@ -608,6 +618,7 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature { fun.volatility(), ), BuiltinScalarFunction::ArrowTypeof => Signature::any(1, fun.volatility()), + BuiltinScalarFunction::Unnest => Signature::any(1, fun.volatility()), // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we // return the best approximation for it (in f64). diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 8a9a8c0f6d748..be81f410cc536 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -519,6 +519,7 @@ enum ScalarFunction { CurrentDate = 70; CurrentTime = 71; Uuid = 72; + Unnest = 73; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 65d7af706b75e..0d1f4f1285ae9 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -16821,6 +16821,7 @@ impl serde::Serialize for ScalarFunction { Self::CurrentDate => "CurrentDate", Self::CurrentTime => "CurrentTime", Self::Uuid => "Uuid", + Self::Unnest => "Unnest", }; serializer.serialize_str(variant) } @@ -16905,6 +16906,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "CurrentDate", "CurrentTime", "Uuid", + "Unnest", ]; struct GeneratedVisitor; @@ -17020,6 +17022,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "CurrentDate" => Ok(ScalarFunction::CurrentDate), "CurrentTime" => Ok(ScalarFunction::CurrentTime), "Uuid" => Ok(ScalarFunction::Uuid), + "Unnest" => Ok(ScalarFunction::Unnest), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 21816f5a90221..55c671ebf5542 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2100,6 +2100,7 @@ pub enum ScalarFunction { CurrentDate = 70, CurrentTime = 71, Uuid = 72, + Unnest = 73, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2181,6 +2182,7 @@ impl ScalarFunction { ScalarFunction::CurrentDate => "CurrentDate", ScalarFunction::CurrentTime => "CurrentTime", ScalarFunction::Uuid => "Uuid", + ScalarFunction::Unnest => "Unnest", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2259,6 +2261,7 @@ impl ScalarFunction { "CurrentDate" => Some(Self::CurrentDate), "CurrentTime" => Some(Self::CurrentTime), "Uuid" => Some(Self::Uuid), + "Unnest" => Some(Self::Unnest), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 5aaf79997b638..d57f19d783dc2 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -472,6 +472,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::FromUnixtime => Self::FromUnixtime, ScalarFunction::Atan2 => Self::Atan2, ScalarFunction::ArrowTypeof => Self::ArrowTypeof, + ScalarFunction::Unnest => Self::Unnest, } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 90c99c0d96bd4..915e7b8f78d39 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1286,6 +1286,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::FromUnixtime => Self::FromUnixtime, BuiltinScalarFunction::Atan2 => Self::Atan2, BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof, + BuiltinScalarFunction::Unnest => Self::Unnest, }; Ok(scalar_function) diff --git a/datafusion/sql/tests/integration_test.rs b/datafusion/sql/tests/integration_test.rs index 761bbf345b2ca..2232a0e8e243b 100644 --- a/datafusion/sql/tests/integration_test.rs +++ b/datafusion/sql/tests/integration_test.rs @@ -2399,6 +2399,16 @@ fn select_groupby_orderby() { quick_test(sql, expected); } +#[test] +fn unnest_column() { + let sql = r#"SELECT first_name, last_name, UNNEST(aliases) FROM unnested_test"#; + // expect projection with unnest call on aliases. + let expected = + "Projection: unnested_test.first_name, unnested_test.last_name, unnest(unnested_test.aliases)\ + \n TableScan: unnested_test"; + quick_test(sql, expected); +} + fn logical_plan(sql: &str) -> Result { logical_plan_with_options(sql, ParserOptions::default()) } @@ -2542,6 +2552,15 @@ impl ContextProvider for MockContextProvider { Field::new("c12", DataType::Float64, false), Field::new("c13", DataType::Utf8, false), ])), + "unnested_test" => Ok(Schema::new(vec![ + Field::new("first_name", DataType::Utf8, false), + Field::new("last_name", DataType::Utf8, false), + Field::new( + "aliases", + DataType::List(Box::new(Field::new("item", DataType::Utf8, false))), + false, + ), + ])), _ => Err(DataFusionError::Plan(format!( "No table named: {} found", name.table() From a6866dc4c3fc4fb6c1222cfd8639030d3f8b8384 Mon Sep 17 00:00:00 2001 From: Vince Date: Thu, 9 Feb 2023 12:12:02 +0100 Subject: [PATCH 2/9] Revert scalar function changes --- datafusion/expr/src/built_in_function.rs | 4 ---- datafusion/expr/src/function.rs | 11 ----------- datafusion/proto/proto/datafusion.proto | 1 - datafusion/proto/src/generated/pbjson.rs | 3 --- datafusion/proto/src/generated/prost.rs | 3 --- datafusion/proto/src/logical_plan/from_proto.rs | 1 - 6 files changed, 23 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 8e56abd263d79..8ea96f185cb03 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -174,8 +174,6 @@ pub enum BuiltinScalarFunction { Struct, /// arrow_typeof ArrowTypeof, - /// unnest - Unnest, } impl BuiltinScalarFunction { @@ -263,7 +261,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Struct => Volatility::Immutable, BuiltinScalarFunction::FromUnixtime => Volatility::Immutable, BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable, - BuiltinScalarFunction::Unnest => Volatility::Immutable, // Stable builtin functions BuiltinScalarFunction::Now => Volatility::Stable, @@ -315,7 +312,6 @@ impl FromStr for BuiltinScalarFunction { // array functions "make_array" => BuiltinScalarFunction::MakeArray, - "unnest" => BuiltinScalarFunction::Unnest, // string functions "ascii" => BuiltinScalarFunction::Ascii, diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 0a0c3c6f3c1e5..50cc6bcd6c823 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -257,16 +257,6 @@ pub fn return_type( BuiltinScalarFunction::ArrowTypeof => Ok(DataType::Utf8), - BuiltinScalarFunction::Unnest => match &input_expr_types[0] { - // Unnest function called on - DataType::List(field) - | DataType::FixedSizeList(field, _) - | DataType::LargeList(field) => Ok(field.data_type().clone()), - _ => Err(DataFusionError::Internal( - "The unnest function only accepts list columns.".to_string(), - )), - }, - BuiltinScalarFunction::Abs | BuiltinScalarFunction::Acos | BuiltinScalarFunction::Asin @@ -618,7 +608,6 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature { fun.volatility(), ), BuiltinScalarFunction::ArrowTypeof => Signature::any(1, fun.volatility()), - BuiltinScalarFunction::Unnest => Signature::any(1, fun.volatility()), // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we // return the best approximation for it (in f64). diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index be81f410cc536..8a9a8c0f6d748 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -519,7 +519,6 @@ enum ScalarFunction { CurrentDate = 70; CurrentTime = 71; Uuid = 72; - Unnest = 73; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 0d1f4f1285ae9..65d7af706b75e 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -16821,7 +16821,6 @@ impl serde::Serialize for ScalarFunction { Self::CurrentDate => "CurrentDate", Self::CurrentTime => "CurrentTime", Self::Uuid => "Uuid", - Self::Unnest => "Unnest", }; serializer.serialize_str(variant) } @@ -16906,7 +16905,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "CurrentDate", "CurrentTime", "Uuid", - "Unnest", ]; struct GeneratedVisitor; @@ -17022,7 +17020,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "CurrentDate" => Ok(ScalarFunction::CurrentDate), "CurrentTime" => Ok(ScalarFunction::CurrentTime), "Uuid" => Ok(ScalarFunction::Uuid), - "Unnest" => Ok(ScalarFunction::Unnest), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 55c671ebf5542..21816f5a90221 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2100,7 +2100,6 @@ pub enum ScalarFunction { CurrentDate = 70, CurrentTime = 71, Uuid = 72, - Unnest = 73, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2182,7 +2181,6 @@ impl ScalarFunction { ScalarFunction::CurrentDate => "CurrentDate", ScalarFunction::CurrentTime => "CurrentTime", ScalarFunction::Uuid => "Uuid", - ScalarFunction::Unnest => "Unnest", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2261,7 +2259,6 @@ impl ScalarFunction { "CurrentDate" => Some(Self::CurrentDate), "CurrentTime" => Some(Self::CurrentTime), "Uuid" => Some(Self::Uuid), - "Unnest" => Some(Self::Unnest), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index d57f19d783dc2..5aaf79997b638 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -472,7 +472,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::FromUnixtime => Self::FromUnixtime, ScalarFunction::Atan2 => Self::Atan2, ScalarFunction::ArrowTypeof => Self::ArrowTypeof, - ScalarFunction::Unnest => Self::Unnest, } } } From 4804d20cf1a907d4fe83f8e081a99d57053b33ed Mon Sep 17 00:00:00 2001 From: Vince Date: Thu, 9 Feb 2023 12:52:48 +0100 Subject: [PATCH 3/9] Revert scalar function changes --- datafusion/proto/src/logical_plan/to_proto.rs | 1 - datafusion/sql/tests/integration_test.rs | 19 ------------------- 2 files changed, 20 deletions(-) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 915e7b8f78d39..90c99c0d96bd4 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1286,7 +1286,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::FromUnixtime => Self::FromUnixtime, BuiltinScalarFunction::Atan2 => Self::Atan2, BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof, - BuiltinScalarFunction::Unnest => Self::Unnest, }; Ok(scalar_function) diff --git a/datafusion/sql/tests/integration_test.rs b/datafusion/sql/tests/integration_test.rs index b5a384169466e..38ede4e6bbb56 100644 --- a/datafusion/sql/tests/integration_test.rs +++ b/datafusion/sql/tests/integration_test.rs @@ -2399,16 +2399,6 @@ fn select_groupby_orderby() { quick_test(sql, expected); } -#[test] -fn unnest_column() { - let sql = r#"SELECT first_name, last_name, UNNEST(aliases) FROM unnested_test"#; - // expect projection with unnest call on aliases. - let expected = - "Projection: unnested_test.first_name, unnested_test.last_name, unnest(unnested_test.aliases)\ - \n TableScan: unnested_test"; - quick_test(sql, expected); -} - fn logical_plan(sql: &str) -> Result { logical_plan_with_options(sql, ParserOptions::default()) } @@ -2552,15 +2542,6 @@ impl ContextProvider for MockContextProvider { Field::new("c12", DataType::Float64, false), Field::new("c13", DataType::Utf8, false), ])), - "unnested_test" => Ok(Schema::new(vec![ - Field::new("first_name", DataType::Utf8, false), - Field::new("last_name", DataType::Utf8, false), - Field::new( - "aliases", - DataType::List(Box::new(Field::new("item", DataType::Utf8, false))), - false, - ), - ])), _ => Err(DataFusionError::Plan(format!( "No table named: {} found", name.table() From 73ad14789a40ba26fcbbd62f1ffc416b3a3c1945 Mon Sep 17 00:00:00 2001 From: Vince Date: Thu, 9 Feb 2023 16:01:22 +0100 Subject: [PATCH 4/9] Add support for SQL unnest --- datafusion/sql/src/select.rs | 92 +++++++++++++++++++++++++++++++++++- 1 file changed, 90 insertions(+), 2 deletions(-) diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index c6317747856f3..2e42d0f25e4ba 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -33,7 +33,9 @@ use datafusion_expr::Expr::Alias; use datafusion_expr::{ Expr, Filter, GroupingSet, LogicalPlan, LogicalPlanBuilder, Partitioning, }; -use sqlparser::ast::{Expr as SQLExpr, WildcardAdditionalOptions}; +use sqlparser::ast::{ + Expr as SQLExpr, FunctionArg, FunctionArgExpr, WildcardAdditionalOptions, +}; use sqlparser::ast::{Select, SelectItem, TableWithJoins}; use std::collections::HashSet; use std::sync::Arc; @@ -46,6 +48,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context: &mut PlannerContext, outer_query_schema: Option<&DFSchema>, ) -> Result { + println!("select_to_plan: {select:#?}"); + // check for unsupported syntax first if !select.cluster_by.is_empty() { return Err(DataFusionError::NotImplemented("CLUSTER BY".to_string())); @@ -76,10 +80,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &from_schema, )?; + // Extract unnest projection if present. + let (projection, unnest) = unnest_projection(select.projection)?; + // process the SELECT expressions, with wildcards expanded. let select_exprs = self.prepare_select_exprs( &plan, - select.projection, + projection, + empty_from, + planner_context, + &from_schema, + )?; + + let unnest_exprs = self.prepare_select_exprs( + &plan, + unnest, empty_from, planner_context, &from_schema, @@ -87,6 +102,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // having and group by clause may reference aliases defined in select projection let projected_plan = self.project(plan.clone(), select_exprs.clone())?; + let mut combined_schema = (**projected_plan.schema()).clone(); combined_schema.merge(plan.schema()); @@ -204,6 +220,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // final projection let plan = project(plan, select_exprs_post_aggr)?; + // Apply unnesting + let plan = self.unnest(plan, unnest_exprs)?; + // process distinct clause let plan = if select.distinct { LogicalPlanBuilder::from(plan).distinct()?.build() @@ -476,6 +495,28 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { LogicalPlanBuilder::from(input).project(expr)?.build() } + /// Unnest columns from the input plan. + fn unnest(&self, mut input: LogicalPlan, exprs: Vec) -> Result { + if exprs.is_empty() { + Ok(input) + } else { + self.validate_schema_satisfies_exprs(input.schema(), &exprs)?; + + for expr in exprs { + let Expr::Column(column) = expr else { + // This should never happen as we did validation above. + return Err(DataFusionError::Internal("Not a column".to_string())); + }; + + input = LogicalPlanBuilder::from(input) + .unnest_column(column)? + .build()?; + } + + Ok(input) + } + } + /// Create an aggregate plan. /// /// An aggregate plan consists of grouping expressions, aggregate expressions, and an @@ -586,3 +627,50 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok((plan, select_exprs_post_aggr, having_expr_post_aggr)) } } + +/// Extract unnest calls from the projection. +fn unnest_projection( + projection: Vec, +) -> Result<(Vec, Vec)> { + let mut unnest_cols = Vec::with_capacity(projection.len()); + let projection = projection + .into_iter() + .map(|expr| match expr { + SelectItem::UnnamedExpr(SQLExpr::Function(ref f)) + | SelectItem::ExprWithAlias { + expr: SQLExpr::Function(ref f), + .. + } => { + // If this is an unnest function call replace the function call with its + // column argument and pass the unnest column expression to the caller. + if f.name.to_string().to_lowercase() == "unnest" { + if f.args.len() != 1 { + Err(DataFusionError::Plan( + "Unnest must have one column argument".to_string(), + )) + } else { + // Extract the function column and put it in the projection. + match &f.args[0] { + FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)) + | FunctionArg::Named { + name: _, + arg: FunctionArgExpr::Expr(arg), + } => { + let item = SelectItem::UnnamedExpr(arg.clone()); + unnest_cols.push(item.clone()); + Ok(item) + } + arg => Err(DataFusionError::Plan(format!( + "Unsupported unnest argument: {arg:?}" + ))), + } + } + } else { + Ok(expr) + } + } + _ => Ok(expr), + }) + .collect::>>()?; + Ok((projection, unnest_cols)) +} From b4ba912f06707fe17d624ccec001e53ec7aaef23 Mon Sep 17 00:00:00 2001 From: Vince Date: Thu, 9 Feb 2023 17:20:18 +0100 Subject: [PATCH 5/9] Remove debug print --- datafusion/sql/src/select.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 2e42d0f25e4ba..bad2d8b98643b 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -48,8 +48,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context: &mut PlannerContext, outer_query_schema: Option<&DFSchema>, ) -> Result { - println!("select_to_plan: {select:#?}"); - // check for unsupported syntax first if !select.cluster_by.is_empty() { return Err(DataFusionError::NotImplemented("CLUSTER BY".to_string())); @@ -102,7 +100,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // having and group by clause may reference aliases defined in select projection let projected_plan = self.project(plan.clone(), select_exprs.clone())?; - let mut combined_schema = (**projected_plan.schema()).clone(); combined_schema.merge(plan.schema()); From e99858b17805cc64de24698b2f778188375d1614 Mon Sep 17 00:00:00 2001 From: Vince Date: Sun, 12 Feb 2023 15:47:01 +0100 Subject: [PATCH 6/9] Fix plan for aggregates on unnested columns --- datafusion/sql/src/select.rs | 111 +++++++++++++++++++++++------------ 1 file changed, 74 insertions(+), 37 deletions(-) diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index bad2d8b98643b..c867fbf1c797c 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -34,7 +34,7 @@ use datafusion_expr::{ Expr, Filter, GroupingSet, LogicalPlan, LogicalPlanBuilder, Partitioning, }; use sqlparser::ast::{ - Expr as SQLExpr, FunctionArg, FunctionArgExpr, WildcardAdditionalOptions, + Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, WildcardAdditionalOptions, }; use sqlparser::ast::{Select, SelectItem, TableWithJoins}; use std::collections::HashSet; @@ -64,6 +64,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // process `from` clause let plan = self.plan_from_tables(select.from, planner_context)?; + let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); // build from schema for unqualifier column ambiguous check // we should get only one field for unqualifier column from schema. @@ -98,8 +99,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &from_schema, )?; + // Apply unnesting + let plan = self.unnest(plan.clone(), unnest_exprs.clone())?; + // having and group by clause may reference aliases defined in select projection let projected_plan = self.project(plan.clone(), select_exprs.clone())?; + let mut combined_schema = (**projected_plan.schema()).clone(); combined_schema.merge(plan.schema()); @@ -217,9 +222,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // final projection let plan = project(plan, select_exprs_post_aggr)?; - // Apply unnesting - let plan = self.unnest(plan, unnest_exprs)?; - // process distinct clause let plan = if select.distinct { LogicalPlanBuilder::from(plan).distinct()?.build() @@ -633,41 +635,76 @@ fn unnest_projection( let projection = projection .into_iter() .map(|expr| match expr { - SelectItem::UnnamedExpr(SQLExpr::Function(ref f)) - | SelectItem::ExprWithAlias { - expr: SQLExpr::Function(ref f), - .. - } => { - // If this is an unnest function call replace the function call with its - // column argument and pass the unnest column expression to the caller. - if f.name.to_string().to_lowercase() == "unnest" { - if f.args.len() != 1 { - Err(DataFusionError::Plan( - "Unnest must have one column argument".to_string(), - )) - } else { - // Extract the function column and put it in the projection. - match &f.args[0] { - FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)) - | FunctionArg::Named { - name: _, - arg: FunctionArgExpr::Expr(arg), - } => { - let item = SelectItem::UnnamedExpr(arg.clone()); - unnest_cols.push(item.clone()); - Ok(item) - } - arg => Err(DataFusionError::Plan(format!( - "Unsupported unnest argument: {arg:?}" - ))), - } - } - } else { - Ok(expr) - } - } + SelectItem::UnnamedExpr(expr) => Ok(SelectItem::UnnamedExpr(extract_unnest( + expr, + &mut unnest_cols, + )?)), + SelectItem::ExprWithAlias { expr, alias } => Ok(SelectItem::ExprWithAlias { + expr: extract_unnest(expr, &mut unnest_cols)?, + alias, + }), _ => Ok(expr), }) .collect::>>()?; + Ok((projection, unnest_cols)) } + +fn extract_unnest(expr: SQLExpr, unnest_cols: &mut Vec) -> Result { + match expr { + SQLExpr::Nested(expr) => Ok(SQLExpr::Nested(Box::new(extract_unnest( + *expr, + unnest_cols, + )?))), + SQLExpr::Function(f) => { + let mut args = f + .args + .into_iter() + .map(|arg| match arg { + FunctionArg::Named { name, arg } => match arg { + FunctionArgExpr::Expr(expr) => Ok(FunctionArg::Named { + name, + arg: FunctionArgExpr::Expr(extract_unnest( + expr, + unnest_cols, + )?), + }), + _ => Ok(FunctionArg::Named { name, arg }), + }, + FunctionArg::Unnamed(arg) => match arg { + FunctionArgExpr::Expr(expr) => Ok(FunctionArg::Unnamed( + FunctionArgExpr::Expr(extract_unnest(expr, unnest_cols)?), + )), + _ => Ok(FunctionArg::Unnamed(arg)), + }, + }) + .collect::>>()?; + + let expr = if f.name.to_string().to_lowercase() == "unnest" { + if args.len() != 1 { + return Err(DataFusionError::Plan( + "Unnest must have one column argument".to_string(), + )); + } else { + match args.pop().unwrap() { + FunctionArg::Named { arg, .. } | FunctionArg::Unnamed(arg) => { + if let FunctionArgExpr::Expr(expr) = arg { + unnest_cols.push(SelectItem::UnnamedExpr(expr.clone())); + expr + } else { + return Err(DataFusionError::Plan( + "Invalid unnest argument".to_string(), + )); + } + } + } + } + } else { + SQLExpr::Function(Function { args, ..f }) + }; + + Ok(expr) + } + _ => Ok(expr), + } +} From 40c5da0ffa146d7ff690419514b77d3e68cf6de8 Mon Sep 17 00:00:00 2001 From: Vince Date: Sun, 12 Feb 2023 16:25:33 +0100 Subject: [PATCH 7/9] Fix clippy warning. --- datafusion/sql/src/select.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index c867fbf1c797c..40d62d726ee17 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -100,7 +100,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )?; // Apply unnesting - let plan = self.unnest(plan.clone(), unnest_exprs.clone())?; + let plan = self.unnest(plan.clone(), unnest_exprs)?; // having and group by clause may reference aliases defined in select projection let projected_plan = self.project(plan.clone(), select_exprs.clone())?; From c4a5f4d2584a0ba753e466c5a6bc39dabd051e00 Mon Sep 17 00:00:00 2001 From: Vince Date: Mon, 13 Feb 2023 14:21:58 +0100 Subject: [PATCH 8/9] Add sql unnest tests --- datafusion/core/tests/sql/mod.rs | 1 + datafusion/core/tests/sql/unnest.rs | 263 ++++++++++++++++++++++++++++ 2 files changed, 264 insertions(+) create mode 100644 datafusion/core/tests/sql/unnest.rs diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 4b3c60d7e01f6..a629059c2a217 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -101,6 +101,7 @@ pub mod select; pub mod timestamp; pub mod udf; pub mod union; +pub mod unnest; pub mod wildcard; pub mod window; diff --git a/datafusion/core/tests/sql/unnest.rs b/datafusion/core/tests/sql/unnest.rs new file mode 100644 index 0000000000000..485bb8d4ce58c --- /dev/null +++ b/datafusion/core/tests/sql/unnest.rs @@ -0,0 +1,263 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn unnest_single_columns() -> Result<()> { + let ctx = create_nested_context().await?; + + let sql = "select * from shapes"; + let results = execute_to_batches(&ctx, sql).await; + let expected = vec![ + r#"+----------+-----------------------------------------------------------+--------------------+"#, + r#"| shape_id | points | tags |"#, + r#"+----------+-----------------------------------------------------------+--------------------+"#, + r#"| 1 | [{"x": 4, "y": -7}, {"x": 9, "y": 6}] | [tag1, tag2, tag3] |"#, + r#"| 2 | | [tag1, tag2] |"#, + r#"| 3 | [{"x": -1, "y": 5}, {"x": -7, "y": 0}] | [tag1] |"#, + r#"| 4 | [{"x": 3, "y": -4}, {"x": -6, "y": -8}] | [tag1, tag2] |"#, + r#"| 5 | [{"x": 5, "y": -3}, {"x": -4, "y": 3}, {"x": -6, "y": 0}] | |"#, + r#"+----------+-----------------------------------------------------------+--------------------+"#, + ]; + assert_batches_eq!(expected, &results); + + let sql = "select shape_id, unnest(tags) from shapes"; + let results = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----------+------+", + "| shape_id | tags |", + "+----------+------+", + "| 1 | tag1 |", + "| 1 | tag2 |", + "| 1 | tag3 |", + "| 2 | tag1 |", + "| 2 | tag2 |", + "| 3 | tag1 |", + "| 4 | tag1 |", + "| 4 | tag2 |", + "| 5 | |", + "+----------+------+", + ]; + assert_batches_eq!(expected, &results); + + let sql = "select shape_id, unnest(points) from shapes"; + + let results = execute_to_batches(&ctx, sql).await; + let expected = vec![ + r#"+----------+--------------------+"#, + r#"| shape_id | points |"#, + r#"+----------+--------------------+"#, + r#"| 1 | {"x": 4, "y": -7} |"#, + r#"| 1 | {"x": 9, "y": 6} |"#, + r#"| 2 | |"#, + r#"| 3 | {"x": -1, "y": 5} |"#, + r#"| 3 | {"x": -7, "y": 0} |"#, + r#"| 4 | {"x": 3, "y": -4} |"#, + r#"| 4 | {"x": -6, "y": -8} |"#, + r#"| 5 | {"x": 5, "y": -3} |"#, + r#"| 5 | {"x": -4, "y": 3} |"#, + r#"| 5 | {"x": -6, "y": 0} |"#, + r#"+----------+--------------------+"#, + ]; + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn unnest_multiple_columns() -> Result<()> { + let ctx = create_nested_context().await?; + + let sql = "select shape_id, unnest(points), unnest(tags) from shapes"; + let results = execute_to_batches(&ctx, sql).await; + let expected = vec![ + r#"+----------+--------------------+------+"#, + r#"| shape_id | points | tags |"#, + r#"+----------+--------------------+------+"#, + r#"| 1 | {"x": 4, "y": -7} | tag1 |"#, + r#"| 1 | {"x": 4, "y": -7} | tag2 |"#, + r#"| 1 | {"x": 4, "y": -7} | tag3 |"#, + r#"| 1 | {"x": 9, "y": 6} | tag1 |"#, + r#"| 1 | {"x": 9, "y": 6} | tag2 |"#, + r#"| 1 | {"x": 9, "y": 6} | tag3 |"#, + r#"| 2 | | tag1 |"#, + r#"| 2 | | tag2 |"#, + r#"| 3 | {"x": -1, "y": 5} | tag1 |"#, + r#"| 3 | {"x": -7, "y": 0} | tag1 |"#, + r#"| 4 | {"x": 3, "y": -4} | tag1 |"#, + r#"| 4 | {"x": 3, "y": -4} | tag2 |"#, + r#"| 4 | {"x": -6, "y": -8} | tag1 |"#, + r#"| 4 | {"x": -6, "y": -8} | tag2 |"#, + r#"| 5 | {"x": 5, "y": -3} | |"#, + r#"| 5 | {"x": -4, "y": 3} | |"#, + r#"| 5 | {"x": -6, "y": 0} | |"#, + r#"+----------+--------------------+------+"#, + ]; + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn unnest_distinct() -> Result<()> { + let ctx = create_nested_context().await?; + + let sql = "select distinct(unnest(tags)) from shapes order by 1;"; + let results = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------+", + "| tags |", + "+------+", + "| tag1 |", + "| tag2 |", + "| tag3 |", + "| |", + "+------+", + ]; + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn unnest_group_by() -> Result<()> { + let ctx = create_nested_context().await?; + + let sql = "select shape_id, count(unnest(tags)) as tag \ + from shapes \ + group by shape_id \ + order by 1"; + let results = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----------+-----+", + "| shape_id | tag |", + "+----------+-----+", + "| 1 | 3 |", + "| 2 | 2 |", + "| 3 | 1 |", + "| 4 | 2 |", + "| 5 | 0 |", + "+----------+-----+", + ]; + assert_batches_eq!(expected, &results); + + let sql = "select shape_id, count(unnest(points)) as point \ + from shapes \ + group by shape_id \ + order by 1"; + let results = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----------+-------+", + "| shape_id | point |", + "+----------+-------+", + "| 1 | 2 |", + "| 2 | 0 |", + "| 3 | 2 |", + "| 4 | 2 |", + "| 5 | 3 |", + "+----------+-------+", + ]; + assert_batches_eq!(expected, &results); + + let sql = "select unnest(tags) as tag, count(*) \ + from shapes \ + group by tag \ + order by 1"; + let results = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------+-----------------+", + "| tag | COUNT(UInt8(1)) |", + "+------+-----------------+", + "| tag1 | 4 |", + "| tag2 | 3 |", + "| tag3 | 1 |", + "| | 1 |", + "+------+-----------------+", + ]; + assert_batches_eq!(expected, &results); + + Ok(()) +} + +/// Create a context that contains nested types. +/// +/// Create a data frame with nested types, each row contains: +/// - shape_id an integer primary key +/// - points A list of points structs {x, y} +/// - A list of tags choosen at random from tag1 to tag10. +async fn create_nested_context() -> Result { + use rand::prelude::*; + const NUM_ROWS: usize = 5; + + let mut shape_id_builder = UInt32Builder::new(); + let mut points_builder = ListBuilder::new(StructBuilder::from_fields( + vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Int32, false), + ], + 5, + )); + let mut tags_builder = ListBuilder::new(StringBuilder::new()); + + let mut rng = StdRng::seed_from_u64(97); + + for idx in 0..NUM_ROWS { + // Append shape id. + shape_id_builder.append_value(idx as u32 + 1); + + // Add a random number of points + let num_points: usize = rng.gen_range(0..4); + if num_points > 0 { + for _ in 0..num_points.max(2) { + // Add x value + points_builder + .values() + .field_builder::(0) + .unwrap() + .append_value(rng.gen_range(-10..10)); + // Add y value + points_builder + .values() + .field_builder::(1) + .unwrap() + .append_value(rng.gen_range(-10..10)); + points_builder.values().append(true); + } + } + + // Append null if num points is 0. + points_builder.append(num_points > 0); + + // Append tags. + let num_tags: usize = rng.gen_range(0..5); + for id in 0..num_tags { + tags_builder.values().append_value(format!("tag{}", id + 1)); + } + tags_builder.append(num_tags > 0); + } + + let batch = RecordBatch::try_from_iter(vec![ + ("shape_id", Arc::new(shape_id_builder.finish()) as ArrayRef), + ("points", Arc::new(points_builder.finish()) as ArrayRef), + ("tags", Arc::new(tags_builder.finish()) as ArrayRef), + ])?; + + let ctx = SessionContext::new(); + ctx.register_batch("shapes", batch)?; + Ok(ctx) +} From c6287d6b18b5648a46ac2f10aa0eeac5afd7a6fe Mon Sep 17 00:00:00 2001 From: Vince Date: Mon, 13 Feb 2023 14:27:34 +0100 Subject: [PATCH 9/9] Fix rustfmt --- datafusion/core/tests/sql/unnest.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/datafusion/core/tests/sql/unnest.rs b/datafusion/core/tests/sql/unnest.rs index 485bb8d4ce58c..23c1ff6fabf21 100644 --- a/datafusion/core/tests/sql/unnest.rs +++ b/datafusion/core/tests/sql/unnest.rs @@ -119,6 +119,9 @@ async fn unnest_distinct() -> Result<()> { let sql = "select distinct(unnest(tags)) from shapes order by 1;"; let results = execute_to_batches(&ctx, sql).await; + + // rustfmt makes it hard to read. + #[rustfmt::skip] let expected = vec![ "+------+", "| tags |",