Skip to content

Commit ba931da

Browse files
authored
fix(sql): avoid excessive inlining during Select merge (ibis-project#8825)
1 parent 1237fe3 commit ba931da

8 files changed

Lines changed: 377 additions & 266 deletions

File tree

ibis/backends/sql/rewrites.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,30 @@ def first_to_firstvalue(_, **kwargs):
119119
return _.copy(func=klass(_.func.arg))
120120

121121

122+
def complexity(node):
123+
"""Assign a complexity score to a node.
124+
125+
Subsequent projections can be merged into a single projection by replacing
126+
the fields referenced in the outer projection with the computed expressions
127+
from the inner projection. This inlining can result in very complex value
128+
expressions depending on the projections. In order to prevent excessive
129+
inlining, we assign a complexity score to each node.
130+
131+
The complexity score assigns 1 to each value expression and adds up in the
132+
tree hierarchy unless there is a Field node where we don't add up the
133+
complexity of the referenced relation. This way we treat fields kind of like
134+
reusable variables considering them less complex than they were inlined.
135+
"""
136+
137+
def accum(node, *args):
138+
if isinstance(node, ops.Field):
139+
return 1
140+
else:
141+
return 1 + sum(args)
142+
143+
return node.map_nodes(accum)[node]
144+
145+
122146
@replace(Object(Select, Object(Select)))
123147
def merge_select_select(_, **kwargs):
124148
"""Merge subsequent Select relations into one.
@@ -128,15 +152,11 @@ def merge_select_select(_, **kwargs):
128152
from the inner Select are inlined into the outer Select.
129153
"""
130154
# don't merge if either the outer or the inner select has window functions
131-
for v in _.selections.values():
132-
if v.find(ops.WindowFunction, filter=ops.Value):
133-
return _
134-
for v in _.parent.selections.values():
135-
if v.find((ops.WindowFunction, ops.Unnest), filter=ops.Value):
136-
return _
137-
for v in _.predicates:
138-
if v.find((ops.ExistsSubquery, ops.InSubquery), filter=ops.Value):
139-
return _
155+
blocking = (ops.WindowFunction, ops.ExistsSubquery, ops.InSubquery, ops.Unnest)
156+
if _.find_below(blocking, filter=ops.Value):
157+
return _
158+
if _.parent.find_below(blocking, filter=ops.Value):
159+
return _
140160

141161
subs = {ops.Field(_.parent, k): v for k, v in _.parent.values.items()}
142162
selections = {k: v.replace(subs, filter=ops.Value) for k, v in _.selections.items()}
@@ -151,12 +171,13 @@ def merge_select_select(_, **kwargs):
151171
)
152172
unique_sort_keys = sort_keys + parent_sort_keys
153173

