Skip to content
This repository was archived by the owner on Mar 21, 2021. It is now read-only.

Commit 5a1dcc1

Browse files
committed
Dedupe computations
1 parent 8ea36ea commit 5a1dcc1

File tree

4 files changed

+46
-41
lines changed

4 files changed

+46
-41
lines changed

qc0/cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import click
2+
from datetime import date
23

34

45
@click.command()

qc0/compile.py

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from sqlalchemy.sql.selectable import Selectable, Join, Alias
55
from .base import Struct
66
from .op import (
7+
Field,
78
Op,
89
Rel,
910
RelVoid,
@@ -42,6 +43,7 @@ class From(Struct):
4243
limit: Any = None
4344
order: Any = None
4445
group_by_columns: List[str] = None
46+
compute: Any = None
4547

4648
def __post_init__(self):
4749
if self.group_by_columns is None:
@@ -125,6 +127,9 @@ def to_select(self, value, *compute):
125127

126128

127129
def op_to_sql(op: Op, from_obj):
130+
if from_obj.compute and id(op) in from_obj.compute:
131+
name = from_obj.compute[id(op)].name
132+
return sa.column(name), from_obj
128133
if op.expr is not None:
129134
expr_collect(op.expr)
130135
expr = None
@@ -169,29 +174,30 @@ def RelVoid_to_sql(rel: RelVoid, from_obj):
169174
@rel_to_sql.register
170175
def RelTable_to_sql(rel: RelTable, from_obj):
171176
if rel.compute:
172-
for _, op in rel.compute:
173-
if op.expr is not None:
174-
expr_collect(op.expr)
177+
for field in rel.compute.values():
178+
if field.op.expr is not None:
179+
expr_collect(field.op.expr)
175180

176181
from_obj = From.make(rel.table)
177182

178183
if rel.compute:
179184
at = from_obj.at
180185
columns = []
181-
for name, op in rel.compute:
182-
expr, from_obj = op_to_sql(op, from_obj.replace(at=at))
183-
columns.append(expr.label(name))
186+
for field in rel.compute.values():
187+
expr, from_obj = op_to_sql(field.op, from_obj.replace(at=at))
188+
columns.append(expr.label(field.name))
184189
from_obj = from_obj.replace(at=at)
185190
from_obj = From.make(from_obj.to_select(None, *columns).alias())
191+
from_obj = from_obj.replace(compute=rel.compute)
186192
return from_obj
187193

188194

189195
@rel_to_sql.register
190196
def RelJoin_to_sql(rel: RelJoin, from_obj):
191197
if rel.compute:
192-
for _, op in rel.compute:
193-
if op.expr is not None:
194-
expr_collect(op.expr)
198+
for field in rel.compute.values():
199+
if field.op.expr is not None:
200+
expr_collect(field.op.expr)
195201

196202
if isinstance(rel.rel, RelAroundParent):
197203
table = rel.fk.column.table
@@ -214,9 +220,9 @@ def RelJoin_to_sql(rel: RelJoin, from_obj):
214220
if rel.compute:
215221
at = from_obj.at
216222
columns = []
217-
for name, op in rel.compute:
218-
expr, from_obj = op_to_sql(op, from_obj.replace(at=at))
219-
columns.append(expr.label(name))
223+
for field in rel.compute.values():
224+
expr, from_obj = op_to_sql(field.op, from_obj.replace(at=at))
225+
columns.append(expr.label(field.name))
220226
from_obj = from_obj.replace(at=at)
221227
from_obj = From.make(from_obj.to_select(None, *columns).alias())
222228

@@ -226,9 +232,9 @@ def RelJoin_to_sql(rel: RelJoin, from_obj):
226232
@rel_to_sql.register
227233
def RelRevJoin_to_sql(rel: RelRevJoin, from_obj):
228234
if rel.compute:
229-
for _, op in rel.compute:
230-
if op.expr is not None:
231-
expr_collect(op.expr)
235+
for field in rel.compute.values():
236+
if field.op.expr is not None:
237+
expr_collect(field.op.expr)
232238

233239
if isinstance(rel.rel, RelParent):
234240
table = rel.fk.parent.table.alias()
@@ -254,9 +260,9 @@ def RelRevJoin_to_sql(rel: RelRevJoin, from_obj):
254260
if rel.compute:
255261
at = from_obj.at
256262
columns = []
257-
for name, op in rel.compute:
258-
expr, from_obj = op_to_sql(op, from_obj.replace(at=at))
259-
columns.append(expr.label(name))
263+
for field in rel.compute.values():
264+
expr, from_obj = op_to_sql(field.op, from_obj.replace(at=at))
265+
columns.append(expr.label(field.name))
260266
from_obj = from_obj.replace(at=at)
261267
from_obj = From.make(from_obj.to_select(None, *columns).alias())
262268

@@ -364,7 +370,8 @@ def build_kernel():
364370
return from_obj
365371

366372
result_columns = [from_obj.current.columns[c.name] for c in tuple(columns)]
367-
for name, op in rel.compute:
373+
for field in rel.compute.values():
374+
op = field.op
368375
assert op.sig is not None
369376
columns, kernel = build_kernel()
370377

@@ -394,7 +401,7 @@ def build_kernel():
394401
inner_sel, *((c.name, c.name) for c in columns), outer=True
395402
)
396403
result_columns.append(
397-
sa.func.coalesce(inner_at.c.value, op.sig.unit).label(name)
404+
sa.func.coalesce(inner_at.c.value, op.sig.unit).label(field.name)
398405
)
399406

