Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
first draft
  • Loading branch information
jayzhan211 committed Mar 8, 2025
commit 364e3a478cb6639ef58c995ea4a7da0de6c7a394
36 changes: 0 additions & 36 deletions datafusion/optimizer/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,16 +344,6 @@ fn test_propagate_empty_relation_inner_join_and_unions() {
assert_eq!(expected, format!("{plan}"));
}

#[test]
fn select_wildcard_with_repeated_column() {
let sql = "SELECT *, col_int32 FROM test";
let err = test_sql(sql).expect_err("query should have failed");
assert_eq!(
"Schema error: Schema contains duplicate qualified field name test.col_int32",
err.strip_backtrace()
);
}

#[test]
fn select_wildcard_with_repeated_column_but_is_aliased() {
let sql = "SELECT *, col_int32 as col_32 FROM test";
Expand Down Expand Up @@ -390,32 +380,6 @@ fn select_correlated_predicate_subquery_with_uppercase_ident() {
assert_eq!(expected, format!("{plan}"));
}

// The test should return an error
// because the wildcard didn't be expanded before type coercion
#[test]
fn test_union_coercion_with_wildcard() -> Result<()> {
let dialect = PostgreSqlDialect {};
let context_provider = MyContextProvider::default();
let sql = "select * from (SELECT col_int32, col_uint32 FROM test) union all select * from(SELECT col_uint32, col_int32 FROM test)";
let statements = Parser::parse_sql(&dialect, sql)?;
let sql_to_rel = SqlToRel::new(&context_provider);
let logical_plan = sql_to_rel.sql_statement_to_plan(statements[0].clone())?;

if let LogicalPlan::Union(union) = logical_plan {
let err = TypeCoercionRewriter::coerce_union(union)
.err()
.unwrap()
.to_string();
assert_contains!(
err,
"Error during planning: Wildcard should be expanded before type coercion"
);
} else {
panic!("Expected Union plan");
}
Ok(())
}

fn test_sql(sql: &str) -> Result<LogicalPlan> {
// parse the SQL
let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...
Expand Down
71 changes: 62 additions & 9 deletions datafusion/sql/src/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ use crate::utils::{

use datafusion_common::error::DataFusionErrorBuilder;
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_common::{not_impl_err, plan_err, Result};
use datafusion_common::{not_impl_err, plan_err, Column, Result};
use datafusion_common::{RecursionUnnestOption, UnnestOptions};
use datafusion_expr::expr::{Alias, PlannedReplaceSelectItem, WildcardOptions};
use datafusion_expr::expr_rewriter::{
normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_sorts,
};
use datafusion_expr::utils::{
expr_as_column_expr, expr_to_columns, find_aggregate_exprs, find_window_exprs,
expand_qualified_wildcard, expand_wildcard, expr_as_column_expr, expr_to_columns, find_aggregate_exprs, find_window_exprs
};
use datafusion_expr::{
qualified_wildcard_with_options, wildcard_with_options, Aggregate, Expr, Filter,
Expand Down Expand Up @@ -583,7 +583,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
let mut error_builder = DataFusionErrorBuilder::new();
for expr in projection {
match self.sql_select_to_rex(expr, plan, empty_from, planner_context) {
Ok(expr) => prepared_select_exprs.push(expr),
Ok(expr) => prepared_select_exprs.extend(expr),
Err(err) => error_builder.add_error(err),
}
}
Expand All @@ -597,7 +597,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
plan: &LogicalPlan,
empty_from: bool,
planner_context: &mut PlannerContext,
) -> Result<Expr> {
) -> Result<Vec<Expr>> {
match sql {
SelectItem::UnnamedExpr(expr) => {
let expr = self.sql_to_expr(expr, plan.schema(), planner_context)?;
Expand All @@ -606,7 +606,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
&[&[plan.schema()]],
&plan.using_columns()?,
)?;
Ok(col)
Ok(vec![col])
}
SelectItem::ExprWithAlias { expr, alias } => {
let select_expr =
Expand All @@ -622,7 +622,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
Expr::Column(column) if column.name.eq(&name) => col,
_ => col.alias(name),
};
Ok(expr)
Ok(vec![expr])
}
SelectItem::Wildcard(options) => {
Self::check_wildcard_options(&options)?;
Expand All @@ -635,7 +635,18 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
planner_context,
options,
)?;
Ok(wildcard_with_options(planned_options))
// Ok(vec![wildcard_with_options(planned_options)])

let expanded =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There two parts are the real change, others are test adjustment

expand_wildcard(plan.schema(), plan, Some(&planned_options))?;

let replaced = if let Some(replace) = planned_options.replace {
replace_columns(expanded, &replace)?
} else {
expanded
};

Ok(replaced)
}
SelectItem::QualifiedWildcard(object_name, options) => {
Self::check_wildcard_options(&options)?;
Expand All @@ -646,7 +657,23 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
planner_context,
options,
)?;
Ok(qualified_wildcard_with_options(qualifier, planned_options))

let expanded = expand_qualified_wildcard(
&qualifier,
plan.schema(),
Some(&planned_options),
)?;
// If there is a REPLACE statement, replace that column with the given
// replace expression. Column name remains the same.
let replaced = if let Some(replace) = planned_options.replace {
replace_columns(expanded, &replace)?
} else {
expanded
};

Ok(replaced)

// Ok(vec![qualified_wildcard_with_options(qualifier, planned_options)])
}
}
}
Expand Down Expand Up @@ -698,7 +725,10 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
planner_context,
)
})
.collect::<Result<Vec<_>>>()?;
.collect::<Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect();
let planned_replace = PlannedReplaceSelectItem {
items: replace.items.into_iter().map(|i| *i).collect(),
planned_expressions: replace_expr,
Expand Down Expand Up @@ -884,3 +914,26 @@ fn match_window_definitions(
}
Ok(())
}

