Skip to content

Commit fd35b66

Browse files
kszucscpcloud
andauthored
refactor(api): restrict arbitrary input nesting (ibis-project#8917)
Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com>
1 parent 73dd685 commit fd35b66

8 files changed

Lines changed: 204 additions & 99 deletions

File tree

ibis/backends/tests/test_vectorized_udf.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,8 +592,7 @@ def test_elementwise_udf_named_destruct(udf_alltypes):
592592
add_one_struct_udf = create_add_one_struct_udf(
593593
result_formatter=lambda v1, v2: (v1, v2)
594594
)
595-
msg = "Duplicate column name 'new_struct' in result set"
596-
with pytest.raises(com.IntegrityError, match=msg):
595+
with pytest.raises(com.InputTypeError, match="Unable to infer datatype"):
597596
udf_alltypes.mutate(
598597
new_struct=add_one_struct_udf(udf_alltypes["double_col"]).destructure()
599598
)

ibis/expr/builders.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,17 +226,15 @@ def order_by(self, expr) -> Self:
226226
return self.copy(orderings=self.orderings + util.promote_tuple(expr))
227227

228228
def bind(self, table):
229-
from ibis.expr.types.relations import bind
230-
231229
if table is None:
232230
if self._table is None:
233231
raise IbisInputError("Cannot bind window frame without a table")
234232
else:
235233
table = self._table.to_expr()
236234

237-
grouping = bind(table, self.groupings)
238-
orderings = bind(table, self.orderings)
239-
return self.copy(groupings=grouping, orderings=orderings)
235+
return self.copy(
236+
groupings=table.bind(self.groupings), orderings=table.bind(self.orderings)
237+
)
240238

241239

242240
class LegacyWindowBuilder(WindowBuilder):

ibis/expr/rewrites.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,11 @@ def backtrack(cls, value):
127127
yield value, distance
128128
value = value.rel.values.get(value.name)
129129
distance += 1
130-
if value is not None and not value.find(ops.Impure, filter=ops.Value):
130+
if (
131+
value is not None
132+
and value.relations
133+
and not value.find(ops.Impure, filter=ops.Value)
134+
):
131135
yield value, distance
132136

133137
def dereference(self, value):

ibis/expr/tests/test_newrels.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import ibis.expr.datatypes as dt
1010
import ibis.expr.operations as ops
1111
import ibis.expr.types as ir
12+
import ibis.selectors as s
1213
from ibis import _
1314
from ibis.common.annotations import ValidationError
1415
from ibis.common.exceptions import IbisInputError, IntegrityError
@@ -499,7 +500,7 @@ def test_subsequent_filter():
499500
assert f2.op() == expected
500501

501502

502-
def test_project_dereferences_literal_expressions():
503+
def test_project_doesnt_dereference_literal_expressions():
503504
one = ibis.literal(1)
504505
two = ibis.literal(2)
505506
four = (one + one) * two
@@ -516,7 +517,7 @@ def test_project_dereferences_literal_expressions():
516517
)
517518

518519
t2 = t1.select(four)
519-
assert t2.op() == Project(parent=t1, values={four.get_name(): t1.four})
520+
assert t2.op() == Project(parent=t1, values={four.get_name(): four})
520521

521522

522523
def test_project_before_and_after_filter():
@@ -864,6 +865,22 @@ def test_join_predicate_dereferencing_using_tuple_syntax():
864865
assert j2.op() == expected
865866

866867

868+
def test_join_with_selector_predicate():
869+
t1 = ibis.table(name="t1", schema={"a": "string", "b": "string"})
870+
t2 = ibis.table(name="t2", schema={"c": "string", "d": "string"})
871+
872+
joined = t1.join(t2, s.of_type("string"))
873+
with join_tables(joined) as (r1, r2):
874+
expected = JoinChain(
875+
first=r1,
876+
rest=[
877+
JoinLink("inner", r2, [r1.a == r2.c, r1.b == r2.d]),
878+
],
879+
values={"a": r1.a, "b": r1.b, "c": r2.c, "d": r2.d},
880+
)
881+
assert joined.op() == expected
882+
883+
867884
def test_join_rhs_dereferencing():
868885
t1 = ibis.table(name="t1", schema={"a": "int64", "b": "string"})
869886
t2 = ibis.table(name="t2", schema={"c": "int64", "d": "string"})

ibis/expr/types/groupby.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,9 @@
2828
from ibis.common.grounds import Concrete
2929
from ibis.common.typing import VarTuple # noqa: TCH001
3030
from ibis.expr.rewrites import rewrite_window_input
31-
from ibis.expr.types.relations import bind
3231

