Skip to content

Commit 38e7e14

Browse files
authored
refactor(api): make input value coercion of mutate() identical to select() (ibis-project#8878)
String literals passed to select() are interpreted as columns whereas mutate() interpreted them as literals. BREAKING CHANGE: strings passed to table.mutate() are now interpreted as column references instead of literals, use `ibis.literal(string)` to pass the string as a literal
1 parent d7f94e5 commit 38e7e14

11 files changed

Lines changed: 29 additions & 26 deletions

File tree

docs/tutorials/ibis-for-sql-users.qmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ In many other situations, you can use constants without having to use
600600
but the number 5 like so:
601601

602602
```{python}
603-
expr = t3.mutate(number5=5)
603+
expr = t3.mutate(number5=ibis.literal(5))
604604
ibis.to_sql(expr)
605605
```
606606

ibis/backends/dask/tests/test_window.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def test_mutate_scalar_with_window_after_join(npartitions):
382382

383383
joined = left.outer_join(right, left.ints == right.group)
384384
proj = joined[left, right.value]
385-
expr = proj.mutate(sum=proj.value.sum(), const=1)
385+
expr = proj.mutate(sum=proj.value.sum(), const=ibis.literal(1))
386386
result = expr.execute()
387387
result = result.sort_values(["ints", "value"]).reset_index(drop=True)
388388
expected = (

ibis/backends/pandas/tests/test_join.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def test_mutate_after_join():
506506
q_count=joined["q_count"].fillna(0),
507507
p_density=joined.p_density.fillna(1e-10),
508508
q_density=joined.q_density.fillna(1e-10),
509-
features="Order_Priority",
509+
features=ibis.literal("Order_Priority"),
510510
)
511511

512512
expected = pd.DataFrame(

ibis/backends/pandas/tests/test_window.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def test_mutate_scalar_with_window_after_join():
391391

392392
joined = left.outer_join(right, left.ints == right.group)
393393
proj = joined[left, right.value]
394-
expr = proj.mutate(sum=proj.value.sum(), const=1)
394+
expr = proj.mutate(sum=proj.value.sum(), const=ibis.literal(1))
395395
result = expr.execute()
396396
expected = pd.DataFrame(
397397
{

ibis/backends/tests/sql/test_sql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def test_mutate_filter_join_no_cross_join(snapshot):
390390
[("person_id", "int64"), ("birth_datetime", "timestamp")],
391391
name="person",
392392
)
393-
mutated = person.mutate(age=400)
393+
mutated = person.mutate(age=ibis.literal(400))
394394
expr = mutated.filter(mutated.age <= 40)[mutated.person_id]
395395

396396
snapshot.assert_match(to_sql(expr), "out.sql")

ibis/backends/tests/test_client.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,7 +1214,7 @@ def test_create_table_timestamp(con, temp_table):
12141214
reason="Feature is not yet implemented: CREATE TEMPORARY TABLE",
12151215
)
12161216
def test_persist_expression_ref_count(backend, con, alltypes):
1217-
non_persisted_table = alltypes.mutate(test_column="calculation")
1217+
non_persisted_table = alltypes.mutate(test_column=ibis.literal("calculation"))
12181218
persisted_table = non_persisted_table.cache()
12191219

12201220
op = non_persisted_table.op()
@@ -1239,7 +1239,9 @@ def test_persist_expression_ref_count(backend, con, alltypes):
12391239
reason="Feature is not yet implemented: CREATE TEMPORARY TABLE",
12401240
)
12411241
def test_persist_expression(backend, alltypes):
1242-
non_persisted_table = alltypes.mutate(test_column="calculation", other_calc="xyz")
1242+
non_persisted_table = alltypes.mutate(
1243+
test_column=ibis.literal("calculation"), other_calc=ibis.literal("xyz")
1244+
)
12431245
persisted_table = non_persisted_table.cache()
12441246
backend.assert_frame_equal(
12451247
non_persisted_table.to_pandas(), persisted_table.to_pandas()
@@ -1259,7 +1261,7 @@ def test_persist_expression(backend, alltypes):
12591261
)
12601262
def test_persist_expression_contextmanager(backend, alltypes):
12611263
non_cached_table = alltypes.mutate(
1262-
test_column="calculation", other_column="big calc"
1264+
test_column=ibis.literal("calculation"), other_column=ibis.literal("big calc")
12631265
)
12641266
with non_cached_table.cache() as cached_table:
12651267
backend.assert_frame_equal(
@@ -1280,7 +1282,7 @@ def test_persist_expression_contextmanager(backend, alltypes):
12801282
)
12811283
def test_persist_expression_contextmanager_ref_count(backend, con, alltypes):
12821284
non_cached_table = alltypes.mutate(
1283-
test_column="calculation", other_column="big calc 2"
1285+
test_column=ibis.literal("calculation"), other_column=ibis.literal("big calc 2")
12841286
)
12851287
op = non_cached_table.op()
12861288
with non_cached_table.cache() as cached_table:
@@ -1304,7 +1306,7 @@ def test_persist_expression_contextmanager_ref_count(backend, con, alltypes):
13041306
@mark.notimpl(["exasol"], reason="Exasol does not support temporary tables")
13051307
def test_persist_expression_multiple_refs(backend, con, alltypes):
13061308
non_cached_table = alltypes.mutate(
1307-
test_column="calculation", other_column="big calc 2"
1309+
test_column=ibis.literal("calculation"), other_column=ibis.literal("big calc 2")
13081310
)
13091311
op = non_cached_table.op()
13101312
with non_cached_table.cache() as cached_table:
@@ -1345,7 +1347,7 @@ def test_persist_expression_multiple_refs(backend, con, alltypes):
13451347
)
13461348
def test_persist_expression_repeated_cache(alltypes):
13471349
non_cached_table = alltypes.mutate(
1348-
test_column="calculation", other_column="big calc 2"
1350+
test_column=ibis.literal("calculation"), other_column=ibis.literal("big calc 2")
13491351
)
13501352
with non_cached_table.cache() as cached_table:
13511353
with cached_table.cache() as nested_cached_table:
@@ -1374,7 +1376,7 @@ def test_persist_expression_repeated_cache(alltypes):
13741376
)
13751377
def test_persist_expression_release(con, alltypes):
13761378
non_cached_table = alltypes.mutate(
1377-
test_column="calculation", other_column="big calc 3"
1379+
test_column=ibis.literal("calculation"), other_column=ibis.literal("big calc 3")
13781380
)
13791381
cached_table = non_cached_table.cache()
13801382
cached_table.release()

