Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions deepmd/pt/utils/compile_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def trace_pad_dim(t: torch.Tensor, dim: int, target: int) -> torch.Tensor:
return torch.cat([t, *([last] * repeats)], dim=dim)


def strip_saved_tensor_detach(gm: torch.fx.GraphModule) -> None:
def strip_saved_tensor_detach(
gm: torch.fx.GraphModule, *, remove_all: bool = False
) -> None:
"""Strip ``aten.detach`` nodes that ``make_fx`` inserts for saved tensors.

When ``make_fx`` decomposes ``autograd.grad(..., create_graph=True)``,
Expand All @@ -171,7 +173,8 @@ def strip_saved_tensor_detach(gm: torch.fx.GraphModule) -> None:
model parameters, causing incorrect parameter updates during force-loss
training.

User-explicit ``.detach()`` calls are preserved. The two categories are
With ``remove_all=False`` (default), user-explicit ``.detach()`` calls are
preserved. The make_fx-inserted and user-explicit detaches are
distinguished by graph topology alone — no hard-coded op names — using
three rules:

Expand All @@ -180,7 +183,14 @@ def strip_saved_tensor_detach(gm: torch.fx.GraphModule) -> None:
* *Chain head*: *all* users are detach nodes.

Any detach that does **not** match these rules is treated as user-explicit
and left untouched.
and left untouched. This is the right behaviour for the SeZM model
inference compile path, which contains legitimate user ``.detach()`` calls.

With ``remove_all=True``, *every* detach node is removed unconditionally.
The pt_expt training trace is invoked with already-detached, grad-enabled
inputs and opens with ``coord.detach().requires_grad_(True)``; that
boundary detach must also go or the force-loss gradient path is severed, so
the training path passes ``remove_all=True``.
"""
_DETACH = torch.ops.aten.detach.default

Expand All @@ -197,6 +207,9 @@ def _is_detach(n: torch.fx.Node) -> bool:
for node in gm.graph.nodes:
if not _is_detach(node):
continue
if remove_all:
to_remove.append(node)
continue
input_node = node.args[0]
users = list(node.users.keys())
is_chain_inner = _is_detach(input_node)
Expand Down
Loading
Loading