3332
if TYPE_CHECKING:
34-
from collections.abc import Iterable, Sequence
33+
from collections.abc import Sequence
3534

3635

3736
@public
@@ -65,13 +64,14 @@ def __getattr__(self, attr):
6564

6665
def aggregate(self, *metrics, **kwds) -> ir.Table:
6766
"""Compute aggregates over a group by."""
67+
metrics = self.table.to_expr().bind(*metrics, **kwds)
6868
return self.table.to_expr().aggregate(
69-
metrics, by=self.groupings, having=self.havings, **kwds
69+
metrics, by=self.groupings, having=self.havings
7070
)
7171

7272
agg = aggregate
7373

74-
def having(self, *expr: ir.BooleanScalar) -> GroupedTable:
74+
def having(self, *predicates: ir.BooleanScalar) -> GroupedTable:
7575
"""Add a post-aggregation result filter `expr`.
7676
7777
::: {.callout-warning}
@@ -80,19 +80,19 @@ def having(self, *expr: ir.BooleanScalar) -> GroupedTable:
8080
8181
Parameters
8282
----------
83-
expr
84-
An expression that filters based on an aggregate value.
83+
predicates
84+
Expressions that filters based on an aggregate value.
8585
8686
Returns
8787
-------
8888
GroupedTable
8989
A grouped table expression
9090
"""
9191
table = self.table.to_expr()
92-
havings = tuple(bind(table, expr))
92+
havings = table.bind(*predicates)
9393
return self.copy(havings=self.havings + havings)
9494

95-
def order_by(self, *expr: ir.Value | Iterable[ir.Value]) -> GroupedTable:
95+
def order_by(self, *by: ir.Value) -> GroupedTable:
9696
"""Sort a grouped table expression by `expr`.
9797
9898
Notes
@@ -101,7 +101,7 @@ def order_by(self, *expr: ir.Value | Iterable[ir.Value]) -> GroupedTable:
101101
102102
Parameters
103103
----------
104-
expr
104+
by
105105
Expressions to order the results by
106106
107107
Returns
@@ -110,7 +110,7 @@ def order_by(self, *expr: ir.Value | Iterable[ir.Value]) -> GroupedTable:
110110
A sorted grouped GroupedTable
111111
"""
112112
table = self.table.to_expr()
113-
orderings = tuple(bind(table, expr))
113+
orderings = table.bind(*by)
114114
return self.copy(orderings=self.orderings + orderings)
115115

116116
def mutate(
@@ -201,7 +201,7 @@ def _selectables(self, *exprs, **kwexprs):
201201
[`GroupedTable.mutate`](#ibis.expr.types.groupby.GroupedTable.mutate)
202202
"""
203203
table = self.table.to_expr()
204-
values = bind(table, (exprs, kwexprs))
204+
values = table.bind(*exprs, **kwexprs)
205205
window = ibis.window(group_by=self.groupings, order_by=self.orderings)
206206
return [rewrite_window_input(expr.op(), window).to_expr() for expr in values]
207207

ibis/expr/types/joins.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,10 @@ def prepare_predicates(
176176
else:
177177
lk = rk = pred
178178

179-
# bind the predicates to the join chain
180-
(left_value,) = bind(left, lk)
181-
(right_value,) = bind(right, rk)
182-
183-
# dereference the left value to one of the relations in the join chain
184-
left_value = deref_left.dereference(left_value.op())
185-
right_value = deref_right.dereference(right_value.op())
186-
187-
yield comparison(left_value, right_value)
179+
for lhs, rhs in zip(bind(left, lk), bind(right, rk)):
180+
lhs = deref_left.dereference(lhs.op())
181+
rhs = deref_right.dereference(rhs.op())
182+
yield comparison(lhs, rhs)
188183

189184

190185
def finished(method):
@@ -335,8 +330,9 @@ def asof_join(
335330
result = self.left_join(
336331
filtered, predicates=[left_on == right_on] + predicates
337332
)
338-
values = {**self.op().values, **filtered.op().values}
339-
return result.select(values)
333+
values = {**filtered.op().values, **self.op().values}
334+
335+
return result.select(**values)
340336

341337
chain = self.op()
342338
right = right.op()
@@ -383,7 +379,7 @@ def cross_join(
383379
@functools.wraps(Table.select)
384380
def select(self, *args, **kwargs):
385381
chain = self.op()
386-
values = bind(self, (args, kwargs))
382+
values = self.bind(*args, **kwargs)
387383
values = unwrap_aliases(values)
388384

389385
links = [link.table for link in chain.rest if link.how not in ("semi", "anti")]

0 commit comments

Comments
 (0)