From a408dbcc4cc6c426b988d1e512df294e3fb18e00 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Fri, 12 Jun 2026 14:48:35 +0800 Subject: [PATCH 1/3] perf(dpa4): remove sync --- deepmd/pt/model/descriptor/sezm_nn/edge_cache.py | 9 ++++++--- deepmd/pt/model/descriptor/sezm_nn/utils.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/model/descriptor/sezm_nn/edge_cache.py b/deepmd/pt/model/descriptor/sezm_nn/edge_cache.py index 0545c82b76..19d25e8d71 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/edge_cache.py +++ b/deepmd/pt/model/descriptor/sezm_nn/edge_cache.py @@ -20,6 +20,7 @@ ) import torch +import torch.nn.functional as F from einops import ( rearrange, ) @@ -450,7 +451,10 @@ def build_edge_cache_from_edges( edge_vec = edge_vec.to(dtype=compute_dtype) edge_keep_f = edge_keep.to(dtype=compute_dtype).unsqueeze(-1) edge_vec = edge_vec * edge_keep_f - edge_vec = edge_vec + (1.0 - edge_keep_f) * edge_vec.new_tensor([0.0, 0.0, 1.0]) + # Masked-out edges (zeroed above) are assigned the canonical +z direction so the + # length normalization and quaternion construction remain finite. Padding the + # keep-complement into the z channel constructs this term entirely on device. + edge_vec = edge_vec + F.pad(1.0 - edge_keep_f, (2, 0)) # === Step 3. Edge length, envelope, and radial basis === with nvtx_range("envelope"): @@ -620,9 +624,8 @@ def _finalize_edge_cache( with nvtx_range("degree"): deg = torch.zeros(n_nodes, dtype=edge_vec.dtype, device=edge_vec.device) # (N,) deg.index_add_(0, dst, edge_env.squeeze(-1).to(dtype=edge_vec.dtype).square()) - floor_tensor = deg.new_tensor(deg_norm_floor) inv_sqrt_deg = rearrange( - torch.rsqrt(deg + floor_tensor), "N -> N 1 1" + torch.rsqrt(deg + deg_norm_floor), "N -> N 1 1" ) # (N, 1, 1) return EdgeFeatureCache( diff --git a/deepmd/pt/model/descriptor/sezm_nn/utils.py b/deepmd/pt/model/descriptor/sezm_nn/utils.py index 0fc7a92b4c..6bb1933b01 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/utils.py +++ b/deepmd/pt/model/descriptor/sezm_nn/utils.py @@ -129,7 +129,7 @@ def safe_norm(x: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: in_dtype = x.dtype if in_dtype in (torch.float16, torch.bfloat16): x = x.float() - eps_sq = x.new_tensor(float(eps) * float(eps)) + eps_sq = float(eps) * float(eps) norm = torch.sqrt(torch.sum(x * x, dim=-1, keepdim=True) + eps_sq) return norm.to(dtype=in_dtype) From c99d187b34848c1e0e94bb3741f39ec86029eba9 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Fri, 12 Jun 2026 17:37:47 +0800 Subject: [PATCH 2/3] feat(dpa4): use edge force and atomic virial --- deepmd/pt/entrypoints/freeze_pt2.py | 8 +- deepmd/pt/model/model/sezm_model.py | 502 ++++++++---------- deepmd/pt/model/model/transform_output.py | 108 ++++ source/tests/pt/model/test_sezm_model.py | 287 ++++++++-- source/tests/pt/model/test_sezm_spin_model.py | 33 +- 5 files changed, 588 insertions(+), 350 deletions(-) diff --git a/deepmd/pt/entrypoints/freeze_pt2.py b/deepmd/pt/entrypoints/freeze_pt2.py index c85671b178..d28d7789f7 100644 --- a/deepmd/pt/entrypoints/freeze_pt2.py +++ b/deepmd/pt/entrypoints/freeze_pt2.py @@ -477,7 +477,7 @@ def freeze_sezm_to_pt2( *, device: torch.device | None = None, head: str | None = None, - atomic_virial: bool = False, + atomic_virial: bool = True, ) -> None: """Freeze a SeZM checkpoint into an AOTInductor ``.pt2`` archive. @@ -495,8 +495,10 @@ def freeze_sezm_to_pt2( ``Default`` head is used when present; otherwise multi-task checkpoints must pass an explicit head. Single-task checkpoints must pass ``None``. atomic_virial - Whether to include per-atom virial outputs in the exported graph. - Disable this for fastest LAMMPS force/energy/total-virial inference. + Whether the exported model exposes per-atom virial. Enabled by + default: the edge-force scatter assembles the per-atom virial as a + free by-product of the single backward, so exporting it carries no + compute cost. """ from torch._inductor import ( aoti_compile_and_package, diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py index 24271f6461..a6cb1f538d 100644 --- a/deepmd/pt/model/model/sezm_model.py +++ b/deepmd/pt/model/model/sezm_model.py @@ -68,7 +68,7 @@ |-- input dtype cast |-- neighbor list built in the extended region '-- forward_common -- ener branch - |-- extended_coord.detach().requires_grad_(True) (NOTE 9) + |-- coords pass through; edge_vec leaf is the grad endpoint (NOTE 9) |-- should_use_compile()? yes -> | |-- trace_and_compile() on cache miss | | |-- make_fx(compute_fn, @@ -78,7 +78,7 @@ | | | * trace inputs use safe prime dims (NOTE 1) | | | * silu_backward is decomposed (NOTE 2) | | | * traced graph already contains the - | | | first autograd.grad over coords + | | | first autograd.grad over edge vecs | | |-- strip_saved_tensor_detach (train only) (NOTE 3) | | |-- rebuild_graph_module (train only) (NOTE 4) | | |-- train: torch.compile(backend="inductor", @@ -89,10 +89,10 @@ '-- communicate_extended_output Subsequent batches look up the cached callable at the same -``(training, do_atomic_virial, has_coord_corr)`` slot of -``compiled_core_compute_cache``. Each slot is retained independently, so -train <-> eval toggles around every ``disp_freq`` / full-validation checkpoint -reuse the other slot's compile product instead of evicting it (NOTE 7). +``(training, has_coord_corr)`` slot of ``compiled_core_compute_cache``. +Each slot is retained independently, so train <-> eval toggles around every +``disp_freq`` / full-validation checkpoint reuse the other slot's compile +product instead of evicting it (NOTE 7). Body of the traced compute ========================== @@ -104,19 +104,21 @@ * ``core_compute`` rebuilds a compact, GPU-friendly edge list from the padded DeePMD neighbor list (``build_edge_list_from_nlist``), with two masked dummy edges appended so the edge tensor has a non-singular - symbolic lower bound (NOTE 10). Edge vectors come from - ``index_select`` on the extended - coordinate tensor, which keeps the gradient path back to coordinates - explicit and safe under symbolic shapes (NOTE 11). -* The SeZM descriptor consumes the edge list and produces per-atom - features. + symbolic lower bound (NOTE 10). Edge vectors are gathered from the + extended coordinate tensor and then **detached into a fresh leaf** + (``edge_vec.detach().requires_grad_(True)``), so the gather lives + entirely outside the autograd region (NOTE 11). +* The SeZM descriptor and the analytical ZBL term (``InterPotential``) + both consume that edge-vector leaf, so the energy depends on coordinates + *only* through ``edge_vec``. * The fitting network predicts per-atom energy; ``apply_out_stat`` adds the per-type statistics and the atom mask zeroes out padding atoms. -* ``fit_output_to_model_output(..., create_graph=self.training)`` calls - ``autograd.grad`` internally to compute ``force = -dE/dx``. - ``create_graph`` is the single toggle that activates the - second-derivative branch for training and omits it at inference - (NOTE 12). +* ``edge_energy_deriv(..., create_graph=self.training)`` runs a *single* + ``autograd.grad(energy, edge_vec)`` (edge-force scatter) and assembles + force, global virial and per-atom virial by scattering the per-edge + gradient back onto the extended atoms. ``create_graph`` is + the single toggle that activates the second-derivative branch for + training and omits it at inference (NOTE 12). Because ``make_fx`` traces *after* that inner ``autograd.grad`` has executed, the resulting FX graph encodes both the forward and the first @@ -280,19 +282,15 @@ NOTE 7 -- Multi-slot compile cache key -------------------------------------- -The key is ``(training, do_atomic_virial, has_coord_corr)`` because all three -fields alter the traced graph topology: +The key is ``(training, has_coord_corr)``: -* ``self.training`` switches ``create_graph`` in - ``fit_output_to_model_output`` -- it toggles the entire - second-derivative branch on or off. -* ``do_atomic_virial`` adds or removes an extra per-atom virial tensor - in the compute output. +* ``self.training`` switches ``create_graph`` in ``edge_energy_deriv`` -- + it toggles the entire second-derivative branch on or off. * ``has_coord_corr`` selects the spin-virial correction branch, changing the compiled callable arity from six tensor inputs to seven. No single compiled graph can serve both variants, so the cache is a -``dict[tuple[bool, bool, bool], Callable]`` named +``dict[tuple[bool, bool], Callable]`` named ``compiled_core_compute_cache``. A single-slot cache would have to evict on every flip, which turns the normal training-loop pattern -- ``train -> eval at every disp_freq -> train`` @@ -335,21 +333,19 @@ arbitrary attributes; ``object.__setattr__`` merely belt-and-braces this invariant for readers of the constructor. -NOTE 9 -- Graph restart via ``detach().requires_grad_(True)`` -------------------------------------------------------------- +NOTE 9 -- Coordinate detach (trace inputs only) +----------------------------------------------- -Before calling into the traced graph we rebind the extended coordinates -to a fresh leaf tensor: ``detach()`` breaks any upstream autograd graph -carried over from the data pipeline, and ``requires_grad_(True)`` -reinstates a grad-endpoint owned by this forward. The subsequent -``autograd.grad`` in ``fit_output_to_model_output`` therefore computes -``dE/dx`` against a graph of known shape and ownership -- the essential -precondition for make_fx symbolic tracing. - -In eval the rebound coordinate still requires grad for the eager -(non-compiled) path's ``autograd.grad``, but the compiled callable runs -under ``torch.no_grad`` so its AOTAutograd inference lowering builds no -outer backward (NOTE 13). +The force autograd endpoint is the per-edge ``edge_vec`` leaf created inside +``core_compute`` (NOTE 11), so the runtime forward passes coordinates through +as-is -- they are never a grad endpoint, and ``edge_vec.detach()`` inside +``core_compute`` already severs any upstream graph the batch might carry. +The trace paths still ``detach()`` their coordinate input: ``compute_fn`` +(via ``_prepare_coord_for_trace``) and the export ``lower_fn`` do so to root +the make_fx / export graph cleanly at ``edge_vec`` with no stray coordinate +grad path. ``autograd.grad`` then computes ``dE/d(edge_vec)`` against a graph +of known shape and ownership, while the coordinate gather sits outside the +differentiated region. NOTE 10 -- Tail dummy edges --------------------------- @@ -368,24 +364,29 @@ dummy's ``edge_mask`` is ``False`` so it contributes exactly zero to every downstream sum or gather. -NOTE 11 -- ``index_select`` for coordinate gradients ----------------------------------------------------- +NOTE 11 -- Edge-vector leaf (gather outside the AD region) +---------------------------------------------------------- -Edge geometry is built with ``coord_flat.index_select(0, src)`` instead -of advanced indexing ``coord_flat[src]``. ``index_select`` registers -an explicit backward that routes gradient cleanly back to the original -extended coordinate tensor. Advanced indexing combined with make_fx -symbolic shapes has previously produced silent gradient truncation in -this project -- the second-derivative gradient over coordinates was -effectively zero, with no error raised. +Edge geometry is gathered from the coordinates with ``torch.gather`` and +then detached into a fresh leaf, ``edge_vec.detach().requires_grad_(True)``. +The single ``autograd.grad(energy, edge_vec)`` of the edge-force scatter +truncates the backward at the edge vectors, so the coordinate gather is a +pure forward op whose backward is never traversed. This keeps the +differentiated region a pure function ``(edge_vec, theta) -> E`` -- gather +and advanced indexing *inside* the differentiated region can silently +truncate the second-order gradient under make_fx, so keeping them out of +that region matters. Force, virial and per-atom virial are then explicit +scatter / outer-product ops rather than autograd by-products. NOTE 12 -- ``create_graph=self.training`` ----------------------------------------- The single toggle that turns force-loss training on. When ``True``, -``autograd.grad`` keeps the graph over the first derivative alive so -the outer optimizer's ``.backward()`` can continue walking it into the -parameters. When ``False`` the double-backward graph is never built, +``autograd.grad(energy, edge_vec)`` keeps the graph over the first +derivative alive so the outer optimizer's ``.backward()`` can continue +walking ``d^2 E / (d edge dtheta)`` into the parameters (the per-edge +gradient feeds the explicit scatter, whose own backward is a cheap +gather). When ``False`` the double-backward graph is never built, saving memory during inference. NOTE 13 -- Inference lowering through ``aot_module_simplified`` @@ -400,9 +401,10 @@ view's extended-atom (``nall``) axis becomes a backed symbol with no input source, and ``produce_guards`` aborts with ``sources must not be empty for symbol s...``. -* Even when it compiles, the grad-bearing input makes AOTAutograd treat - the call as forward+backward and keep the whole forward activation set - alive -- ~3x the eager peak memory, OOM-ing on large inference sweeps. +* Even when it compiles, the materialised first-derivative graph (the + edge backward baked into the forward) makes AOTAutograd treat the call + as forward+backward and keep the whole forward activation set alive -- + ~3x the eager peak memory, OOM-ing on large inference sweeps. So eval lowers the graph with ``aot_module_simplified`` -- AOTAutograd's inference path -- with no Dynamo frontend. It still functionalizes the @@ -483,7 +485,7 @@ ) from deepmd.pt.model.model.transform_output import ( communicate_extended_output, - fit_output_to_model_output, + edge_energy_deriv, ) from deepmd.pt.utils import ( env, @@ -530,10 +532,10 @@ # --------------------------------------------------------------------------- # Multi-task compile sharing # --------------------------------------------------------------------------- -# Maps (structure_key..., training, do_atomic_virial, has_coord_corr) to the -# compiled callable. Tasks whose descriptor and fitting parameters share the -# same Python-object identity after ``share_params(level=0)`` reuse one compiled -# graph, avoiding N x compile-cache growth and duplicated DDP graph boundaries. +# Maps (structure_key..., training, has_coord_corr) to the compiled callable. +# Tasks whose descriptor and fitting parameters share the same Python-object +# identity after ``share_params(level=0)`` reuse one compiled graph, avoiding +# N x compile-cache growth and duplicated DDP graph boundaries. _SEZM_COMPILE_CACHE: dict[tuple[Any, ...], Any] = {} # Maps structure_key -> task_buf_order so every instance in the same group @@ -677,15 +679,15 @@ def __init__( self.lora_config: dict[str, Any] | None = None if lora is None else dict(lora) self._dens_compiled = False self._core_compute_pending_compile_t0: float | None = None - self._core_compute_pending_compile_key: tuple[bool, bool, bool] | None = None + self._core_compute_pending_compile_key: tuple[bool, bool] | None = None self._dens_pending_compile_t0: float | None = None # Store compiled callables outside the nn.Module tree so that # FSDP2 / DDP do not shard or sync its duplicated parameters. # ``compiled_core_compute_cache`` is keyed on - # ``(training, do_atomic_virial, has_coord_corr)`` so every graph - # topology has its own slot; flipping between train and eval for - # validation -- regular, full, or EMA full -- therefore reuses cached - # compile products instead of evicting the other mode. + # ``(training, has_coord_corr)`` so every graph topology has its own + # slot; flipping between train and eval for validation -- regular, + # full, or EMA full -- therefore reuses cached compile products + # instead of evicting the other mode. object.__setattr__(self, "compiled_core_compute_cache", {}) object.__setattr__(self, "compiled_dens_compute", None) # Maps cache_key -> task_buf_order for this instance so forward() @@ -1049,23 +1051,6 @@ def forward_common_after_nlist( ) else: # === Step 1. `ener` path (edges built inside core_compute) === - # NOTE: Rebind the extended coordinates to a fresh leaf - # tensor before entering either ``core_compute`` or the - # compiled callable. ``detach()`` breaks any upstream - # autograd graph carried by the batch (data pipeline - # artefacts, neighbor-list ops) and - # ``requires_grad_(True)`` reinstates a grad-endpoint - # owned exclusively by this forward. The inner - # ``autograd.grad`` inside ``fit_output_to_model_output`` - # will then compute ``dE/dx`` against a graph of known - # shape and ownership -- the essential precondition for - # symbolic make_fx tracing. In eval without coordinate - # gradients a bare detach is enough. - if self.do_grad_r() or self.do_grad_c(): - extended_coord = extended_coord.detach().requires_grad_(True) - else: - extended_coord = extended_coord.detach() - with self.tf32_precision_ctx(): if self.should_use_compile(): fp, ap = self.convert_fp_ap( @@ -1077,11 +1062,7 @@ def forward_common_after_nlist( device=extended_coord.device, ) has_coord_corr = extended_coord_corr is not None - cache_key = ( - bool(self.training), - bool(do_atomic_virial), - has_coord_corr, - ) + cache_key = (bool(self.training), has_coord_corr) if cache_key not in self.compiled_core_compute_cache: self.trace_and_compile( extended_coord, @@ -1091,7 +1072,6 @@ def forward_common_after_nlist( fp, ap, charge_spin, - do_atomic_virial, extended_coord_corr=extended_coord_corr, ) compiled_core_compute = self.compiled_core_compute_cache[cache_key] @@ -1142,10 +1122,8 @@ def forward_common_after_nlist( torch.cuda.synchronize() log.info( "SeZM: finished compiling " - "(mode=%s, atomic_virial=%s, coord_corr=%s) " - "in %.2fs", + "(mode=%s, coord_corr=%s) in %.2fs", "train" if self.training else "eval", - do_atomic_virial, has_coord_corr, time.perf_counter() - self._core_compute_pending_compile_t0, ) @@ -1161,7 +1139,6 @@ def forward_common_after_nlist( fparam=fp, aparam=ap, charge_spin=charge_spin, - do_atomic_virial=do_atomic_virial, extra_nlist_sort=self.need_sorted_nlist_for_lower(), extended_coord_corr=extended_coord_corr, ) @@ -1188,7 +1165,6 @@ def core_compute( fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, charge_spin: torch.Tensor | None = None, - do_atomic_virial: bool = False, comm_dict: dict[str, torch.Tensor] | None = None, extra_nlist_sort: bool = False, extended_coord_corr: torch.Tensor | None = None, @@ -1198,7 +1174,8 @@ def core_compute( Builds compact sparse edges, runs descriptor and fitting evaluation, applies output masking and the optional analytical pair potential, - then calls ``fit_output_to_model_output`` for force / virial. + then calls ``edge_energy_deriv`` (edge-force scatter) for force / + virial / per-atom virial. Parameters ---------- @@ -1216,8 +1193,6 @@ def core_compute( Atomic parameters with shape (nf, nloc, nda), or ``None``. charge_spin Frame-level charge and spin conditions with shape `(nf, 2)`. - do_atomic_virial - Whether to compute per-atom virial. comm_dict Communication data for parallel inference. Currently unused. extra_nlist_sort @@ -1228,22 +1203,33 @@ def core_compute( Returns ------- dict[str, torch.Tensor] - DeePMD lower-style outputs (energy, energy_redu, energy_derv_r, ...). + DeePMD lower-style outputs (energy, energy_redu, energy_derv_r, + energy_derv_c, energy_derv_c_redu, mask). The per-atom virial + (energy_derv_c) is always produced; callers decide whether to keep + it. """ del comm_dict nlist = self.format_nlist( extended_coord, extended_atype, nlist, extra_nlist_sort=extra_nlist_sort ) - _, nloc, _ = nlist.shape + nf, nloc, _ = nlist.shape atype = extended_atype[:, :nloc] descriptor_model = self.atomic_model.descriptor # === Step 1. Build compact sparse edges === - edge_index, edge_vec, edge_mask = self.build_edge_list_from_nlist( - extended_coord=extended_coord, - nlist=nlist, - mapping=mapping, + edge_index, edge_vec, edge_mask, edge_index_ext = ( + self.build_edge_list_from_nlist( + extended_coord=extended_coord, + nlist=nlist, + mapping=mapping, + ) ) + # Edge displacements are the autograd leaf for the force / virial + # backward. The coordinate gather that produced ``edge_vec`` stays a + # pure forward op, so the differentiated region is the function + # ``(edge_vec, theta) -> E``; this keeps the make_fx symbolic trace and + # second-order lowering clean (see doc/outisli/dpa4.md §12.4). + edge_vec = edge_vec.detach().requires_grad_(True) # === Step 2. Descriptor forward === with nvtx_range("SeZM/descriptor"): @@ -1292,39 +1278,51 @@ def core_compute( ).view(out_shape) fit_ret["mask"] = atom_mask - # === Step 5. Inject analytical pair potential === + # === Step 5. Inject analytical pair potential (edge form) === + # ZBL is evaluated from ``edge_vec`` (the autograd leaf) so its force + # and virial flow through the same edge backward as the learned energy. if self.inter_potential is not None: fit_ret["energy"] = fit_ret["energy"] + self.inter_potential( - extended_coord, - extended_atype, - nlist, - nloc, + edge_vec=edge_vec, + edge_index=edge_index, + atype_flat=atype.reshape(-1), + edge_mask=edge_mask, + n_node=nf * nloc, real_type_count=self._get_inter_potential_real_type_count(), - ) - - # === Step 6. Force / virial via fit_output_to_model_output === - # NOTE: ``create_graph=self.training`` is the single toggle that - # activates force-loss training. Internally this calls - # ``torch.autograd.grad(energy, extended_coord, create_graph=...)`` - # to produce ``force = -dE/dx``. When ``True`` the autograd graph - # over the first derivative is kept alive, so the outer - # optimiser's ``.backward()`` can continue differentiating into - # parameters -- that chain is the full - # ``d^2 E / (dx dtheta)`` second derivative. When ``False`` the - # double-backward graph is never built, saving memory during - # inference. The entire reason this file exists -- make_fx, - # detach stripping, graph rebuild -- is to keep that - # second-derivative chain intact after ``torch.compile`` has - # captured the whole thing. - return fit_output_to_model_output( - fit_ret, - self.atomic_output_def(), - extended_coord, - do_atomic_virial=do_atomic_virial, + ).view(nf, nloc, 1) + + # === Step 6. Force / virial via edge-force scatter === + # A single ``autograd.grad(energy, edge_vec)`` inside + # ``edge_energy_deriv`` produces force, global virial and per-atom + # virial together as per-ghost extended tensors for the downstream + # ``communicate_extended_output`` / lower-interface contract. + # ``create_graph=self.training`` keeps the first-derivative graph alive + # so the optimiser's backward can reach the parameters through + # ``d^2 E / (d edge dtheta)`` during force-loss training. + energy_atom = fit_ret["energy"] + energy_redu = torch.sum( + energy_atom.to(env.GLOBAL_PT_ENER_FLOAT_PRECISION), dim=1 + ) + nall = extended_coord.shape[1] + energy_derv_r, energy_derv_c, energy_derv_c_redu = edge_energy_deriv( + energy_redu, + edge_vec, + edge_index_ext[0], + edge_index_ext[1], + edge_mask, + nf, + nall, create_graph=self.training, - mask=fit_ret["mask"], extended_coord_corr=extended_coord_corr, ) + return { + "energy": energy_atom, + "energy_redu": energy_redu, + "energy_derv_r": energy_derv_r, + "energy_derv_c": energy_derv_c, + "energy_derv_c_redu": energy_derv_c_redu, + "mask": fit_ret["mask"], + } def core_compute_dens( self, @@ -1385,7 +1383,7 @@ def core_compute_dens( descriptor_model = self.atomic_model.descriptor # === Step 1. Build compact sparse edges === - edge_index, edge_vec, edge_mask = self.build_edge_list_from_nlist( + edge_index, edge_vec, edge_mask, _ = self.build_edge_list_from_nlist( extended_coord=extended_coord, nlist=nlist, mapping=mapping, @@ -1569,8 +1567,6 @@ def forward_common_lower( extended_coord_corr = extended_coord_corr.reshape( extended_atype.shape[0], -1, 3 ) - if self.do_grad_r() or self.do_grad_c(): - cc_ext = cc_ext.detach().requires_grad_(True) nf = extended_atype.shape[0] charge_spin = self.convert_charge_spin( charge_spin, @@ -1586,7 +1582,6 @@ def forward_common_lower( fparam=fp, aparam=ap, charge_spin=charge_spin, - do_atomic_virial=do_atomic_virial, comm_dict=comm_dict, extra_nlist_sort=extra_nlist_sort, extended_coord_corr=extended_coord_corr, @@ -1606,7 +1601,6 @@ def trace_and_compile( fp: torch.Tensor, ap: torch.Tensor, charge_spin: torch.Tensor, - do_atomic_virial: bool, extended_coord_corr: torch.Tensor | None = None, ) -> None: """Trace ``core_compute()`` with ``make_fx`` and cache the compiled callable. @@ -1634,7 +1628,7 @@ def trace_and_compile( # should share one compiled graph. If a sibling task already compiled, # populate this instance's per-instance caches and return immediately. structure_key = _sezm_structure_key(self) - cache_key = (bool(self.training), bool(do_atomic_virial), has_coord_corr) + cache_key = (bool(self.training), has_coord_corr) full_cache_key = structure_key + cache_key if full_cache_key in _SEZM_COMPILE_CACHE: self.compiled_core_compute_cache[cache_key] = _SEZM_COMPILE_CACHE[ @@ -1642,19 +1636,15 @@ def trace_and_compile( ] self._task_buf_order_cache[cache_key] = _SEZM_TASK_BUF_ORDER[structure_key] log.info( - "SeZM: reusing shared compiled graph " - "(mode=%s, atomic_virial=%s, coord_corr=%s)", + "SeZM: reusing shared compiled graph (mode=%s, coord_corr=%s)", mode, - do_atomic_virial, has_coord_corr, ) return log.info( - "SeZM: start tracing and compiling " - "(mode=%s, atomic_virial=%s, coord_corr=%s)", + "SeZM: start tracing and compiling (mode=%s, coord_corr=%s)", mode, - do_atomic_virial, has_coord_corr, ) @@ -1712,24 +1702,19 @@ def _restore_task_bufs( actual = name[len(FIT_PREFIX) :] _fitting_patch._buffers[actual] = orig - need_coord_grad = self.do_grad_r() or self.do_grad_c() - def _prepare_coord_for_trace(coord: torch.Tensor) -> torch.Tensor: - """Restart the coordinate autograd graph for the traced compute. - - ``detach()`` severs any upstream graph carried by the trace - inputs and ``requires_grad_(True)`` reinstates a fresh - grad-endpoint owned by this compute. The inner - ``autograd.grad`` inside ``fit_output_to_model_output`` then - differentiates against a graph of known shape and ownership -- - the essential precondition for make_fx symbolic tracing to - capture dE/dx as ordinary FX nodes. In the eval-only branch - a bare detach keeps the traced graph free of backward sections. + """Detach the trace input coordinates from any upstream graph. + + The force-autograd endpoint is the per-edge ``edge_vec`` leaf + created inside ``core_compute`` (edge-force scatter), so the + coordinates themselves do not carry a grad endpoint. ``detach()`` + severs any upstream graph from the trace inputs; the inner + ``autograd.grad(energy, edge_vec)`` then differentiates against a + graph of known shape and ownership rooted at ``edge_vec`` -- the + precondition for make_fx to capture dE/d(edge) as ordinary FX + nodes -- while keeping the coordinate gather out of the AD region. """ - if need_coord_grad: - return coord.detach().requires_grad_(True) - else: - return coord.detach() + return coord.detach() # NOTE: compute_fn accepts *task_buf_vals after the fixed tensor args. # make_fx treats each element as a separate placeholder so the compiled @@ -1760,7 +1745,6 @@ def compute_fn( fparam=fp, aparam=ap, charge_spin=charge_spin, - do_atomic_virial=do_atomic_virial, extra_nlist_sort=self.need_sorted_nlist_for_lower(), ) finally: @@ -1792,7 +1776,6 @@ def compute_fn( # type: ignore[misc] fparam=fp, aparam=ap, charge_spin=charge_spin, - do_atomic_virial=do_atomic_virial, extra_nlist_sort=self.need_sorted_nlist_for_lower(), extended_coord_corr=extended_coord_corr, ) @@ -1936,9 +1919,9 @@ def compute_fn( # type: ignore[misc] # from parameter discovery (FSDP2/DDP would otherwise shard or # synchronise the wrapper's duplicated flat parameter views and # silently corrupt training). The cache is keyed on - # ``(training, do_atomic_virial, has_coord_corr)`` so that distinct - # graph topologies coexist without evicting each other on every - # ``model.eval()`` / ``model.train()`` switch. + # ``(training, has_coord_corr)`` so that distinct graph topologies + # coexist without evicting each other on every ``model.eval()`` / + # ``model.train()`` switch. # NOTE: ``dynamic=True`` emits a single kernel per traced # shape symbol, so changes in ``nframes``, ``nall`` or edge # count do not trigger recompiles; and the option dict above @@ -2135,8 +2118,8 @@ def forward_common_lower_exportable( ) -> torch.nn.Module: """Trace ``forward_common_lower`` into an exportable FX ``GraphModule``. - ``make_fx`` unfolds the inner ``autograd.grad`` that - ``fit_output_to_model_output`` performs for force and virial, so + ``make_fx`` unfolds the inner ``autograd.grad(energy, edge_vec)`` + that ``edge_energy_deriv`` performs for force and virial, so the returned module can be handed to :func:`torch.export.export` directly. ``silu_backward`` is decomposed to primitive ops so Inductor never sees an opaque higher-order derivative — the same @@ -2162,11 +2145,12 @@ def lower_fn( aparam_: torch.Tensor | None, charge_spin_: torch.Tensor | None, ) -> dict[str, torch.Tensor]: - # detach + requires_grad_ must live INSIDE the traced closure: - # LAMMPS feeds a plain fp64 non-leaf tensor, and the exported - # graph needs its own grad endpoint for the inner autograd.grad - # that fit_output_to_model_output performs. - ext_coord = ext_coord.detach().requires_grad_(True) + # Detach INSIDE the traced closure so the exported graph never + # captures the upstream LAMMPS tensor. The force-autograd + # endpoint is the per-edge ``edge_vec`` leaf created inside + # ``core_compute`` (edge-force scatter), so the coordinates carry + # no grad endpoint here. + ext_coord = ext_coord.detach() return model.forward_common_lower( ext_coord, ext_atype, @@ -2309,15 +2293,25 @@ def build_edge_list_from_nlist( extended_coord: torch.Tensor, nlist: torch.Tensor, mapping: torch.Tensor | None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Build a compact edge list from DeePMD padded neighbor list. - Edge vectors are computed via ``index_select`` on ``extended_coord`` - so they remain differentiable w.r.t. the input coordinates. Two - masked dummy edges are always appended to avoid data-dependent - empty-edge branches that ``make_fx`` cannot trace and singular - edge-axis guards in Inductor's batched matmul lowering. + Edge vectors are gathered from ``extended_coord`` with ghosts' + real (periodic-image) coordinates, so ``edge_vec`` is the true + minimum-image displacement and the edge-force-scatter virial is + PBC-correct without any extended-coordinate trick. Two masked dummy + edges are always appended to avoid data-dependent empty-edge branches + that ``make_fx`` cannot trace and singular edge-axis guards in + Inductor's batched matmul lowering. + + Two index spaces are returned. ``edge_index`` uses *local* node + indices ``[0, nf * nloc)`` (neighbours mapped to their local image) + and drives message passing in the descriptor. ``edge_index_ext`` + uses *extended* node indices ``[0, nf * nall)`` (neighbours kept as + their ghost) and drives the edge-force scatter in ``core_compute``, + which must emit per-ghost extended tensors for + ``communicate_extended_output`` / ``forward_lower``. Parameters ---------- @@ -2331,30 +2325,29 @@ def build_edge_list_from_nlist( Returns ------- edge_index - Edge indices with shape (2, E+2) where E is valid edge count. + Local (src, dst) indices with shape (2, E+2), values in + ``[0, nf * nloc)``; E is the valid edge count. edge_vec Edge vectors with shape (E+2, 3). edge_mask Boolean mask with shape (E+2,). The two trailing elements are ``False``. + edge_index_ext + Extended (src, dst) indices with shape (2, E+2), values in + ``[0, nf * nall)``, aligned 1:1 with ``edge_index`` / ``edge_vec``. """ nf, nloc, nsel = nlist.shape device = extended_coord.device nall = extended_coord.shape[1] - descriptor_model = self.atomic_model.descriptor - coord_for_diff = extended_coord.to(dtype=descriptor_model.compute_dtype) - - # === Step 1. Build per-edge geometry via index_select (differentiable) === - # NOTE: Edge vectors come from ``coord_flat.index_select(0, ...)`` - # rather than advanced indexing ``coord_flat[...]``. - # ``index_select`` has an explicit, well-defined backward that - # routes gradient cleanly back to the original extended - # coordinate tensor. Advanced indexing combined with make_fx - # symbolic shapes has previously produced silent gradient - # truncation in this project -- the second-derivative gradient - # over coordinates was effectively zero, with no error raised. - # ``torch.where(valid_flat, neighbor_flat, 0)`` sanitises padded - # ``-1`` entries before indexing so we never hit an out-of-range - # gather; the corresponding edges are filtered out below anyway. + + # === Step 1. Build per-edge geometry via gather === + # Edge vectors come from ``torch.gather`` rather than advanced indexing + # ``coord_flat[...]``: gather lowers to an explicit, symbolic-shape + # -friendly op under make_fx, while advanced indexing under symbolic + # shapes can silently truncate gradients. ``core_compute`` detaches the + # result into the ``edge_vec`` autograd leaf, so this gather is a pure + # forward op. ``torch.where(valid_flat, neighbor_flat, 0)`` sanitises + # padded ``-1`` entries before indexing so we never hit an + # out-of-range gather; the corresponding edges are filtered out below. neighbor_flat = nlist.reshape(-1) # ``dst_actual = arange(N*K) // K`` produces the same value # sequence as ``arange(N).repeat_interleave(K)`` but its length @@ -2386,12 +2379,12 @@ def build_edge_list_from_nlist( # visibly bounded by the atom axis. neighbor_safe_2d = neighbor_safe.to(dtype=torch.long).view(nf, nloc * nsel) nei_coord = torch.gather( - coord_for_diff, + extended_coord, 1, neighbor_safe_2d.unsqueeze(-1).expand(-1, -1, 3), ).reshape(-1, 3) dst_coord = torch.gather( - coord_for_diff[:, :nloc, :], + extended_coord[:, :nloc, :], 1, dst_local.view(nf, -1).unsqueeze(-1).expand(-1, -1, 3), ).reshape(-1, 3) @@ -2405,6 +2398,15 @@ def build_edge_list_from_nlist( src_local = torch.gather(mapping, 1, neighbor_safe_2d).reshape(-1) src_actual = f_idx * nloc + src_local.to(dtype=torch.long) + # Extended-index counterparts for the edge-force scatter. The + # neighbour keeps its ghost identity (``neighbor_safe`` indexes + # ``[0, nall)``) while the centre, always a local atom, occupies its + # own slot ``dst_local`` in the extended layout. Scattering the edge + # gradient onto these indices yields per-ghost extended force / + # virial, the exact contract ``communicate_extended_output`` reduces. + src_ext = f_idx * nall + neighbor_safe.to(dtype=torch.long) + dst_ext = f_idx * nall + dst_local + # Filter: valid nlist entry AND src in [0, nloc) AND non-zero distance. src_local_valid = (src_local >= 0) & (src_local < nloc) len_positive = edge_len2 > 1e-10 @@ -2436,13 +2438,20 @@ def build_edge_list_from_nlist( dst_sel = dst_actual.index_select(0, padded_idx) edge_vec_sel = diff.index_select(0, padded_idx) edge_index = torch.stack([src_sel, dst_sel], dim=0) + edge_index_ext = torch.stack( + [ + src_ext.index_select(0, padded_idx), + dst_ext.index_select(0, padded_idx), + ], + dim=0, + ) edge_mask = torch.cat( [ torch.ones(valid_idx.shape[0], dtype=torch.bool, device=device), torch.zeros(dummy_count, dtype=torch.bool, device=device), ] ) - return edge_index, edge_vec_sel, edge_mask + return edge_index, edge_vec_sel, edge_mask, edge_index_ext # ========================================================================= # Input Canonicalization @@ -2935,6 +2944,7 @@ def __init__(self, type_map: list[str], mode: str = "zbl") -> None: if mode != "ZBL": raise ValueError(f"Unknown InterPotential mode: {mode}") self.mode = mode + self.ntypes_real = len(type_map) atomic_numbers = [] for elem in type_map: @@ -2976,89 +2986,21 @@ def _zbl_pair_energy( return _KE_EV_A * zi * zj / r * phi def forward( - self, - extended_coord: torch.Tensor, - extended_atype: torch.Tensor, - nlist: torch.Tensor, - nloc: int, - real_type_count: int | None = None, - ) -> torch.Tensor: - """ - Compute per-atom pair energy from the standard neighbor list path. - - Parameters - ---------- - extended_coord - Coordinates in extended region with shape (nf, nall, 3) in Å. - extended_atype - Atom types in extended region with shape (nf, nall). - nlist - Neighbor list with shape (nf, nloc, nsel). - nloc : int - Number of local atoms. - real_type_count - Number of real atom types. Types with index greater than or equal to - this value are virtual spin types and are masked out of the - analytical potential. If omitted, all configured types are real. - - Returns - ------- - torch.Tensor - Per-atom pair energy with shape (nf, nloc, 1) in eV. - """ - if real_type_count is None: - real_type_count = int(self.atomic_numbers.numel()) - nf = extended_coord.shape[0] - coord64 = extended_coord.to(dtype=torch.float64) - atype_for_z = extended_atype.clamp(min=0) - atype_for_z = torch.where( - atype_for_z >= real_type_count, - atype_for_z - real_type_count, - atype_for_z, - ) - z_all = self.atomic_numbers[atype_for_z] # (nf, nall) - - # === Step 1. Gather neighbor coordinates and types === - nsel = nlist.shape[2] - nlist_clamp = nlist.clamp(min=0) # (nf, nloc, nsel) - nei_coord = torch.gather( - coord64, 1, nlist_clamp.unsqueeze(-1).expand(-1, -1, -1, 3).view(nf, -1, 3) - ).view(nf, nloc, nsel, 3) - atom_coord = coord64[:, :nloc].unsqueeze(2) # (nf, nloc, 1, 3) - diff = nei_coord - atom_coord # (nf, nloc, nsel, 3) - r = diff.norm(dim=-1).clamp(min=1e-10) # (nf, nloc, nsel) - - zi = z_all[:, :nloc].unsqueeze(2).expand_as(r) # (nf, nloc, nsel) - zj_idx = nlist_clamp - zj = torch.gather(z_all, 1, zj_idx.view(nf, -1)).view(nf, nloc, nsel) - - # === Step 2. Compute pair energies === - pair_e = self._zbl_pair_energy(r, zi, zj) # (nf, nloc, nsel) - - # Mask padding entries (nlist == -1) - valid = (nlist >= 0).to(dtype=pair_e.dtype) - center_is_real = (extended_atype[:, :nloc] < real_type_count).unsqueeze(2) - neighbor_atype = torch.gather(extended_atype, 1, nlist_clamp.view(nf, -1)).view( - nf, nloc, nsel - ) - neighbor_is_real = neighbor_atype < real_type_count - valid = valid * (center_is_real & neighbor_is_real).to(dtype=pair_e.dtype) - pair_e = pair_e * valid - - # Half contribution to avoid double-counting - atom_pair_energy = (pair_e * 0.5).sum(dim=-1, keepdim=True) # (nf, nloc, 1) - return atom_pair_energy.to(dtype=extended_coord.dtype) - - def forward_from_edges( self, edge_vec: torch.Tensor, edge_index: torch.Tensor, atype_flat: torch.Tensor, edge_mask: torch.Tensor, n_node: int, + real_type_count: int | None = None, ) -> torch.Tensor: """ - Compute per-atom pair energy from the compile-path edge list. + Compute per-atom pair energy from the sparse edge list. + + The pair sum is evaluated from the descriptor's ``edge_vec`` so that + differentiating ``edge_vec`` (the edge-force-scatter leaf) routes the + analytical force / virial through the same single edge backward as the + learned energy. Parameters ---------- @@ -3072,22 +3014,42 @@ def forward_from_edges( Boolean mask with shape (E,). True means valid edge. n_node : int Number of flattened local nodes. + real_type_count + Number of real atom types. Types ``>= real_type_count`` are + virtual spin types: edges touching them are masked out. If + ``None``, all configured types are real. Returns ------- torch.Tensor Per-atom pair energy with shape (1, N, 1) in eV. """ + if real_type_count is None: + real_type_count = self.ntypes_real src = edge_index[0].to(dtype=torch.long) dst = edge_index[1].to(dtype=torch.long) + # Wrap virtual spin types (>= real_type_count) back onto their real + # parent type so the Z lookup never indexes out of range; their edges + # are masked out below regardless. + atype_for_z = atype_flat.clamp(min=0) + atype_for_z = torch.where( + atype_for_z >= real_type_count, + atype_for_z - real_type_count, + atype_for_z, + ) + r = edge_vec.to(dtype=torch.float64).norm(dim=-1).clamp(min=1e-10) # (E,) - z_all = self.atomic_numbers[atype_flat.clamp(min=0)] # (N,) + z_all = self.atomic_numbers[atype_for_z] # (N,) zi = z_all[src] # (E,) zj = z_all[dst] # (E,) pair_e = self._zbl_pair_energy(r, zi, zj) # (E,) - pair_e = pair_e * edge_mask.to(dtype=pair_e.dtype) + # Drop padded edges and any edge touching a virtual (spin) node. + node_is_real = atype_flat < real_type_count # (N,) + edge_is_real = node_is_real[src] & node_is_real[dst] # (E,) + valid = edge_mask & edge_is_real + pair_e = pair_e * valid.to(dtype=pair_e.dtype) # Half contribution to each destination atom atom_energy = torch.zeros(n_node, dtype=pair_e.dtype, device=pair_e.device) diff --git a/deepmd/pt/model/model/transform_output.py b/deepmd/pt/model/model/transform_output.py index 7cb7df000d..ca3515f650 100644 --- a/deepmd/pt/model/model/transform_output.py +++ b/deepmd/pt/model/model/transform_output.py @@ -206,6 +206,114 @@ def fit_output_to_model_output( return model_ret +def edge_energy_deriv( + energy_redu: torch.Tensor, + edge_vec: torch.Tensor, + src_ext: torch.Tensor, + dst_ext: torch.Tensor, + edge_mask: torch.Tensor, + nf: int, + nall: int, + create_graph: bool, + extended_coord_corr: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Assemble extended force, virial and atomic virial from edge gradients. + + The energy depends on coordinates only through the per-edge displacement + vectors ``edge_vec``. A single ``autograd.grad`` produces the per-edge + gradient ``g_e = dE / d(edge_vec_e)``; force, global virial and per-atom + virial are assembled from it with explicit scatter and outer-product ops. + + With edge ``e`` running from receiver ``dst(e)`` to sender ``src(e)`` and + ``edge_vec_e = r_{src(e)} - r_{dst(e)}``, the chain rule + ``d(edge_vec_e)/dr_k = (delta_{k,src} - delta_{k,dst}) I`` gives the + conservative force and the pairwise virial:: + + F_k = sum_{dst(e)=k} g_e - sum_{src(e)=k} g_e + W = - sum_e g_e (x) edge_vec_e + + ``src_ext`` and ``dst_ext`` index the flattened extended space + ``[0, nf * nall)``, so the scatter produces per-ghost extended tensors + consumed by ``communicate_extended_output`` and the lower interface. + + ``edge_vec`` carries the coordinate precision (``GLOBAL_PT_FLOAT_PRECISION``), + so ``g`` and the assembled force / virial share that dtype -- the dtype the + ``communicate_extended_output`` scatter buffers and the reduced energy + expect. The reduced global virial is summed in + ``GLOBAL_PT_ENER_FLOAT_PRECISION``. + + Parameters + ---------- + energy_redu + Reduced per-frame energy with shape ``(nf, 1)``. + edge_vec + Per-edge displacement leaf with shape ``(E, 3)`` carrying ``requires_grad``. + src_ext, dst_ext + Sender / receiver indices into the flattened extended space, each with + shape ``(E,)``. + edge_mask + Boolean validity mask with shape ``(E,)``. + nf, nall + Frame count and extended-atom count. + create_graph + Keep the first-derivative graph alive so the force-loss second backward + can reach the parameters. + extended_coord_corr + Optional spin virtual-displacement correction with shape + ``(nf, nall, 3)``; adds ``force (x) coord_corr`` per extended atom. + + Returns + ------- + energy_derv_r + Extended force with shape ``(nf, nall, 1, 3)``. + energy_derv_c + Extended per-atom virial with shape ``(nf, nall, 1, 9)``, split + symmetrically between the two endpoints of each edge. + energy_derv_c_redu + Reduced global virial with shape ``(nf, 1, 9)``. + """ + (g,) = torch.autograd.grad( + [energy_redu], + [edge_vec], + grad_outputs=[torch.ones_like(energy_redu)], + create_graph=create_graph, + retain_graph=True, + ) + # Padded edges carry no energy contribution, so their gradient is zero; + # mask defensively before the scatter. + g = torch.where(edge_mask.unsqueeze(-1), g, torch.zeros_like(g)) + + n_ext = nf * nall + # Force: F_k = sum_{dst=k} g_e - sum_{src=k} g_e. + force_flat = torch.zeros(n_ext, 3, dtype=g.dtype, device=g.device) + force_flat = force_flat.index_add(0, dst_ext, g) + force_flat = force_flat.index_add(0, src_ext, -g) + extended_force = force_flat.view(nf, nall, 3) + + # Per-edge virial outer product w_e[k, j] = -g_e^k * edge_vec_e^j, flattened + # to 9 with (force component k, coordinate component j) ordering. + w_edge = -torch.einsum("ek,ej->ekj", g, edge_vec).reshape(-1, 9) + # Atomic virial: split each per-edge tensor symmetrically between endpoints. + half_w = 0.5 * w_edge + av_flat = torch.zeros(n_ext, 9, dtype=g.dtype, device=g.device) + av_flat = av_flat.index_add(0, dst_ext, half_w) + av_flat = av_flat.index_add(0, src_ext, half_w) + extended_virial = av_flat.view(nf, nall, 9) + + if extended_coord_corr is not None: + # Spin: the virtual-atom displacement adds force (x) coord_corr per atom. + corr = ( + extended_force.unsqueeze(-1) + @ extended_coord_corr.unsqueeze(-2).to(extended_force.dtype) + ).reshape(nf, nall, 9) + extended_virial = extended_virial + corr + + energy_derv_r = extended_force.unsqueeze(-2) + energy_derv_c = extended_virial.unsqueeze(-2) + energy_derv_c_redu = energy_derv_c.to(env.GLOBAL_PT_ENER_FLOAT_PRECISION).sum(dim=1) + return energy_derv_r, energy_derv_c, energy_derv_c_redu + + def communicate_extended_output( model_ret: dict[str, torch.Tensor], model_output_def: ModelOutputDef, diff --git a/source/tests/pt/model/test_sezm_model.py b/source/tests/pt/model/test_sezm_model.py index 0396b24ad2..afdd1bb72b 100644 --- a/source/tests/pt/model/test_sezm_model.py +++ b/source/tests/pt/model/test_sezm_model.py @@ -401,8 +401,9 @@ def test_compile_cache_slots_and_eval_shape_change(self) -> None: model_cmp = get_sezm_model(self._build_model_params(use_compile=True)) model_cmp.load_state_dict(model_dyn.state_dict()) - train_key = (True, False, False) - eval_key = (False, False, False) + # Compile cache key is (training, has_coord_corr). + train_key = (True, False) + eval_key = (False, False) # === Step 2. Train-mode forward fills the training slot. === model_cmp.train() @@ -590,7 +591,7 @@ def test_fixed_edge_geometry_matches_standard_cache(self) -> None: wigner_calc=descriptor.wigner_calc, ) - edge_index, edge_vec, edge_mask = model.build_edge_list_from_nlist( + edge_index, edge_vec, edge_mask, _ = model.build_edge_list_from_nlist( extended_coord=extended_coord, nlist=nlist, mapping=mapping, @@ -913,7 +914,7 @@ def _build_wrapper(use_compile: bool) -> ModelWrapper: cache1 = wrapper_cmp.model["water_1"].compiled_core_compute_cache cache2 = wrapper_cmp.model["water_2"].compiled_core_compute_cache self.assertIsNot(cache1, cache2) - train_key = (True, False, False) + train_key = (True, False) self.assertIn(train_key, cache1) self.assertIn(train_key, cache2) c1 = cache1[train_key] @@ -941,8 +942,24 @@ class TestInterPotential(unittest.TestCase): def setUp(self) -> None: self.device = env.DEVICE + def _pair_edges( + self, r: float, atype_pair: list[int] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Two directed edges (i->j and j->i) for one pair at distance r.""" + edge_vec = torch.tensor( + [[r, 0.0, 0.0], [-r, 0.0, 0.0]], + dtype=torch.float64, + device=self.device, + ) + edge_index = torch.tensor( + [[1, 0], [0, 1]], dtype=torch.long, device=self.device + ) + atype_flat = torch.tensor(atype_pair, dtype=torch.long, device=self.device) + edge_mask = torch.tensor([True, True], device=self.device) + return edge_vec, edge_index, atype_flat, edge_mask + def test_zbl_known_value_OO(self) -> None: - """Test ZBL energy for O-O pair at known distance against reference.""" + """ZBL energy for an O-O pair matches the analytic reference.""" pot = InterPotential(type_map=["O", "H"], mode="ZBL").to(self.device) import math @@ -961,20 +978,11 @@ def test_zbl_known_value_OO(self) -> None: ) expected = ke * z_o * z_o / r * phi - extended_coord = torch.tensor( - [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]], - dtype=torch.float64, - device=self.device, - ) - extended_atype = torch.tensor([[0, 0]], dtype=torch.int64, device=self.device) - nlist = torch.tensor([[[1], [0]]], dtype=torch.int64, device=self.device) - - pair_e = pot(extended_coord, extended_atype, nlist, nloc=2) - total_e = pair_e.sum().item() + total_e = pot(*self._pair_edges(r, [0, 0]), n_node=2).sum().item() self.assertAlmostEqual(total_e, expected, places=5) def test_zbl_known_value_OH(self) -> None: - """Test ZBL energy for O-H pair at known distance.""" + """ZBL energy for an O-H pair matches the analytic reference.""" pot = InterPotential(type_map=["O", "H"], mode="ZBL").to(self.device) import math @@ -992,71 +1000,234 @@ def test_zbl_known_value_OH(self) -> None: ) expected = ke * z_o * z_h / r * phi - extended_coord = torch.tensor( - [[[0.0, 0.0, 0.0], [0.8, 0.0, 0.0]]], - dtype=torch.float64, - device=self.device, - ) - extended_atype = torch.tensor([[0, 1]], dtype=torch.int64, device=self.device) - nlist = torch.tensor([[[1], [0]]], dtype=torch.int64, device=self.device) - - pair_e = pot(extended_coord, extended_atype, nlist, nloc=2) - total_e = pair_e.sum().item() + total_e = pot(*self._pair_edges(r, [0, 1]), n_node=2).sum().item() self.assertAlmostEqual(total_e, expected, places=5) def test_zbl_gradient_exists(self) -> None: - """Test that ZBL potential produces valid gradients for force computation.""" + """ZBL produces finite gradients w.r.t. the edge vectors.""" pot = InterPotential(type_map=["O", "H"], mode="ZBL").to(self.device) + edge_vec, edge_index, atype_flat, edge_mask = self._pair_edges(1.0, [0, 1]) + edge_vec = edge_vec.detach().requires_grad_(True) + + pot(edge_vec, edge_index, atype_flat, edge_mask, n_node=2).sum().backward() + self.assertIsNotNone(edge_vec.grad) + self.assertTrue(torch.isfinite(edge_vec.grad).all()) - extended_coord = torch.tensor( - [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]], + def test_virtual_spin_types_masked(self) -> None: + """Edges touching a virtual spin type (>= real_type_count) contribute 0.""" + pot = InterPotential(type_map=["O", "H"], mode="ZBL").to(self.device) + # Node 2 is a virtual spin atom (type 2 >= real_type_count=2). + edge_vec = torch.tensor( + [[1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [0.5, 0.0, 0.0], [-0.5, 0.0, 0.0]], dtype=torch.float64, device=self.device, - requires_grad=True, ) - extended_atype = torch.tensor([[0, 1]], dtype=torch.int64, device=self.device) - nlist = torch.tensor([[[1], [0]]], dtype=torch.int64, device=self.device) + # Edges: (0<->1) real-real, (0<->2) touch virtual node 2. + edge_index = torch.tensor( + [[1, 0, 2, 0], [0, 1, 0, 2]], dtype=torch.long, device=self.device + ) + atype_flat = torch.tensor([0, 1, 2], dtype=torch.long, device=self.device) + edge_mask = torch.tensor([True, True, True, True], device=self.device) - pair_e = pot(extended_coord, extended_atype, nlist, nloc=2) - pair_e.sum().backward() - self.assertIsNotNone(extended_coord.grad) - self.assertTrue(torch.isfinite(extended_coord.grad).all()) + with_virtual = pot( + edge_vec, edge_index, atype_flat, edge_mask, n_node=3, real_type_count=2 + ) + # Only the real-real pair survives. + real_only = pot( + edge_vec[:2], + edge_index[:, :2], + atype_flat, + edge_mask[:2], + n_node=3, + real_type_count=2, + ) + torch.testing.assert_close(with_virtual, real_only) def test_unknown_element_raises(self) -> None: """Test that unknown element raises ValueError.""" with self.assertRaises(ValueError): InterPotential(type_map=["O", "Xx"]) - def test_forward_from_edges(self) -> None: - """Test the compile-path edge-based ZBL computation.""" - pot = InterPotential(type_map=["O", "H"], mode="ZBL").to(self.device) - edge_vec = torch.tensor( - [[1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]], +class TestSeZMEdgeForceScatter(unittest.TestCase): + """Validate the edge-force-scatter force / virial assembly. + + Force, global virial and per-atom virial all come from a single + ``autograd.grad`` truncated at the per-edge displacement vectors + (``edge_energy_deriv``), then scattered back onto atoms. These eager, + float64 finite-difference checks pin the conservative-force guarantee + ``F = -dE/dx`` and the PBC-correct virial ``W = -dE/deps``, and confirm + the half-split per-atom virial sums back to the global virial. The ZBL + cases additionally drive ``InterPotential`` (edge form) through the + same single backward. + """ + + def setUp(self) -> None: + self.device = env.DEVICE + + def _build_model(self, *, bridging_method: str = "none") -> SeZMModel: + """Build a tiny float64 SeZM model with randomized parameters.""" + params = { + "type": "SeZM", + "type_map": ["O", "H"], + "descriptor": { + "type": "SeZM", + "sel": [12, 12], + "rcut": 3.0, + "channels": 4, + "n_focus": 1, + "n_radial": 3, + "radial_mlp": [6], + "use_env_seed": True, + "l_schedule": [1, 0], + "mmax": 1, + "so2_norm": False, + "so2_layers": 1, + "n_atten_head": 1, + "sandwich_norm": [True, False, True, False], + "ffn_neurons": 8, + "ffn_blocks": 1, + "s2_activation": [False, True], + "mlp_bias": False, + "layer_scale": False, + "use_amp": False, + "activation_function": "silu", + "glu_activation": True, + "precision": "float64", + "seed": 7, + }, + "fitting_net": { + "neuron": [8], + "activation_function": "silu", + "precision": "float64", + "seed": 7, + }, + "use_compile": False, + "bridging_method": bridging_method, + "bridging_r_inner": 0.8, + "bridging_r_outer": 1.2, + } + model = get_sezm_model(params) + torch.manual_seed(1234) + with torch.no_grad(): + for p in model.parameters(): + p.copy_(torch.randn_like(p) * 0.1) + model.eval() + return model + + def _frame(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Small periodic frame with dense neighbours inside ``rcut``.""" + coord = torch.tensor( + [ + [ + [0.10, 0.05, 0.00], + [1.05, 0.30, 0.10], + [0.20, 1.40, 0.35], + [1.60, 1.15, 0.20], + [2.20, 0.10, 1.05], + ] + ], dtype=torch.float64, device=self.device, ) - edge_index = torch.tensor( - [[1, 0], [0, 1]], dtype=torch.long, device=self.device - ) - atype_flat = torch.tensor([0, 1], dtype=torch.long, device=self.device) - edge_mask = torch.tensor([True, True], device=self.device) - - result = pot.forward_from_edges(edge_vec, edge_index, atype_flat, edge_mask, 2) - self.assertEqual(result.shape, (1, 2, 1)) - self.assertTrue(torch.isfinite(result).all()) - - extended_coord = torch.tensor( - [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]], + atype = torch.tensor([[0, 1, 0, 1, 0]], dtype=torch.int64, device=self.device) + box = torch.tensor( + [[6.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0, 6.0]], dtype=torch.float64, device=self.device, ) - extended_atype = torch.tensor([[0, 1]], dtype=torch.int64, device=self.device) - nlist = torch.tensor([[[1], [0]]], dtype=torch.int64, device=self.device) - pair_e_nlist = pot(extended_coord, extended_atype, nlist, nloc=2) + return coord, atype, box + + def _energy( + self, + model: SeZMModel, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor, + ) -> torch.Tensor: + return model(coord, atype, box=box)["energy"].squeeze() + + def _check_force_fd(self, bridging_method: str, *, periodic: bool = True) -> None: + model = self._build_model(bridging_method=bridging_method) + coord, atype, box = self._frame() + # box=None exercises the non-periodic (open-boundary / cluster) path: + # the edge-force scatter is PBC-agnostic because it differentiates the + # real per-edge displacement, so the same assembly must hold. + if not periodic: + box = None + force = model(coord, atype, box=box)["force"] + + eps = 1.0e-5 + nloc = coord.shape[1] + fd_force = torch.zeros_like(force) + for a in range(nloc): + for d in range(3): + cp = coord.clone() + cp[0, a, d] += eps + cm = coord.clone() + cm[0, a, d] -= eps + e_plus = self._energy(model, cp, atype, box) + e_minus = self._energy(model, cm, atype, box) + fd_force[0, a, d] = -(e_plus - e_minus) / (2 * eps) + boundary = "periodic" if periodic else "non-periodic" torch.testing.assert_close( - result.sum(), pair_e_nlist.sum().to(result.dtype), atol=1e-8, rtol=1e-8 - ) + force, + fd_force, + atol=1.0e-6, + rtol=1.0e-4, + msg=f"edge-scatter force != finite difference " + f"({bridging_method}, {boundary})", + ) + + def test_force_matches_finite_difference(self) -> None: + """F = -dE/dx for the pure descriptor path.""" + self._check_force_fd("none") + + def test_force_matches_finite_difference_zbl(self) -> None: + """F = -dE/dx with ZBL bridging routed through the edge ZBL form.""" + self._check_force_fd("ZBL") + + def test_force_matches_finite_difference_nonperiodic(self) -> None: + """F = -dE/dx for a non-periodic (box=None) cluster.""" + self._check_force_fd("none", periodic=False) + + def test_virial_matches_strain_finite_difference(self) -> None: + """W = -dE/deps under a random symmetric strain (PBC-correct virial).""" + model = self._build_model(bridging_method="none") + coord, atype, box = self._frame() + virial = model(coord, atype, box=box)["virial"].view(3, 3) + + torch.manual_seed(0) + s = torch.randn(3, 3, dtype=torch.float64, device=self.device) + strain = 1.0e-4 * (s + s.transpose(0, 1)) + eye = torch.eye(3, dtype=torch.float64, device=self.device) + + def deformed_energy(sign: float) -> torch.Tensor: + m = (eye + sign * strain).transpose(0, 1) + coord_d = coord @ m + box_d = (box.view(1, 3, 3) @ m).reshape(1, 9) + return self._energy(model, coord_d, atype, box_d) + + e_plus = deformed_energy(1.0) + e_minus = deformed_energy(-1.0) + # dE/dt|_0 = -, central difference over t = +/-1. + lhs = (strain * virial).sum() + rhs = -(e_plus - e_minus) / 2.0 + torch.testing.assert_close(lhs, rhs, atol=1.0e-8, rtol=1.0e-4) + + def test_atom_virial_sums_to_global_virial(self) -> None: + """Half-split per-atom virial reduces to the global virial.""" + for bridging_method in ("none", "ZBL"): + model = self._build_model(bridging_method=bridging_method) + coord, atype, box = self._frame() + out = model(coord, atype, box=box, do_atomic_virial=True) + torch.testing.assert_close( + out["atom_virial"].sum(dim=1), + out["virial"], + atol=1.0e-10, + rtol=1.0e-6, + msg=f"atom_virial sum != global virial ({bridging_method})", + ) class TestSeZMModelBridging(unittest.TestCase): diff --git a/source/tests/pt/model/test_sezm_spin_model.py b/source/tests/pt/model/test_sezm_spin_model.py index aeaec3f8df..b2a7b1be73 100644 --- a/source/tests/pt/model/test_sezm_spin_model.py +++ b/source/tests/pt/model/test_sezm_spin_model.py @@ -352,33 +352,28 @@ def test_bridging_masks_virtual_pairs(self) -> None: ) self.assertIsNotNone(model.inter_potential) - coord = torch.tensor( - [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.5, 0.0, 0.0]]], + # Nodes 0 (type 0) and 1 (type 1) are real; node 2 (type 2) is a virtual + # spin atom (>= real_type_count=2). Edges touching node 2 must be masked. + edge_vec = torch.tensor( + [[1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [0.5, 0.0, 0.0], [-0.5, 0.0, 0.0]], dtype=torch.float64, device=self.device, ) - atype_with_virtual = torch.tensor( - [[0, 1, 2]], dtype=torch.long, device=self.device - ) - nlist_real_and_virtual = torch.tensor( - [[[1, 2], [0, 2], [0, 1]]], dtype=torch.long, device=self.device - ) - nlist_real_only = torch.tensor( - [[[1, -1], [0, -1], [-1, -1]]], dtype=torch.long, device=self.device + edge_index = torch.tensor( + [[1, 0, 2, 0], [0, 1, 0, 2]], dtype=torch.long, device=self.device ) + atype_flat = torch.tensor([0, 1, 2], dtype=torch.long, device=self.device) + edge_mask = torch.tensor([True, True, True, True], device=self.device) energy_with_virtual = model.inter_potential( - coord, - atype_with_virtual, - nlist_real_and_virtual, - nloc=3, - real_type_count=2, + edge_vec, edge_index, atype_flat, edge_mask, n_node=3, real_type_count=2 ) energy_real_only = model.inter_potential( - coord, - atype_with_virtual, - nlist_real_only, - nloc=3, + edge_vec[:2], + edge_index[:, :2], + atype_flat, + edge_mask[:2], + n_node=3, real_type_count=2, ) From a5cea0fc34589a7641b7aa808d8f85fd4459d891 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Fri, 12 Jun 2026 17:49:02 +0800 Subject: [PATCH 3/3] fix edge force --- source/tests/pt/model/test_sezm_spin_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/pt/model/test_sezm_spin_model.py b/source/tests/pt/model/test_sezm_spin_model.py index b2a7b1be73..207743ee8e 100644 --- a/source/tests/pt/model/test_sezm_spin_model.py +++ b/source/tests/pt/model/test_sezm_spin_model.py @@ -394,7 +394,7 @@ def test_compile_matches_eager(self) -> None: out_eager = eager(self.coord, self.atype, spin=self.spin, box=self.box) out_compiled = compiled(self.coord, self.atype, spin=self.spin, box=self.box) - self.assertIn((False, False, True), compiled.compiled_core_compute_cache) + self.assertIn((False, True), compiled.compiled_core_compute_cache) _assert_close_with_strict_warning( out_compiled["energy"], out_eager["energy"],