Skip to content
Open

Typing #1792

Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Standardize various method overrides
  • Loading branch information
Armavica committed Dec 17, 2025
commit dee236c59e81a62d4ceaf516034285fb2db37bbb
94 changes: 47 additions & 47 deletions pytensor/graph/destroyhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,72 +521,72 @@ def fast_destroy(self, fgraph, app, reason):
# assert len(v) <= 1
# assert len(d) <= 1

def on_import(self, fgraph, app, reason):
def on_import(self, fgraph, node, reason):
"""
Add Apply instance to set which must be computed.

"""
if app in self.debug_all_apps:
if node in self.debug_all_apps:
raise ProtocolError("double import")
self.debug_all_apps.add(app)
self.debug_all_apps.add(node)
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)

# If it's a destructive op, add it to our watch list
dmap = app.op.destroy_map
vmap = app.op.view_map
dmap = node.op.destroy_map
vmap = node.op.view_map
if dmap:
self.destroyers.add(app)
self.destroyers.add(node)
if self.algo == "fast":
self.fast_destroy(fgraph, app, reason)
self.fast_destroy(fgraph, node, reason)

# add this symbol to the forward and backward maps
for o_idx, i_idx_list in vmap.items():
if len(i_idx_list) > 1:
raise NotImplementedError(
"destroying this output invalidates multiple inputs", (app.op)
"destroying this output invalidates multiple inputs", (node.op)
)
o = app.outputs[o_idx]
i = app.inputs[i_idx_list[0]]
o = node.outputs[o_idx]
i = node.inputs[i_idx_list[0]]
self.view_i[o] = i
self.view_o.setdefault(i, OrderedSet()).add(o)

# update self.clients
for i, input in enumerate(app.inputs):
self.clients.setdefault(input, {}).setdefault(app, 0)
self.clients[input][app] += 1
for i, input in enumerate(node.inputs):
self.clients.setdefault(input, {}).setdefault(node, 0)
self.clients[input][node] += 1

for i, output in enumerate(app.outputs):
for i, output in enumerate(node.outputs):
self.clients.setdefault(output, {})

self.stale_droot = True

def on_prune(self, fgraph, app, reason):
def on_prune(self, fgraph, node, reason):
"""
Remove Apply instance from set which must be computed.

"""
if app not in self.debug_all_apps:
if node not in self.debug_all_apps:
raise ProtocolError("prune without import")
self.debug_all_apps.remove(app)
self.debug_all_apps.remove(node)

# UPDATE self.clients
for input in set(app.inputs):
del self.clients[input][app]
for input in set(node.inputs):
del self.clients[input][node]

if app.op.destroy_map:
self.destroyers.remove(app)
if node.op.destroy_map:
self.destroyers.remove(node)

# Note: leaving empty client dictionaries in the struct.
# Why? It's a pain to remove them. I think they aren't doing any harm, they will be
# deleted on_detach().

# UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in app.op.view_map.items():
for o_idx, i_idx_list in node.op.view_map.items():
if len(i_idx_list) > 1:
# destroying this output invalidates multiple inputs
raise NotImplementedError()
o = app.outputs[o_idx]
i = app.inputs[i_idx_list[0]]
o = node.outputs[o_idx]
i = node.inputs[i_idx_list[0]]

del self.view_i[o]

Expand All @@ -595,53 +595,53 @@ def on_prune(self, fgraph, app, reason):
del self.view_o[i]

self.stale_droot = True
if app in self.fail_validate:
del self.fail_validate[app]
if node in self.fail_validate:
del self.fail_validate[node]

def on_change_input(self, fgraph, app, i, old_r, new_r, reason):
def on_change_input(self, fgraph, node, i, var, new_var, reason=None):
"""
app.inputs[i] changed from old_r to new_r.
node.inputs[i] changed from var to new_var.

"""
if isinstance(app.op, Output):
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
if isinstance(node.op, Output):
# node == 'output' is special key that means FunctionGraph is redefining which nodes are being
# considered 'outputs' of the graph.
pass
else:
if app not in self.debug_all_apps:
if node not in self.debug_all_apps:
raise ProtocolError("change without import")