/// If there is a REPLACE statement in the projected expression in the form of
/// "REPLACE (some_column_within_an_expr AS some_column)", this function replaces
/// that column with the given replace expression. Column name remains the same.
/// Multiple REPLACEs are also possible with comma separations.
fn replace_columns(
mut exprs: Vec<Expr>,
replace: &PlannedReplaceSelectItem,
) -> Result<Vec<Expr>> {
for expr in exprs.iter_mut() {
if let Expr::Column(Column { name, .. }) = expr {
if let Some((_, new_expr)) = replace
.items()
.iter()
.zip(replace.expressions().iter())
.find(|(item, _)| item.column_name.value == *name)
{
*expr = new_expr.clone().alias(name.clone())
}
}
}
Ok(exprs)
}
137 changes: 8 additions & 129 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,13 +323,6 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
parser_dialect: Box::new(MySqlDialect {}),
unparser_dialect: Box::new(UnparserMySqlDialect {}),
},
TestStatementWithDialect {
sql: "select * from (select * from j1 limit 10);",
expected:
"SELECT * FROM (SELECT * FROM `j1` LIMIT 10) AS `derived_limit`",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we expand * to all the columns, should we convert back to * for unparser?

I'm not sure should we support this, I delete related test and consider not supported

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's hard to convert back to * after removing Expr::Wildcard. It's okay for me to show all the column names for unparsing. They are equivalent.

parser_dialect: Box::new(MySqlDialect {}),
unparser_dialect: Box::new(UnparserMySqlDialect {}),
},
TestStatementWithDialect {
sql: "select ta.j1_id from j1 ta order by j1_id limit 10;",
expected:
Expand Down Expand Up @@ -524,96 +517,6 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(SqliteDialect {}),
},
TestStatementWithDialect {
sql: "SELECT * FROM (SELECT j1_id + 1 FROM j1) AS temp_j(id2)",
expected: r#"SELECT * FROM (SELECT (`j1`.`j1_id` + 1) AS `id2` FROM `j1`) AS `temp_j`"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(SqliteDialect {}),
},
Comment on lines 526 to 532
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can't remove them. They aren't related to the wildcard expansion. It's better to change the expected result to match the expanded result instead.

