diff --git a/deepmd/pt/utils/compile_compat.py b/deepmd/pt/utils/compile_compat.py index 76c6c0c046..3cdfd54da6 100644 --- a/deepmd/pt/utils/compile_compat.py +++ b/deepmd/pt/utils/compile_compat.py @@ -161,7 +161,9 @@ def trace_pad_dim(t: torch.Tensor, dim: int, target: int) -> torch.Tensor: return torch.cat([t, *([last] * repeats)], dim=dim) -def strip_saved_tensor_detach(gm: torch.fx.GraphModule) -> None: +def strip_saved_tensor_detach( + gm: torch.fx.GraphModule, *, remove_all: bool = False +) -> None: """Strip ``aten.detach`` nodes that ``make_fx`` inserts for saved tensors. When ``make_fx`` decomposes ``autograd.grad(..., create_graph=True)``, @@ -171,7 +173,8 @@ def strip_saved_tensor_detach(gm: torch.fx.GraphModule) -> None: model parameters, causing incorrect parameter updates during force-loss training. - User-explicit ``.detach()`` calls are preserved. The two categories are + With ``remove_all=False`` (default), user-explicit ``.detach()`` calls are + preserved. The make_fx-inserted and user-explicit detaches are distinguished by graph topology alone — no hard-coded op names — using three rules: @@ -180,7 +183,14 @@ def strip_saved_tensor_detach(gm: torch.fx.GraphModule) -> None: * *Chain head*: *all* users are detach nodes. Any detach that does **not** match these rules is treated as user-explicit - and left untouched. + and left untouched. This is the right behaviour for the SeZM model + inference compile path, which contains legitimate user ``.detach()`` calls. + + With ``remove_all=True``, *every* detach node is removed unconditionally. + The pt_expt training trace is invoked with already-detached, grad-enabled + inputs and opens with ``coord.detach().requires_grad_(True)``; that + boundary detach must also go or the force-loss gradient path is severed, so + the training path passes ``remove_all=True``. """ _DETACH = torch.ops.aten.detach.default @@ -197,6 +207,9 @@ def _is_detach(n: torch.fx.Node) -> bool: for node in gm.graph.nodes: if not _is_detach(node): continue + if remove_all: + to_remove.append(node) + continue input_node = node.args[0] users = list(node.users.keys()) is_chain_inner = _is_detach(input_node) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 70a880094b..202c5d10de 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -38,6 +38,12 @@ format_training_message, format_training_message_per_task, ) +from deepmd.pt.utils.compile_compat import next_safe_prime as _next_safe_prime +from deepmd.pt.utils.compile_compat import rebuild_graph_module as _rebuild_graph_module +from deepmd.pt.utils.compile_compat import ( + strip_saved_tensor_detach as _strip_saved_tensor_detach, +) +from deepmd.pt.utils.compile_compat import trace_pad_dim as _trace_pad_dim from deepmd.pt_expt.loss import ( DOSLoss, EnergyLoss, @@ -265,45 +271,6 @@ def get_additional_data_requirement(_model: Any) -> list[DataRequirementItem]: # --------------------------------------------------------------------------- -def _remove_detach_nodes(gm: torch.fx.GraphModule) -> None: - """Remove ``aten.detach.default`` nodes from an FX graph in-place. - - ``make_fx`` inserts these nodes when recording saved tensors from the - autograd backward pass (``autograd.grad`` with ``create_graph=True``). - The detach breaks the gradient connection between saved activations and - model parameters, causing incorrect second-order derivatives — e.g. - bias gradients become zero for force-loss training. - - Removing these nodes restores the gradient path so that higher-order - derivatives flow correctly through the decomposed backward ops. - """ - graph = gm.graph - for node in list(graph.nodes): - if node.op == "call_function" and node.target == torch.ops.aten.detach.default: - input_node = node.args[0] - node.replace_all_uses_with(input_node) - graph.erase_node(node) - graph.lint() - gm.recompile() - - -def _rebuild_graph_module(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: - """Copy all nodes into a fresh ``torch.fx.Graph``. - - After ``Graph.erase_node()`` the C-level prev/next pointers on - neighbouring ``Node`` objects may become stale. When ``torch.compile`` - (dynamo) later re-traces the graph it walks these pointers, which can - cause segfaults. Rebuilding into a new graph eliminates stale pointers. - """ - old_graph = gm.graph - new_graph = torch.fx.Graph() - val_map: dict[torch.fx.Node, torch.fx.Node] = {} - for node in old_graph.nodes: - val_map[node] = new_graph.node_copy(node, lambda n: val_map[n]) - new_graph.lint() - return torch.fx.GraphModule(gm, new_graph) - - def _trace_and_compile( model: torch.nn.Module, ext_coord: torch.Tensor, @@ -417,35 +384,69 @@ def fn( if _fitting is not None: _fitting._buffers[name] = orig - # Pick a trace-time nframes that's unlikely to collide with any other - # tensor dim in the graph. The symbolic tracer merges symbols that - # are numerically equal at trace time, which bakes nframes into the - # compiled graph whenever it matches e.g. numb_fparam, numb_aparam, - # ntypes, axis_neuron, or neuron sizes (8, 16, 32, ...). Using a - # prime value of 7 avoids the common small-dim collisions while still - # being cheap to trace. - _TRACE_NFRAMES = 7 - cur_nframes = ext_coord.shape[0] - if cur_nframes != _TRACE_NFRAMES: - - def _expand(t: torch.Tensor | None) -> torch.Tensor | None: - if t is None: - return None - # Repeat rows so total nframes == _TRACE_NFRAMES. Use index - # gather (mod) so we don't require divisibility. - idx = ( - torch.arange(_TRACE_NFRAMES, dtype=torch.long, device=t.device) - % cur_nframes - ) - return t.index_select(0, idx) - - ext_coord = _expand(ext_coord) - ext_atype = _expand(ext_atype) - nlist = _expand(nlist) - mapping = _expand(mapping) - fparam = _expand(fparam) - aparam = _expand(aparam) - charge_spin = _expand(charge_spin) + # Pad nf to a safe prime; keep real nloc and nall from the data. + # + # make_fx (tracing_mode="symbolic") unifies dimension symbols that share + # the same concrete value at trace time (duck-shape merging). We take + # one frame ([:1]) to normalise nf, then pad it to a prime so PyTorch + # does not specialise it as the constant 1. nloc and nall come from + # real data, so they are already too + # large to alias with any architecture dim and need no adjustment. + # + # The prime for nf is chosen by enumerating every dimension that appears + # in the model's parameters and buffers, then calling _next_safe_prime to + # find the first prime that doesn't collide with any of them. This + # catches internal dims like g2_dim, axis_neuron, attn_head, etc. + # without requiring a hardcoded list. + _forbidden: set[int] = { + int(_d) + for _src in (model.parameters(), model.buffers()) + for _p in _src + for _d in _p.shape + if _d > 1 + } + # Also add the real nloc and nall so trace_nf never aliases them. + _forbidden.add(int(ext_coord.shape[1])) # nall + _forbidden.add(int(ext_atype.shape[1])) # nall (same tensor, defensive) + _forbidden.add(int(nlist.shape[1])) # nloc + # nsel stays at its real value; add it to forbidden for the same reason. + _nsel = int(nlist.shape[2]) + if _nsel > 1: + _forbidden.add(_nsel) + try: + _dim_fp = model.get_dim_fparam() + if _dim_fp > 1: + _forbidden.add(_dim_fp) + except Exception: + pass + try: + _dim_ap = model.get_dim_aparam() + if _dim_ap > 1: + _forbidden.add(_dim_ap) + except Exception: + pass + if charge_spin is not None: + _dim_cs = int(charge_spin.shape[1]) + if _dim_cs > 1: + _forbidden.add(_dim_cs) + for _tbv in task_buf_vals_trace: + for _d in _tbv.shape: + if _d > 1: + _forbidden.add(int(_d)) + + trace_nf = _next_safe_prime(5, _forbidden) + + # Pad nf only; nloc and nall retain their real values (no clamping needed). + ext_coord = _trace_pad_dim(ext_coord[:1], 0, trace_nf) + ext_atype = _trace_pad_dim(ext_atype[:1], 0, trace_nf) + nlist = _trace_pad_dim(nlist[:1], 0, trace_nf) + mapping = _trace_pad_dim(mapping[:1], 0, trace_nf) + if fparam is not None: + fparam = _trace_pad_dim(fparam[:1], 0, trace_nf) + if aparam is not None: + aparam = _trace_pad_dim(aparam[:1], 0, trace_nf) + if charge_spin is not None: + charge_spin = _trace_pad_dim(charge_spin[:1], 0, trace_nf) # Decompose silu_backward into primitive ops (sigmoid + mul + ...) # so that inductor can compile the graph without requiring a @@ -476,8 +477,9 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None: # make_fx inserts aten.detach.default for saved tensors used in the # decomposed autograd.grad backward ops. These detach nodes break # second-order gradient flow (d(force)/d(params) for force training). - # Removing them restores correct higher-order derivatives. - _remove_detach_nodes(traced_lower) + # The training trace is fed already-detached, grad-enabled inputs, so + # every detach is removed unconditionally to restore the gradient path. + _strip_saved_tensor_detach(traced_lower, remove_all=True) # Rebuild into a fresh graph to eliminate stale C-level node pointers # left by erase_node(), which can cause segfaults during dynamo re-trace. traced_lower = _rebuild_graph_module(traced_lower) @@ -510,22 +512,37 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None: class _CompiledModel(torch.nn.Module): - """Coord extension (eager) -> compiled forward_lower (dynamic shapes).""" + """Coord extension (eager) -> compiled forward_lower (dynamic shapes). + + Compilation is lazy: ``_trace_and_compile`` is called on the first real + ``forward()`` invocation using that batch's tensors, so no extra + ``get_data()`` call is needed during ``__init__``. Tasks that share the + same model structure reuse the compiled graph via ``compiled_by_structure``. + """ def __init__( self, original_model: torch.nn.Module, - compiled_forward_lower: torch.nn.Module, + structure_key: tuple[int, ...], task_buf_order: tuple[str, ...] = (), task_buffers: dict[str, torch.Tensor] | None = None, + compile_opts: dict[str, Any] | None = None, + compiled_by_structure: dict | None = None, ) -> None: super().__init__() self.original_model = original_model - self.compiled_forward_lower = compiled_forward_lower + self.compiled_forward_lower: torch.nn.Module | None = None self._task_buf_order = task_buf_order - # task_buffers is intentionally not stored: buffers are read from - # original_model.get_fitting_net() at forward time so that weight - # updates (load_state_dict, optimiser steps) are always reflected. + self._structure_key = structure_key + self._compile_opts = compile_opts + # Stored only for the first-forward compile call; freed afterwards. + self._task_buffers = task_buffers + # Shared dict across all _CompiledModel instances in the same Trainer. + # A cache hit lets a second task with the same structure reuse the + # already-traced graph without re-running make_fx. + self._compiled_by_structure: dict = ( + compiled_by_structure if compiled_by_structure is not None else {} + ) def __getattr__(self, name: str) -> Any: # Delegate unknown lookups to original_model so that callers such as @@ -582,6 +599,96 @@ def forward( distinguish_types=False, ) ext_coord = ext_coord.reshape(nframes, -1, 3) + + # Mirror the uncompiled path's optional-input defaulting (see + # ``SeZMModel._forward_common`` -> ``convert_fparam_aparam`` / + # ``convert_charge_spin``): a model configured with fparam or + # charge_spin (``dim > 0``) substitutes its default when the data + # omits it. The compiled ``forward_lower`` is frozen to the *traced* + # branch -- a present optional input bakes ``aten._to_copy(x, ...)`` + # into the graph, while an absent one is dropped during make_fx pytree + # flattening -- so these inputs must be normalized to tensors here, + # before both tracing and every compiled call. Otherwise a graph + # traced with the input present crashes when a later call (e.g. a + # share_params task whose dataset omits it and relies on the default) + # invokes it with None. ``aparam`` has no default (it is required + # whenever ``dim_aparam > 0``), so it needs no normalization; a genuine + # absence is reported by ``forward_lower`` itself, as in eager mode. + # ``get_default_*`` may return either a tensor or a raw ``list[float]`` + # (the sezm descriptor stores ``default_chg_spin`` as a list, and only + # ``sezm_atomic_model`` wraps it via ``new_tensor``; the dp_atomic_model + # family returns the descriptor list as-is), so coerce with + # ``torch.as_tensor`` and ``reshape`` to ``(1, dim)`` before broadcasting. + _model = self.original_model + _dim_fparam = ( + _model.get_dim_fparam() if hasattr(_model, "get_dim_fparam") else 0 + ) + if fparam is None and _dim_fparam > 0: + _default_fparam = _model.get_default_fparam() + if _default_fparam is not None: + fparam = ( + torch.as_tensor( + _default_fparam, dtype=ext_coord.dtype, device=ext_coord.device + ) + .reshape(1, _dim_fparam) + .expand(nframes, -1) + ) + _dim_cs = ( + _model.get_dim_chg_spin() if hasattr(_model, "get_dim_chg_spin") else 0 + ) + if charge_spin is None and _dim_cs > 0: + _default_cs = _model.get_default_chg_spin() + if _default_cs is not None: + charge_spin = ( + torch.as_tensor( + _default_cs, dtype=ext_coord.dtype, device=ext_coord.device + ) + .reshape(1, _dim_cs) + .expand(nframes, -1) + ) + + # Lazy compile: trace on the first real forward call using this + # batch's tensors (prime-padded inside _trace_and_compile). + # Mirrors DPA4's on-cache-miss compile so no separate get_data() + # is needed during __init__. + if self.compiled_forward_lower is None: + # Optional inputs (fparam / charge_spin) are normalized to their + # defaults above, so their presence is now config-driven (a + # function of the model's ``dim_*``) rather than data-driven. + # Tasks sharing this structure key share the same descriptor / + # fitting net and therefore the same dims, so a single compiled + # graph is safe to reuse across them. + if self._structure_key in self._compiled_by_structure: + compiled_lower, buf_order = self._compiled_by_structure[ + self._structure_key + ] + log.info("Reusing compiled graph (shared model structure, lazy).") + else: + log.info( + "Lazy compile: tracing model on first forward call " + "(structure_key=%s).", + self._structure_key, + ) + compiled_lower, buf_order = _trace_and_compile( + self.original_model, + ext_coord, + ext_atype, + nlist, + mapping, + fparam, + aparam, + charge_spin=charge_spin, + task_buffers=self._task_buffers, + compile_opts=self._compile_opts, + ) + self._compiled_by_structure[self._structure_key] = ( + compiled_lower, + buf_order, + ) + self.compiled_forward_lower = compiled_lower + self._task_buf_order = buf_order + self._task_buffers = None # free; no longer needed after compile + ext_coord = ext_coord.detach().requires_grad_(True) if self._task_buf_order: @@ -1139,14 +1246,6 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: # 'meta'`` (pytorch/pytorch#134182). torch._dynamo.config.optimize_ddp = False - from deepmd.dpmodel.utils.nlist import ( - build_neighbor_list, - extend_coord_with_ghosts, - ) - from deepmd.dpmodel.utils.region import ( - normalize_coord, - ) - # Under DDP, self.wrapper is a DistributedDataParallel wrapper; # access the underlying ModelWrapper via .module. wrapper_mod = ( @@ -1200,9 +1299,10 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: wrapper_mod.model[task_key], group_models ) - # structure_key -> (compiled_lower, task_buf_order) - # Tasks with the same structure key (same descriptor + shared fitting) - # reuse the compiled graph; different descriptor or fitting → distinct key. + # Shared cache: structure_key -> (compiled_lower, task_buf_order). + # Tasks with the same structure key reuse the same compiled graph. + # The dict is passed to every _CompiledModel instance so the lazy + # compile on the first forward can populate and share it. _compiled_by_structure: dict[tuple[int, ...], tuple] = {} for task_key in self.model_keys: @@ -1233,69 +1333,16 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: structure_key = _key_for[task_key] task_bufs = _task_bufs_for[task_key] - if structure_key in _compiled_by_structure: - # Shared structure: reuse the already-compiled graph. - compiled_lower, task_buf_order = _compiled_by_structure[structure_key] - log.info( - "Reusing compiled graph for task=%s (shared model structure).", - task_key, - ) - else: - inp, _ = self.get_data(is_train=True, task_key=task_key) - coord = inp["coord"].detach() - atype = inp["atype"].detach() - box = inp.get("box") - if box is not None: - box = box.detach() - - nframes, nloc = atype.shape[:2] - coord_3d = coord.reshape(nframes, nloc, 3) - box_flat = box.reshape(nframes, 9) if box is not None else None - - if box_flat is not None: - coord_norm = normalize_coord( - coord_3d, box_flat.reshape(nframes, 3, 3) - ) - else: - coord_norm = coord_3d - - ext_coord, ext_atype, mapping = extend_coord_with_ghosts( - coord_norm, atype, box_flat, model.get_rcut() - ) - nlist_t = build_neighbor_list( - ext_coord, - ext_atype, - nloc, - model.get_rcut(), - model.get_sel(), - distinguish_types=False, - ) - ext_coord = ext_coord.reshape(nframes, -1, 3) - - fparam = inp.get("fparam") - aparam = inp.get("aparam") - charge_spin = inp.get("charge_spin") - - compiled_lower, task_buf_order = _trace_and_compile( - model, - ext_coord, - ext_atype, - nlist_t, - mapping, - fparam, - aparam, - charge_spin=charge_spin, - task_buffers=task_bufs if task_bufs else None, - compile_opts=compile_opts, - ) - _compiled_by_structure[structure_key] = (compiled_lower, task_buf_order) - wrapper_mod.model[task_key] = _CompiledModel( - model, compiled_lower, task_buf_order, task_bufs + model, + structure_key=structure_key, + task_buf_order=tuple(task_bufs.keys()) if task_bufs else (), + task_buffers=task_bufs if task_bufs else None, + compile_opts=compile_opts, + compiled_by_structure=_compiled_by_structure, ) log.info( - "Model compiled/reused (task=%s, tracing_mode=symbolic, " - "dynamic=True, backend=inductor).", + "Lazy compile registered (task=%s); will trace on first forward call.", task_key, ) diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index cbb8368074..1651bcac0c 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -307,9 +307,9 @@ class TestCompiledModelGetattr(unittest.TestCase): """Unit tests for _CompiledModel attribute delegation. These tests do not require example data or torch.compile — they use a - lightweight mock original_model and a no-op compiled_forward_lower to - verify that __getattr__ correctly forwards unknown attributes/methods to - the wrapped original model. + lightweight mock original_model to verify that __getattr__ correctly + forwards unknown attributes/methods to the wrapped original model. + Compilation is lazy, so no compiled_forward_lower is needed for construction. """ def _make_compiled_model(self): @@ -317,10 +317,6 @@ def _make_compiled_model(self): _CompiledModel, ) - class _FakeForwardLower(torch.nn.Module): - def forward(self, *a, **kw): - pass - class _FakeModel(torch.nn.Module): def get_rcut(self): return 3.0 @@ -335,7 +331,7 @@ def atomic_model(self): def get_descriptor(self): return self - return _CompiledModel(_FakeModel(), _FakeForwardLower()) + return _CompiledModel(_FakeModel(), structure_key=(7,)) def test_delegates_method(self) -> None: """Unknown method calls are forwarded to original_model.""" @@ -355,10 +351,12 @@ def test_delegates_property(self) -> None: def test_own_attrs_not_delegated(self) -> None: """Attributes owned by _CompiledModel itself are NOT delegated.""" cm = self._make_compiled_model() - # original_model and compiled_forward_lower are registered submodules - # of _CompiledModel — they must not fall through to delegation. + # original_model is a registered submodule and must not fall through + # to delegation. compiled_forward_lower is None before the first + # forward call (lazy compile) — accessing it must return None, not + # delegate to original_model. self.assertIsInstance(cm.original_model, torch.nn.Module) - self.assertIsInstance(cm.compiled_forward_lower, torch.nn.Module) + self.assertIsNone(cm.compiled_forward_lower) def test_missing_attr_raises(self) -> None: """Accessing an attribute missing from both wrapper and original raises.""" @@ -403,9 +401,11 @@ def test_compiled_handles_varying_nall(self) -> None: # The wrapper.model should be a _CompiledModel compiled_model = trainer.wrapper.model["Default"] self.assertIsInstance(compiled_model, _CompiledModel) + # Lazy compile: compiled_forward_lower is None before any forward. + self.assertIsNone(compiled_model.compiled_forward_lower) trainer.wrapper.train() - for _ in range(3): + for step in range(3): trainer.optimizer.zero_grad(set_to_none=True) inp, lab = trainer.get_data(is_train=True) lr = trainer.scheduler.get_last_lr()[0] @@ -413,6 +413,10 @@ def test_compiled_handles_varying_nall(self) -> None: loss.backward() trainer.optimizer.step() + # After first forward, compiled_forward_lower must be set. + if step == 0: + self.assertIsNotNone(compiled_model.compiled_forward_lower) + # Loss should be a finite scalar at every step self.assertFalse(torch.isnan(loss)) self.assertFalse(torch.isinf(loss))