154-
return Select(
174+
result = Select(
155175
_.parent.parent,
156176
selections=selections,
157177
predicates=unique_predicates,
158178
sort_keys=unique_sort_keys,
159179
)
180+
return result if complexity(result) <= complexity(_) else _
160181

161182

162183
def extract_ctes(node):
@@ -198,7 +219,8 @@ def sqlize(
198219
assert isinstance(node, ops.Relation)
199220

200221
# apply the backend specific rewrites
201-
node = node.replace(reduce(operator.or_, rewrites))
222+
if rewrites:
223+
node = node.replace(reduce(operator.or_, rewrites))
202224

203225
# lower the expression graph to a SQL-like relational algebra
204226
context = {"params": params}
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
SELECT
2-
IIF(
2+
IIF([t2].[InSubquery(x)] <> 0, 1, 0) AS [InSubquery(x)]
3+
FROM (
4+
SELECT
35
[t0].[x] IN (
46
SELECT
57
[t0].[x]
68
FROM [t] AS [t0]
79
WHERE
810
[t0].[x] > 2
9-
),
10-
1,
11-
0
12-
) AS [InSubquery(x)]
13-
FROM [t] AS [t0]
11+
) AS [InSubquery(x)]
12+
FROM [t] AS [t0]
13+
) AS [t2]

ibis/backends/tests/test_generic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,11 @@ def test_typeof(con):
852852
reason="https://github.com/risingwavelabs/risingwave/issues/1343",
853853
)
854854
@pytest.mark.xfail_version(dask=["dask<2024.2.0"])
855+
@pytest.mark.notyet(
856+
["mssql"],
857+
raises=PyODBCProgrammingError,
858+
reason="naked IN queries are not supported",
859+
)
855860
def test_isin_uncorrelated(
856861
backend, batting, awards_players, batting_df, awards_players_df
857862
):
Lines changed: 81 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,90 @@
11
SELECT
2-
"t9"."s_name",
3-
"t9"."s_address"
2+
"t13"."s_name",
3+
"t13"."s_address"
44
FROM (
55
SELECT
6-
"t5"."s_suppkey",
7-
"t5"."s_name",
8-
"t5"."s_address",
9-
"t5"."s_nationkey",
10-
"t5"."s_phone",
11-
"t5"."s_acctbal",
12-
"t5"."s_comment",
13-
"t6"."n_nationkey",
14-
"t6"."n_name",
15-
"t6"."n_regionkey",
16-
"t6"."n_comment"
17-
FROM "supplier" AS "t5"
18-
INNER JOIN "nation" AS "t6"
19-
ON "t5"."s_nationkey" = "t6"."n_nationkey"
20-
) AS "t9"
21-
WHERE
22-
"t9"."n_name" = 'CANADA'
23-
AND "t9"."s_suppkey" IN (
6+
"t9"."s_suppkey",
7+
"t9"."s_name",
8+
"t9"."s_address",
9+
"t9"."s_nationkey",
10+
"t9"."s_phone",
11+
"t9"."s_acctbal",
12+
"t9"."s_comment",
13+
"t9"."n_nationkey",
14+
"t9"."n_name",
15+
"t9"."n_regionkey",
16+
"t9"."n_comment"
17+
FROM (
2418
SELECT
25-
"t1"."ps_suppkey"
26-
FROM "partsupp" AS "t1"
27-
WHERE
28-
"t1"."ps_partkey" IN (
19+
"t5"."s_suppkey",
20+
"t5"."s_name",
21+
"t5"."s_address",
22+
"t5"."s_nationkey",
23+
"t5"."s_phone",
24+
"t5"."s_acctbal",
25+
"t5"."s_comment",
26+
"t6"."n_nationkey",
27+
"t6"."n_name",
28+
"t6"."n_regionkey",
29+
"t6"."n_comment"
30+
FROM "supplier" AS "t5"
31+
INNER JOIN "nation" AS "t6"
32+
ON "t5"."s_nationkey" = "t6"."n_nationkey"
33+
) AS "t9"
34+
WHERE
35+
"t9"."n_name" = 'CANADA'
36+
AND "t9"."s_suppkey" IN (
37+
SELECT
38+
"t11"."ps_suppkey"
39+
FROM (
2940
SELECT
30-
"t3"."p_partkey"
31-
FROM "part" AS "t3"
41+
"t2"."ps_partkey",
42+
"t2"."ps_suppkey",
43+
"t2"."ps_availqty",
44+
"t2"."ps_supplycost",
45+
"t2"."ps_comment"
46+
FROM "partsupp" AS "t2"
3247
WHERE
33-
"t3"."p_name" LIKE 'forest%'
34-
)
35-
AND "t1"."ps_availqty" > (
36-
(
37-
SELECT
38-
SUM("t8"."l_quantity") AS "Sum(l_quantity)"
39-
FROM (
48+
"t2"."ps_partkey" IN (
4049
SELECT
41-
"t4"."l_orderkey",
42-
"t4"."l_partkey",
43-
"t4"."l_suppkey",
44-
"t4"."l_linenumber",
45-
"t4"."l_quantity",
46-
"t4"."l_extendedprice",
47-
"t4"."l_discount",
48-
"t4"."l_tax",
49-
"t4"."l_returnflag",
50-
"t4"."l_linestatus",
51-
"t4"."l_shipdate",
52-
"t4"."l_commitdate",
53-
"t4"."l_receiptdate",
54-
"t4"."l_shipinstruct",
55-
"t4"."l_shipmode",
56-
"t4"."l_comment"
57-
FROM "lineitem" AS "t4"
50+
"t3"."p_partkey"
51+
FROM "part" AS "t3"
5852
WHERE
59-
"t4"."l_partkey" = "t1"."ps_partkey"
60-
AND "t4"."l_suppkey" = "t1"."ps_suppkey"
61-
AND "t4"."l_shipdate" >= MAKE_DATE(1994, 1, 1)
62-
AND "t4"."l_shipdate" < MAKE_DATE(1995, 1, 1)
63-
) AS "t8"
64-
) * CAST(0.5 AS DOUBLE)
65-
)
66-
)
53+
"t3"."p_name" LIKE 'forest%'
54+
)
55+
AND "t2"."ps_availqty" > (
56+
(
57+
SELECT
58+
SUM("t8"."l_quantity") AS "Sum(l_quantity)"
59+
FROM (
60+
SELECT
61+
"t4"."l_orderkey",
62+
"t4"."l_partkey",
63+
"t4"."l_suppkey",
64+
"t4"."l_linenumber",
65+
"t4"."l_quantity",
66+
"t4"."l_extendedprice",
67+
"t4"."l_discount",
68+
"t4"."l_tax",
69+
"t4"."l_returnflag",
70+
"t4"."l_linestatus",
71+
"t4"."l_shipdate",
72+
"t4"."l_commitdate",
73+
"t4"."l_receiptdate",
74+
"t4"."l_shipinstruct",
75+
"t4"."l_shipmode",
76+
"t4"."l_comment"
77+
FROM "lineitem" AS "t4"
78+
WHERE
79+
"t4"."l_partkey" = "t2"."ps_partkey"
80+
AND "t4"."l_suppkey" = "t2"."ps_suppkey"
81+
AND "t4"."l_shipdate" >= MAKE_DATE(1994, 1, 1)
82+
AND "t4"."l_shipdate" < MAKE_DATE(1995, 1, 1)
83+
) AS "t8"
84+
) * CAST(0.5 AS DOUBLE)
85+
)
86+
) AS "t11"
87+
)
88+
) AS "t13"
6789
ORDER BY
68-
"t9"."s_name" ASC
90+
"t13"."s_name" ASC

0 commit comments

Comments
 (0)