400407
from_obj = From.make(
@@ -433,13 +440,9 @@ def ExprColumn_to_sql(op: ExprColumn, from_obj):
433440

434441
@expr_to_sql.register
435442
def ExprCompute_to_sql(expr: ExprCompute, from_obj):
436-
found = None
437-
for name, op in expr.rel.compute:
438-
if op == expr.op:
439-
found = name
440-
break
441-
assert found
442-
return sa.column(found, _selectable=from_obj.at), from_obj
443+
field = expr.rel.compute.get(id(expr.op))
444+
assert field
445+
return sa.column(field.name, _selectable=from_obj.at), from_obj
443446

444447

445448
@expr_to_sql.register
@@ -498,9 +501,12 @@ def ExprColumn_collect(expr: ExprColumn):
498501

499502
@expr_collect.register
500503
def ExprAggregate_collect(expr: ExprCompute):
501-
idx = len(expr.rel.compute)
502-
name = f"compute_{idx}"
503-
expr.rel.compute.append((name, expr.op))
504+
key = id(expr.op)
505+
field = expr.rel.compute.get(key)
506+
if field is None:
507+
idx = len(expr.rel.compute)
508+
name = f"compute_{idx}"
509+
expr.rel.compute[id(expr.op)] = Field(name=name, op=expr.op)
504510

505511

506512
@expr_collect.register

qc0/op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class RelVoid(Rel):
6969

7070

7171
class RelWithCompute(Rel):
72-
compute: List[Tuple[str, Op]]
72+
compute: Dict[int, Tuple[Field]]
7373

7474

7575
class RelTable(RelWithCompute):

qc0/plan.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def Nav_to_op(syn: Nav, parent: Op):
181181

182182
if isinstance(parent.scope, UnivScope):
183183
table = parent.scope.tables[syn.name]
184-
rel = RelTable(table=table, compute=[])
184+
rel = RelTable(table=table, compute={})
185185
scope = TableScope(rel=rel, table=table)
186186
return Op(
187187
rel=rel,
@@ -206,7 +206,7 @@ def Nav_to_op(syn: Nav, parent: Op):
206206
fk = parent.scope.foreign_keys.get(syn.name)
207207
if fk:
208208
assert parent.expr is None, parent.expr
209-
rel = RelJoin(rel=parent.rel, fk=fk, compute=[])
209+
rel = RelJoin(rel=parent.rel, fk=fk, compute={})
210210
scope = TableScope(rel=rel, table=fk.column.table)
211211
return parent.grow_rel(
212212
rel=rel,
@@ -217,7 +217,7 @@ def Nav_to_op(syn: Nav, parent: Op):
217217
fk = parent.scope.rev_foreign_keys.get(syn.name)
218218
if fk:
219219
assert parent.expr is None, parent.expr
220-
rel = RelRevJoin(rel=parent.rel, fk=fk, compute=[])
220+
rel = RelRevJoin(rel=parent.rel, fk=fk, compute={})
221221
scope = TableScope(rel=rel, table=fk.parent.table)
222222
return parent.grow_rel(
223223
rel=rel,
@@ -233,11 +233,7 @@ def Nav_to_op(syn: Nav, parent: Op):
233233
elif isinstance(parent.scope, RecordScope):
234234
if syn.name in parent.scope.fields:
235235
op_field = parent.scope.op_fields[syn.name]
236-
if (
237-
parent.card == Cardinality.SEQ
238-
and op_field.op.sig is not None
239-
and op_field.op.sig != JsonAggSig
240-
):
236+
if op_field.op.sig and op_field.op.sig != JsonAggSig:
241237
scope = parent.scope
242238
while isinstance(scope, RecordScope):
243239
scope = scope.parent
@@ -346,7 +342,9 @@ def Select_to_op(syn: Select, parent: Op):
346342
if field_op.card == Cardinality.SEQ:
347343
field_op = field_op.aggregate(JsonAggSig)
348344
fields[name] = Field(op=field_op, name=name)
349-
scope = RecordScope(parent=parent.scope, fields=syn.fields, op_fields=fields)
345+
scope = RecordScope(
346+
parent=parent.scope, fields=syn.fields, op_fields=fields
347+
)
350348
return parent.replace(scope=scope)
351349

352350

@@ -416,7 +414,7 @@ def Apply_to_op(syn: Apply, parent: Op):
416414
op = op.grow_expr(ExprIdentity(table=op.scope.table))
417415
fields[name] = Field(op=op, name=name)
418416

419-
rel = RelGroup(rel=parent.rel, fields=fields, compute=[])
417+
rel = RelGroup(rel=parent.rel, fields=fields, compute={})
420418
scope = GroupScope(scope=parent.scope, fields=syn.args, rel=rel)
421419
card = Cardinality.SEQ if fields else Cardinality.ONE
422420
return parent.grow_rel(rel=rel, syn=syn, card=card, scope=scope)

0 commit comments

Comments
 (0)