Skip to content

Commit 8022827

Browse files
authored
Fix optimizer regression with simplifying expressions in subquery filters (#3764)
1 parent e395e30 commit 8022827

File tree

3 files changed

+32
-6
lines changed

3 files changed

+32
-6
lines changed

datafusion/core/tests/sql/subqueries.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -336,10 +336,10 @@ order by s_name;
336336
Projection: part.p_partkey AS p_partkey, alias=__sq_1
337337
Filter: part.p_name LIKE Utf8("forest%")
338338
TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("forest%")]
339-
Projection: lineitem.l_partkey, lineitem.l_suppkey, CAST(Float64(0.5) AS Decimal128(38, 17)) * CAST(SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3
339+
Projection: lineitem.l_partkey, lineitem.l_suppkey, Decimal128(Some(50000000000000000),38,17) * CAST(SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3
340340
Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_quantity)]]
341-
Filter: lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)
342-
TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)]"#
341+
Filter: lineitem.l_shipdate >= Date32("8766")
342+
TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("8766")]"#
343343
.to_string();
344344
assert_eq!(actual, expected);
345345

@@ -393,8 +393,8 @@ order by cntrycode;"#;
393393
TableScan: orders projection=[o_custkey]
394394
Projection: AVG(customer.c_acctbal) AS __value, alias=__sq_1
395395
Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]]
396-
Filter: CAST(customer.c_acctbal AS Decimal128(30, 15)) > CAST(Float64(0) AS Decimal128(30, 15)) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
397-
TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[CAST(customer.c_acctbal AS Decimal128(30, 15)) > CAST(Float64(0) AS Decimal128(30, 15)), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]"#
396+
Filter: CAST(customer.c_acctbal AS Decimal128(30, 15)) > Decimal128(Some(0),30,15) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
397+
TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[CAST(customer.c_acctbal AS Decimal128(30, 15)) > Decimal128(Some(0),30,15), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]"#
398398
.to_string();
399399
assert_eq!(actual, expected);
400400

@@ -453,7 +453,7 @@ order by value desc;
453453
TableScan: supplier projection=[s_suppkey, s_nationkey]
454454
Filter: nation.n_name = Utf8("GERMANY")
455455
TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")]
456-
Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * CAST(Float64(0.0001) AS Decimal128(38, 17)) AS __value, alias=__sq_1
456+
Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, alias=__sq_1
457457
Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
458458
Inner Join: supplier.s_nationkey = nation.n_nationkey
459459
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey

datafusion/optimizer/src/optimizer.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ impl Optimizer {
144144
Arc::new(DecorrelateWhereIn::new()),
145145
Arc::new(ScalarSubqueryToJoin::new()),
146146
Arc::new(SubqueryFilterToJoin::new()),
147+
// simplify expressions does not simplify expressions in subqueries, so we
148+
// run it again after running the optimizations that potentially converted
149+
// subqueries to joins
150+
Arc::new(SimplifyExpressions::new()),
147151
Arc::new(EliminateFilter::new()),
148152
Arc::new(ReduceCrossJoin::new()),
149153
Arc::new(CommonSubexprEliminate::new()),

datafusion/optimizer/tests/integration-test.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,28 @@ fn case_when() -> Result<()> {
5252
Ok(())
5353
}
5454

55+
#[test]
56+
fn subquery_filter_with_cast() -> Result<()> {
57+
// regression test for https://github.com/apache/arrow-datafusion/issues/3760
58+
let sql = "SELECT col_int32 FROM test \
59+
WHERE col_int32 > (\
60+
SELECT AVG(col_int32) FROM test \
61+
WHERE col_utf8 BETWEEN '2002-05-08' \
62+
AND (cast('2002-05-08' as date) + interval '5 days')\
63+
)";
64+
let plan = test_sql(sql)?;
65+
let expected =
66+
"Projection: test.col_int32\n Filter: CAST(test.col_int32 AS Float64) > __sq_1.__value\
67+
\n CrossJoin:\
68+
\n TableScan: test projection=[col_int32]\
69+
\n Projection: AVG(test.col_int32) AS __value, alias=__sq_1\
70+
\n Aggregate: groupBy=[[]], aggr=[[AVG(test.col_int32)]]\
71+
\n Filter: test.col_utf8 >= Utf8(\"2002-05-08\") AND test.col_utf8 <= Utf8(\"2002-05-13\")\
72+
\n TableScan: test projection=[col_int32, col_utf8]";
73+
assert_eq!(expected, format!("{:?}", plan));
74+
Ok(())
75+
}
76+
5577
#[test]
5678
fn case_when_aggregate() -> Result<()> {
5779
let sql = "SELECT col_utf8, SUM(CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END) AS n FROM test GROUP BY col_utf8";

0 commit comments

Comments
 (0)