Skip to content

Commit feeb8ae

Browse files
authored
refactor(api): treat integer inputs as literals instead of column references (ibis-project#8884)
Removing the int-as-column references in mutate and select. Further discussed in ibis-project#8878. BREAKING CHANGE: Integer inputs to `select` and `mutate` are now always interpreted as literals. Columns can still be accessed by their integer index using square-bracket syntax.
1 parent 0f00101 commit feeb8ae

4 files changed

Lines changed: 45 additions & 60 deletions

File tree

ibis/expr/types/relations.py

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ibis import util
2424
from ibis.common.deferred import Deferred, Resolver
2525
from ibis.expr.types.core import Expr, _FixedTextJupyterMixin
26-
from ibis.expr.types.generic import ValueExpr, literal
26+
from ibis.expr.types.generic import Value, literal
2727
from ibis.expr.types.pretty import to_rich
2828
from ibis.selectors import Selector
2929
from ibis.util import deprecated
@@ -97,15 +97,23 @@ def f( # noqa: D417
9797

9898
# TODO(kszucs): should use (table, *args, **kwargs) instead to avoid interpreting
9999
# nested inputs
100-
def bind(table: Table, value: Any) -> Iterator[ir.Value]:
100+
def bind(table: Table, value: Any, int_as_column=False) -> Iterator[ir.Value]:
101101
"""Bind a value to a table expression."""
102-
if type(value) in (str, int):
103-
yield table._get_column(value)
104-
elif isinstance(value, ValueExpr):
102+
if isinstance(value, str):
103+
# TODO(kszucs): perhaps use getattr(table, value) instead for nicer error msg
104+
yield ops.Field(table, value).to_expr()
105+
elif isinstance(value, bool):
106+
yield literal(value)
107+
elif int_as_column and isinstance(value, int):
108+
name = table.columns[value]
109+
yield ops.Field(table, name).to_expr()
110+
elif isinstance(value, ops.Value):
111+
yield value.to_expr()
112+
elif isinstance(value, Value):
105113
yield value
106114
elif isinstance(value, Table):
107115
for name in value.columns:
108-
yield value._get_column(name)
116+
yield ops.Field(table, name).to_expr()
109117
elif isinstance(value, Deferred):
110118
yield value.resolve(table)
111119
elif isinstance(value, Resolver):
@@ -114,17 +122,11 @@ def bind(table: Table, value: Any) -> Iterator[ir.Value]:
114122
yield from value.expand(table)
115123
elif isinstance(value, Mapping):
116124
for k, v in value.items():
117-
for val in bind(table, v):
125+
for val in bind(table, v, int_as_column=int_as_column):
118126
yield val.name(k)
119127
elif util.is_iterable(value):
120128
for v in value:
121-
yield from bind(table, v)
122-
elif isinstance(value, ops.Value):
123-
# TODO(kszucs): from certain builders, like ir.GroupedTable we pass
124-
# operation nodes instead of expressions to table methods, it would
125-
# be better to convert them to expressions before passing them to
126-
# this function
127-
yield value.to_expr()
129+
yield from bind(table, v, int_as_column=int_as_column)
128130
elif callable(value):
129131
yield value(table)
130132
else:
@@ -567,13 +569,6 @@ def preview(
567569
console_width=console_width,
568570
)
569571

570-
# TODO(kszucs): expose this method in the public API
571-
def _get_column(self, name: str | int) -> ir.Column:
572-
"""Get a column from the table."""
573-
if isinstance(name, int):
574-
name = self.schema().name_at_position(name)
575-
return ops.Field(self, name).to_expr()
576-
577572
def __getitem__(self, what):
578573
"""Select items from a table expression.
579574
@@ -820,22 +815,18 @@ def __getitem__(self, what):
820815
"""
821816
from ibis.expr.types.logical import BooleanValue
822817

823-
if isinstance(what, (str, int)):
824-
return self._get_column(what)
825-
elif isinstance(what, slice):
818+
if isinstance(what, slice):
826819
limit, offset = util.slice_to_limit_offset(what, self.count())
827820
return self.limit(limit, offset=offset)
828-
elif isinstance(what, (list, tuple, Table)):
829-
# Projection case
830-
return self.select(what)
831-
832-
items = tuple(bind(self, what))
833-
if util.all_of(items, BooleanValue):
834-
# TODO(kszucs): this branch should be removed, .filter should be
835-
# used instead
836-
return self.filter(items)
821+
822+
values = tuple(bind(self, what, int_as_column=True))
823+
if isinstance(what, (str, int)):
824+
assert len(values) == 1
825+
return values[0]
826+
elif util.all_of(values, BooleanValue):
827+
return self.filter(values)
837828
else:
838-
return self.select(items)
829+
return self.select(values)
839830

840831
def __len__(self):
841832
raise com.ExpressionError("Use .count() instead")
@@ -878,7 +869,7 @@ def __getattr__(self, key: str) -> ir.Column:
878869
└───────────┘
879870
"""
880871
try:
881-
return self._get_column(key)
872+
return ops.Field(self, key).to_expr()
882873
except com.IbisTypeError:
883874
pass
884875

@@ -2073,7 +2064,7 @@ def select(
20732064
20742065
Projection by zero-indexed column position
20752066
2076-
>>> t.select(0, 4).head()
2067+
>>> t.select(t[0], t[4]).head()
20772068
┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
20782069
┃ species ┃ flipper_length_mm ┃
20792070
┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
@@ -4409,10 +4400,10 @@ def relocate(
44094400
where = 0
44104401

44114402
# all columns that should come BEFORE the matched selectors
4412-
front = [left for left in range(where) if left not in sels]
4403+
front = [self[left] for left in range(where) if left not in sels]
44134404

44144405
# all columns that should come AFTER the matched selectors
4415-
back = [right for right in range(where, ncols) if right not in sels]
4406+
back = [self[right] for right in range(where, ncols) if right not in sels]
44164407

44174408
# selected columns
44184409
middle = [self[i].name(name) for i, name in sels.items()]

ibis/tests/expr/test_selectors.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def test_if_any(penguins):
342342

343343

344344
def test_negate_range(penguins):
345-
assert penguins.select(~s.r[3:]).equals(penguins.select(0, 1, 2))
345+
assert penguins.select(~s.r[3:]).equals(penguins[[0, 1, 2]])
346346

347347

348348
def test_string_range_start(penguins):
@@ -378,16 +378,15 @@ def test_all(penguins):
378378
@pytest.mark.parametrize(
379379
("seq", "expected"),
380380
[
381-
param([0, 1, 2], (0, 1, 2), id="int_tuple"),
382381
param(~s.r[[3, 4, 5]], sorted(set(range(8)) - {3, 4, 5}), id="neg_int_list"),
383382
param(~s.r[3, 4, 5], sorted(set(range(8)) - {3, 4, 5}), id="neg_int_tuple"),
384383
param(s.r["island", "year"], ("island", "year"), id="string_tuple"),
385384
param(s.r[["island", "year"]], ("island", "year"), id="string_list"),
386-
param(iter(["island", 4, "year"]), ("island", 4, "year"), id="mixed_iterable"),
385+
param(iter(["island", "year"]), ("island", "year"), id="mixed_iterable"),
387386
],
388387
)
389388
def test_sequence(penguins, seq, expected):
390-
assert penguins.select(seq).equals(penguins.select(*expected))
389+
assert penguins.select(seq).equals(penguins[expected])
391390

392391

393392
def test_names_callable(penguins):

ibis/tests/expr/test_table.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@
1919
from ibis import _
2020
from ibis.common.annotations import ValidationError
2121
from ibis.common.deferred import Deferred
22-
from ibis.common.exceptions import ExpressionError, IntegrityError, RelationError
22+
from ibis.common.exceptions import (
23+
ExpressionError,
24+
IbisTypeError,
25+
IntegrityError,
26+
RelationError,
27+
)
2328
from ibis.expr import api
2429
from ibis.expr.rewrites import simplify
2530
from ibis.expr.tests.test_newrels import join_tables
@@ -230,7 +235,7 @@ def test_projection_with_star_expr(table):
230235

231236
# cannot pass an invalid table expression
232237
t2 = t.aggregate([t["a"].sum().name("sum(a)")], by=["g"])
233-
with pytest.raises(IntegrityError):
238+
with pytest.raises(IbisTypeError):
234239
t[[t2]]
235240
# TODO: there may be some ways this can be invalid
236241

@@ -581,10 +586,8 @@ def test_order_by_scalar(table, key, expected):
581586
("key", "exc_type"),
582587
[
583588
("bogus", com.IbisTypeError),
584-
# (("bogus", False), com.IbisTypeError),
589+
(("bogus", False), com.IbisTypeError),
585590
(ibis.desc("bogus"), com.IbisTypeError),
586-
(1000, IndexError),
587-
# ((1000, False), IndexError),
588591
(_.bogus, AttributeError),
589592
(_.bogus.desc(), AttributeError),
590593
],
@@ -746,7 +749,7 @@ def test_aggregate_keywords(table):
746749
def test_select_on_literals(table):
747750
# literal ints and strings are column indices, everything else is a value
748751
expr1 = table.select(col1=True, col2=1, col3="a")
749-
expr2 = table.select(col1=ibis.literal(True), col2=table.b, col3=table.a)
752+
expr2 = table.select(col1=ibis.literal(True), col2=ibis.literal(1), col3=table.a)
750753
assert expr1.equals(expr2)
751754

752755

@@ -1280,7 +1283,7 @@ def test_inner_join_overlapping_column_names():
12801283
lambda t1, t2: [(t1.foo_id, t2.foo_id)],
12811284
lambda t1, t2: [(_.foo_id, _.foo_id)],
12821285
lambda t1, t2: [(t1.foo_id, _.foo_id)],
1283-
lambda t1, t2: [(2, 0)], # foo_id is 2nd in t1, 0th in t2
1286+
lambda t1, t2: [(t1[2], t2[0])], # foo_id is 2nd in t1, 0th in t2
12841287
lambda t1, t2: [(lambda t: t.foo_id, lambda t: t.foo_id)],
12851288
],
12861289
)

ibis/tests/expr/test_value_exprs.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -882,18 +882,10 @@ def test_bitwise_exprs(fn, expected_op):
882882
([1, 0], ["bar", "foo"]),
883883
],
884884
)
885-
@pytest.mark.parametrize(
886-
"expr_func",
887-
[
888-
lambda t, args: t[args],
889-
lambda t, args: t.order_by(args),
890-
lambda t, args: t.group_by(args).aggregate(bar_avg=t.bar.mean()),
891-
],
892-
)
893-
def test_table_operations_with_integer_column(position, names, expr_func):
885+
def test_table_operations_with_integer_column(position, names):
894886
t = ibis.table([("foo", "string"), ("bar", "double")])
895-
result = expr_func(t, position)
896-
expected = expr_func(t, names)
887+
result = t[position]
888+
expected = t[names]
897889
assert result.equals(expected)
898890

899891

0 commit comments

Comments
 (0)