ibis/backends/tests/test_generic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,7 @@ def test_uncorrelated_subquery(backend, batting, batting_df):
785785

786786

787787
def test_int_column(alltypes):
788-
expr = alltypes.mutate(x=1).x
788+
expr = alltypes.mutate(x=ibis.literal(1)).x
789789
result = expr.execute()
790790
assert expr.type() == dt.int8
791791
assert result.dtype == np.int8

ibis/backends/tests/test_temporal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1236,7 +1236,7 @@ def test_interval_add_cast_column(backend, alltypes, df):
12361236
),
12371237
param(
12381238
lambda t: (
1239-
t.mutate(suffix="%d")
1239+
t.mutate(suffix=ibis.literal("%d"))
12401240
.select(formatted=lambda t: t.timestamp_col.strftime("%Y%m" + t.suffix))
12411241
.formatted
12421242
),

ibis/expr/tests/test_newrels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def test_mutate():
232232
def test_mutate_overwrites_existing_column():
233233
t = ibis.table(dict(a="string", b="string"))
234234

235-
mut = t.mutate(a=42)
235+
mut = t.mutate(a=ibis.literal(42))
236236
assert mut.op() == Project(parent=t, values={"a": ibis.literal(42), "b": t.b})
237237

238238
sel = mut.select("a")

ibis/expr/types/relations.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ def f( # noqa: D417
9595

9696
# TODO(kszucs): should use (table, *args, **kwargs) instead to avoid interpreting
9797
# nested inputs
98-
def bind(table: Table, value: Any, prefer_column=True) -> Iterator[ir.Value]:
98+
def bind(table: Table, value: Any) -> Iterator[ir.Value]:
9999
"""Bind a value to a table expression."""
100-
if prefer_column and type(value) in (str, int):
100+
if type(value) in (str, int):
101101
yield table._get_column(value)
102102
elif isinstance(value, ValueExpr):
103103
yield value
@@ -110,11 +110,11 @@ def bind(table: Table, value: Any, prefer_column=True) -> Iterator[ir.Value]:
110110
yield from value.expand(table)
111111
elif isinstance(value, Mapping):
112112
for k, v in value.items():
113-
for val in bind(table, v, prefer_column=prefer_column):
113+
for val in bind(table, v):
114114
yield val.name(k)
115115
elif util.is_iterable(value):
116116
for v in value:
117-
yield from bind(table, v, prefer_column=prefer_column)
117+
yield from bind(table, v)
118118
elif isinstance(value, ops.Value):
119119
# TODO(kszucs): from certain builders, like ir.GroupedTable we pass
120120
# operation nodes instead of expressions to table methods, it would
@@ -1946,7 +1946,7 @@ def mutate(self, *exprs: Sequence[ir.Expr] | None, **mutations: ir.Value) -> Tab
19461946
# string and integer inputs are going to be coerced to literals instead
19471947
# of interpreted as column references like in select
19481948
node = self.op()
1949-
values = bind(self, (exprs, mutations), prefer_column=False)
1949+
values = bind(self, (exprs, mutations))
19501950
values = unwrap_aliases(values)
19511951
# allow overriding of fields, hence the mutation behavior
19521952
values = {**node.fields, **values}
@@ -3359,7 +3359,8 @@ def cache(self) -> Table:
33593359
>>> import ibis
33603360
>>> ibis.options.interactive = True
33613361
>>> t = ibis.examples.penguins.fetch()
3362-
>>> cached_penguins = t.mutate(computation="Heavy Computation").cache()
3362+
>>> heavy_computation = ibis.literal("Heavy Computation")
3363+
>>> cached_penguins = t.mutate(computation=heavy_computation).cache()
33633364
>>> cached_penguins
33643365
┏━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━┓
33653366
┃ species ┃ island ┃ bill_length_mm ┃ bill_depth_mm ┃ flipper_length_mm ┃ … ┃
@@ -3381,7 +3382,7 @@ def cache(self) -> Table:
33813382
33823383
Explicit cache cleanup
33833384
3384-
>>> with t.mutate(computation="Heavy Computation").cache() as cached_penguins:
3385+
>>> with t.mutate(computation=heavy_computation).cache() as cached_penguins:
33853386
... cached_penguins
33863387
┏━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━┓
33873388
┃ species ┃ island ┃ bill_length_mm ┃ bill_depth_mm ┃ flipper_length_mm ┃ … ┃

0 commit comments

Comments
 (0)