SELECT `temp_j`.`id2` FROM (SELECT (`j1`.`j1_id` + 1) AS `id2` FROM `j1`) AS `temp_j

TestStatementWithDialect {
sql: "SELECT * FROM (SELECT j1_id FROM j1 LIMIT 1) AS temp_j(id2)",
expected: r#"SELECT * FROM (SELECT `j1`.`j1_id` AS `id2` FROM `j1` LIMIT 1) AS `temp_j`"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(SqliteDialect {}),
},
TestStatementWithDialect {
sql: "SELECT * FROM UNNEST([1,2,3])",
expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))")"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
TestStatementWithDialect {
sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)",
expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS t1 (c1)"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
TestStatementWithDialect {
sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)",
expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS t1 (c1)"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
TestStatementWithDialect {
sql: "SELECT * FROM UNNEST([1,2,3]), j1",
expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") CROSS JOIN j1"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
TestStatementWithDialect {
sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) JOIN j1 ON u.c1 = j1.j1_id",
expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS u (c1) JOIN j1 ON (u.c1 = j1.j1_id)"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
TestStatementWithDialect {
sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) UNION ALL SELECT * FROM UNNEST([4,5,6]) u(c1)",
expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS u (c1) UNION ALL SELECT * FROM (SELECT UNNEST([4, 5, 6]) AS "UNNEST(make_array(Int64(4),Int64(5),Int64(6)))") AS u (c1)"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
TestStatementWithDialect {
sql: "SELECT * FROM UNNEST([1,2,3])",
expected: r#"SELECT * FROM UNNEST([1, 2, 3])"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()),
},
TestStatementWithDialect {
sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)",
expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS t1 (c1)"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()),
},
TestStatementWithDialect {
sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)",
expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS t1 (c1)"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()),
},
TestStatementWithDialect {
sql: "SELECT * FROM UNNEST([1,2,3]), j1",
expected: r#"SELECT * FROM UNNEST([1, 2, 3]) CROSS JOIN j1"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()),
},
TestStatementWithDialect {
sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) JOIN j1 ON u.c1 = j1.j1_id",
expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS u (c1) JOIN j1 ON (u.c1 = j1.j1_id)"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()),
},
TestStatementWithDialect {
sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) UNION ALL SELECT * FROM UNNEST([4,5,6]) u(c1)",
expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS u (c1) UNION ALL SELECT * FROM UNNEST([4, 5, 6]) AS u (c1)"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()),
},
TestStatementWithDialect {
sql: "SELECT UNNEST([1,2,3])",
expected: r#"SELECT * FROM UNNEST([1, 2, 3])"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()),
},
TestStatementWithDialect {
sql: "SELECT UNNEST([1,2,3]) as c1",
expected: r#"SELECT UNNEST([1, 2, 3]) AS c1"#,
Expand All @@ -626,30 +529,6 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()),
},
TestStatementWithDialect {
sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col)",
expected: r#"SELECT * FROM unnest_table AS u CROSS JOIN UNNEST(u.array_col)"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()),
},
TestStatementWithDialect {
sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col) AS t1 (c1)",
expected: r#"SELECT * FROM unnest_table AS u CROSS JOIN UNNEST(u.array_col) AS t1 (c1)"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()),
},
TestStatementWithDialect {
sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col)",
expected: r#"SELECT * FROM unnest_table AS u CROSS JOIN LATERAL (SELECT UNNEST(u.array_col) AS "UNNEST(outer_ref(u.array_col))")"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
TestStatementWithDialect {
sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col) AS t1 (c1)",
expected: r#"SELECT * FROM unnest_table AS u CROSS JOIN LATERAL (SELECT UNNEST(u.array_col) AS "UNNEST(outer_ref(u.array_col))") AS t1 (c1)"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
];

for query in tests {
Expand Down Expand Up @@ -1456,13 +1335,13 @@ fn test_unnest_to_sql() {
fn test_join_with_no_conditions() {
sql_round_trip(
GenericDialect {},
"SELECT * FROM j1 JOIN j2",
"SELECT * FROM j1 CROSS JOIN j2",
"SELECT j1.j1_id FROM j1 JOIN j2",
"SELECT j1.j1_id FROM j1 CROSS JOIN j2",
);
sql_round_trip(
GenericDialect {},
"SELECT * FROM j1 CROSS JOIN j2",
"SELECT * FROM j1 CROSS JOIN j2",
"SELECT j1.j1_id FROM j1 CROSS JOIN j2",
"SELECT j1.j1_id FROM j1 CROSS JOIN j2",
);
}

Expand Down Expand Up @@ -1547,7 +1426,7 @@ impl UserDefinedLogicalNodeUnparser for UnusedUnparser {
fn test_unparse_extension_to_statement() -> Result<()> {
let dialect = GenericDialect {};
let statement = Parser::new(&dialect)
.try_with_sql("SELECT * FROM j1")?
.try_with_sql("SELECT j1.j1_id FROM j1")?
.parse_statement()?;
let state = MockSessionState::default();
let context = MockContextProvider { state };
Expand All @@ -1563,7 +1442,7 @@ fn test_unparse_extension_to_statement() -> Result<()> {
Arc::new(UnusedUnparser {}),
]);
let sql = unparser.plan_to_sql(&extension)?;
let expected = "SELECT * FROM j1";
let expected = "SELECT j1.j1_id FROM j1";
assert_eq!(sql.to_string(), expected);

if let Some(err) = plan_to_sql(&extension).err() {
Expand Down Expand Up @@ -1606,7 +1485,7 @@ impl UserDefinedLogicalNodeUnparser for MockSqlUnparser {
fn test_unparse_extension_to_sql() -> Result<()> {
let dialect = GenericDialect {};
let statement = Parser::new(&dialect)
.try_with_sql("SELECT * FROM j1")?
.try_with_sql("SELECT j1.j1_id FROM j1")?
.parse_statement()?;
let state = MockSessionState::default();
let context = MockContextProvider { state };
Expand All @@ -1626,7 +1505,7 @@ fn test_unparse_extension_to_sql() -> Result<()> {
Arc::new(UnusedUnparser {}),
]);
let sql = unparser.plan_to_sql(&plan)?;
let expected = "SELECT j1.j1_id AS user_id FROM (SELECT * FROM j1)";
let expected = "SELECT j1.j1_id AS user_id FROM (SELECT j1.j1_id FROM j1)";
assert_eq!(sql.to_string(), expected);

if let Some(err) = plan_to_sql(&plan).err() {
Expand Down
Loading