# UPDATE self.clients
self.clients[old_r][app] -= 1
if self.clients[old_r][app] == 0:
del self.clients[old_r][app]
self.clients[var][node] -= 1
if self.clients[var][node] == 0:
del self.clients[var][node]

self.clients.setdefault(new_r, {}).setdefault(app, 0)
self.clients[new_r][app] += 1
self.clients.setdefault(new_var, {}).setdefault(node, 0)
self.clients[new_var][node] += 1

# UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in app.op.view_map.items():
for o_idx, i_idx_list in node.op.view_map.items():
if len(i_idx_list) > 1:
# destroying this output invalidates multiple inputs
raise NotImplementedError()
i_idx = i_idx_list[0]
output = app.outputs[o_idx]
output = node.outputs[o_idx]
if i_idx == i:
if app.inputs[i_idx] is not new_r:
if node.inputs[i_idx] is not new_var:
raise ProtocolError("wrong new_r on change")

self.view_i[output] = new_r
self.view_i[output] = new_var

self.view_o[old_r].remove(output)
if not self.view_o[old_r]:
del self.view_o[old_r]
self.view_o[var].remove(output)
if not self.view_o[var]:
del self.view_o[var]

self.view_o.setdefault(new_r, OrderedSet()).add(output)
self.view_o.setdefault(new_var, OrderedSet()).add(output)

if self.algo == "fast":
if app in self.fail_validate:
del self.fail_validate[app]
self.fast_destroy(fgraph, app, reason)
if node in self.fail_validate:
del self.fail_validate[node]
self.fast_destroy(fgraph, node, reason)
self.stale_droot = True

def validate(self, fgraph):
Expand Down
22 changes: 11 additions & 11 deletions pytensor/graph/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,11 +416,11 @@ def on_detach(self, fgraph):
del fgraph.revert
del self.history[fgraph]

def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
def on_change_input(self, fgraph, node, i, var, new_var, reason=None):
if self.history[fgraph] is None:
return
h = self.history[fgraph]
h.append(LambdaExtract(fgraph, node, i, r, reason))
h.append(LambdaExtract(fgraph, node, i, var, reason))

def revert(self, fgraph, checkpoint):
"""
Expand Down Expand Up @@ -544,9 +544,9 @@ def on_attach(self, fgraph):
raise ValueError("Full History already attached to another fgraph")
self.fg = fgraph

def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
self.bw.append(LambdaExtract(fgraph, node, i, r, reason))
self.fw.append(LambdaExtract(fgraph, node, i, new_r, reason))
def on_change_input(self, fgraph, node, i, var, new_var, reason=None):
self.bw.append(LambdaExtract(fgraph, node, i, var, reason))
self.fw.append(LambdaExtract(fgraph, node, i, new_var, reason))
self.pointer += 1
if self.callback:
self.callback()
Expand Down Expand Up @@ -832,15 +832,15 @@ class PreserveVariableAttributes(Feature):
This preserve some variables attributes and tag during optimization.
"""

def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
def on_change_input(self, fgraph, node, i, var, new_var, reason=None):
# Don't change the name of constants
if r.owner and r.name is not None and new_r.name is None:
new_r.name = r.name
if var.owner and var.name is not None and new_var.name is None:
new_var.name = var.name
if (
getattr(r.tag, "nan_guard_mode_check", False)
and getattr(new_r.tag, "nan_guard_mode_check", False) is False
getattr(var.tag, "nan_guard_mode_check", False)
and getattr(new_var.tag, "nan_guard_mode_check", False) is False
):
new_r.tag.nan_guard_mode_check = r.tag.nan_guard_mode_check
new_var.tag.nan_guard_mode_check = var.tag.nan_guard_mode_check


class NoOutputFromInplace(Feature):
Expand Down
18 changes: 9 additions & 9 deletions pytensor/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,15 +550,15 @@ def on_attach(self, fgraph):
def clone(self):
return type(self)()

def on_change_input(self, fgraph, node, i, r, new_r, reason):
def on_change_input(self, fgraph, node, i, var, new_var, reason=None):
if node in self.nodes_seen:
# If inputs to a node change, it's not guaranteed that the node is
# distinct from the other nodes in `self.nodes_seen`.
self.nodes_seen.discard(node)
self.process_node(fgraph, node)

if isinstance(new_r, AtomicVariable):
self.process_atomic(fgraph, new_r)
if isinstance(new_var, AtomicVariable):
self.process_atomic(fgraph, new_var)

def on_import(self, fgraph, node, reason):
for c in node.inputs:
Expand Down Expand Up @@ -973,7 +973,7 @@ def __init__(self, fn, tracks=None, requirements=()):
)
self.requirements = requirements

def transform(self, fgraph, node, enforce_tracks: bool = True):
def transform(self, fgraph, node, enforce_tracks: bool = True, *args, **kwargs):
if enforce_tracks and self._tracks:
node_op = node.op
if not (
Expand Down Expand Up @@ -1230,7 +1230,7 @@ def tracks(self):
t.extend(at)
return t

def transform(self, fgraph, node, enforce_tracks=False):
def transform(self, fgraph, node, enforce_tracks=False, *args, **kwargs):
if len(self.rewrites) == 0:
return

Expand Down Expand Up @@ -1385,7 +1385,7 @@ def __init__(self, op1, op2, transfer_tags=True):
def tracks(self):
return [self.op1]

def transform(self, fgraph, node, enforce_tracks=True):
def transform(self, fgraph, node, enforce_tracks=True, *args, **kwargs):
if enforce_tracks and (node.op != self.op1):
return False
repl = self.op2.make_node(*node.inputs)
Expand Down Expand Up @@ -1713,9 +1713,9 @@ def on_prune(self, fgraph, node, reason):
if self.pruner:
self.pruner(node)

def on_change_input(self, fgraph, node, i, r, new_r, reason):
def on_change_input(self, fgraph, node, i, var, new_var, reason=None):
if self.chin:
self.chin(node, i, r, new_r, reason)
self.chin(node, i, var, new_var, reason)

def on_detach(self, fgraph):
# To allow pickling this object
Expand Down Expand Up @@ -2160,7 +2160,7 @@ def on_import(self, fgraph, node, reason):
self.nb_imported += 1
self.changed = True

def on_change_input(self, fgraph, node, i, r, new_r, reason):
def on_change_input(self, fgraph, node, i, var, new_var, reason=None):
self.changed = True

def reset(self):
Expand Down
8 changes: 4 additions & 4 deletions pytensor/graph/rewriting/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,10 @@ def __init__(
self.__position__ = {}
self.failure_callback = failure_callback

def register(self, name, obj, *tags, **kwargs):
def register(self, name, rewriter, *tags, **kwargs):
position = kwargs.pop("position", "last")

super().register(name, obj, *tags, **kwargs)
super().register(name, rewriter, *tags, **kwargs)

if position == "last":
if len(self.__position__) == 0:
Expand Down Expand Up @@ -497,8 +497,8 @@ def __init__(
self.node_rewriter = node_rewriter
self.__name__: str = ""

def register(self, name, obj, *tags, position="last", **kwargs):
super().register(name, obj, *tags, position=position, **kwargs)
def register(self, name, rewriter, *tags, position="last", **kwargs):
super().register(name, rewriter, *tags, position=position, **kwargs)

def query(self, *tags, **kwtags):
rewrites = list(super().query(*tags, **kwtags))
Expand Down
2 changes: 1 addition & 1 deletion pytensor/graph/rewriting/kanren.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def results_filter(
self.node_filter = node_filter
super().__init__()

def transform(self, fgraph, node, enforce_tracks: bool = True):
def transform(self, fgraph, node, enforce_tracks: bool = True, *args, **kwargs):
if self.node_filter(node) is False:
return False

Expand Down
4 changes: 2 additions & 2 deletions pytensor/misc/ordered_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def __init__(self, iterable: Iterable | None = None) -> None:
else:
self.values = dict.fromkeys(iterable)

def __contains__(self, value) -> bool:
return value in self.values
def __contains__(self, x) -> bool:
return x in self.values

def __iter__(self) -> Iterator:
yield from self.values
Expand Down
Loading
Loading