From 3bdc8f912e17bae5894c96246c4ea471df6bd999 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 12 Jun 2026 16:56:31 +0800 Subject: [PATCH 01/12] feat(pt_expt): DPA4 descriptor wrapper with converter registrations --- .../dpmodel/descriptor/dpa4_nn/embedding.py | 8 +- deepmd/dpmodel/descriptor/dpa4_nn/so2.py | 11 +- deepmd/pt_expt/descriptor/__init__.py | 4 + deepmd/pt_expt/descriptor/dpa4.py | 66 ++++++++ source/tests/pt_expt/descriptor/test_dpa4.py | 147 ++++++++++++++++++ 5 files changed, 232 insertions(+), 4 deletions(-) create mode 100644 deepmd/pt_expt/descriptor/dpa4.py create mode 100644 source/tests/pt_expt/descriptor/test_dpa4.py diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/embedding.py b/deepmd/dpmodel/descriptor/dpa4_nn/embedding.py index e20e69fc87..152b8b5a55 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/embedding.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/embedding.py @@ -548,7 +548,13 @@ def __init__( seed=child_seed(seed, 3), trainable=self.trainable, ) - self.output_proj.w = np.zeros_like(self.output_proj.w) + # Use an explicit shape/dtype instead of np.zeros_like(self.output_proj.w): + # in pt_expt the attribute is a requires-grad torch Parameter, on which + # numpy __array__ conversion raises. + self.output_proj.w = np.zeros( + (self.embed_dim * self.axis_dim, 2 * self.channels), + dtype=PRECISION_DICT[self.precision.lower()], + ) def call( self, diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/so2.py b/deepmd/dpmodel/descriptor/dpa4_nn/so2.py index eca0209e93..12b0ff5371 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/so2.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/so2.py @@ -426,15 +426,20 @@ def _load_variables(self, variables: dict[str, Any]) -> None: self.bias0 = np.asarray(variables["bias0"], dtype=prec).reshape( self.bias0.shape ) + # Rebuild the list and assign the whole attribute (rather than + # item-assignment) so that pt_expt, which converts the list to a + # torch ParameterList, can re-convert the new value cleanly. + new_weight_m = [] for m_idx in range(len(self.weight_m)): key = f"weight_m.{m_idx}" value = np.asarray(variables[key], dtype=prec) - if value.shape != self.weight_m[m_idx].shape: + if value.shape != tuple(self.weight_m[m_idx].shape): raise ValueError( f"{key} shape {value.shape} does not match the expected " - f"shape {self.weight_m[m_idx].shape}" + f"shape {tuple(self.weight_m[m_idx].shape)}" ) - self.weight_m[m_idx] = value + new_weight_m.append(value) + self.weight_m = new_weight_m def serialize(self) -> dict[str, Any]: """Serialize the SO2Linear to a dict (pt-compatible format).""" diff --git a/deepmd/pt_expt/descriptor/__init__.py b/deepmd/pt_expt/descriptor/__init__.py index 8253ed6338..163a6788d8 100644 --- a/deepmd/pt_expt/descriptor/__init__.py +++ b/deepmd/pt_expt/descriptor/__init__.py @@ -17,6 +17,9 @@ from .dpa3 import ( DescrptDPA3, ) +from .dpa4 import ( + DescrptDPA4, +) from .hybrid import ( DescrptHybrid, ) @@ -41,6 +44,7 @@ "DescrptDPA1", "DescrptDPA2", "DescrptDPA3", + "DescrptDPA4", "DescrptHybrid", "DescrptSeA", "DescrptSeAttenV2", diff --git a/deepmd/pt_expt/descriptor/dpa4.py b/deepmd/pt_expt/descriptor/dpa4.py new file mode 100644 index 0000000000..9aa2473560 --- /dev/null +++ b/deepmd/pt_expt/descriptor/dpa4.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.dpa4 import DescrptDPA4 as DescrptDPA4DP +from deepmd.dpmodel.descriptor.dpa4_nn.activation import SwiGLU as SwiGLUDP +from deepmd.dpmodel.descriptor.dpa4_nn.wignerd import ( + WignerDCalculator as WignerDCalculatorDP, +) +from deepmd.pt_expt.common import ( + register_dpmodel_mapping, + torch_module, +) +from deepmd.pt_expt.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.pt_expt.utils.update_sel import ( + UpdateSel, +) + + +@torch_module +class WignerDCalculator(WignerDCalculatorDP): + def forward(self, *args: Any, **kwargs: Any) -> Any: + return self.call(*args, **kwargs) + + +# WignerDCalculator.deserialize raises NotImplementedError by design (its +# tables are derived constants); rebuild from the stored constructor args. +register_dpmodel_mapping( + WignerDCalculatorDP, + lambda v: WignerDCalculator(v.lmax, eps=v.eps, precision=v.precision), +) + + +@torch_module +class SwiGLU(SwiGLUDP): + def forward(self, *args: Any, **kwargs: Any) -> Any: + return self.call(*args, **kwargs) + + +# SwiGLU is parameter-free (no serialize); rebuild fresh. +register_dpmodel_mapping(SwiGLUDP, lambda v: SwiGLU()) + + +@BaseDescriptor.register("SeZM") +@BaseDescriptor.register("sezm") +@BaseDescriptor.register("DPA4") +@BaseDescriptor.register("dpa4") +@torch_module +class DescrptDPA4(DescrptDPA4DP): + _update_sel_cls = UpdateSel + + def forward(self, *args: Any, **kwargs: Any) -> Any: + return self.call(*args, **kwargs) + + def share_params( + self, + base_class: "DescrptDPA4", + shared_level: int, + model_prob: float = 1.0, + resume: bool = False, + ) -> None: + # Multi-task parameter sharing for DPA4 is out of scope for this PR. + raise NotImplementedError("share_params is not yet implemented for DescrptDPA4") diff --git a/source/tests/pt_expt/descriptor/test_dpa4.py b/source/tests/pt_expt/descriptor/test_dpa4.py new file mode 100644 index 0000000000..20efe20b28 --- /dev/null +++ b/source/tests/pt_expt/descriptor/test_dpa4.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import numpy as np +import pytest +import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) + +from deepmd.dpmodel.descriptor.dpa4 import DescrptDPA4 as DPDescrptDPA4 +from deepmd.pt_expt.descriptor.dpa4 import ( + DescrptDPA4, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.env import ( + PRECISION_DICT, +) + +from ...common.test_mixins import ( + TestCaseSingleFrameWithNlist, +) + + +def make_descriptor(nt, sel, rcut, **overrides) -> DescrptDPA4: + kwargs = { + "ntypes": nt, + "sel": sel, + "rcut": rcut, + "channels": 16, + "n_radial": 8, + "lmax": 2, + "mmax": 1, + "n_blocks": 2, + "grid_branch": [1, 1, 1], + "s2_activation": [False, True], + "random_gamma": False, + "precision": "float64", + "seed": 7, + } + kwargs.update(overrides) + return DescrptDPA4(**kwargs) + + +class TestDescrptDPA4(TestCaseSingleFrameWithNlist): + def setup_method(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + @pytest.mark.parametrize("use_env_seed", [True, False]) # env seed feature + @pytest.mark.parametrize("use_mapping", [True, False]) # pass mapping vs None + def test_consistency(self, use_env_seed, use_mapping) -> None: + dtype = PRECISION_DICT["float64"] + err_msg = f"use_env_seed={use_env_seed} use_mapping={use_mapping}" + dd0 = make_descriptor( + self.nt, + self.sel_mix, + self.rcut, + use_env_seed=use_env_seed, + ).to(self.device) + coord_ext = torch.tensor(self.coord_ext, dtype=dtype, device=self.device) + atype_ext = torch.tensor(self.atype_ext, dtype=int, device=self.device) + nlist = torch.tensor(self.nlist, dtype=int, device=self.device) + mapping = ( + torch.tensor(self.mapping, dtype=int, device=self.device) + if use_mapping + else None + ) + rd0 = dd0(coord_ext, atype_ext, nlist, mapping)[0] + # serialization round-trip within pt_expt + dd1 = DescrptDPA4.deserialize(dd0.serialize()) + rd1 = dd1(coord_ext, atype_ext, nlist, mapping)[0] + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=1e-12, + atol=1e-14, + err_msg=err_msg, + ) + # dpmodel (numpy) impl + dd2 = DPDescrptDPA4.deserialize(dd0.serialize()) + rd2 = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + mapping=self.mapping if use_mapping else None, + )[0] + # CPU: strict same-math parity; CUDA: ULP / nondeterministic reduction slack + if self.device == "cpu" or str(self.device) == "cpu": + rtol, atol = 1e-12, 1e-14 + else: + rtol, atol = 1e-10, 1e-12 + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + @pytest.mark.parametrize("prec", ["float64"]) # precision + def test_exportable(self, prec) -> None: + dtype = PRECISION_DICT[prec] + dd0 = make_descriptor(self.nt, self.sel_mix, self.rcut, precision=prec).to( + self.device + ) + dd0 = dd0.eval() + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) + + @pytest.mark.parametrize("prec", ["float64"]) # precision + def test_make_fx(self, prec) -> None: + dtype = PRECISION_DICT[prec] + dd0 = make_descriptor(self.nt, self.sel_mix, self.rcut, precision=prec).to( + self.device + ) + dd0 = dd0.eval() + coord_ext = torch.tensor(self.coord_ext, dtype=dtype, device=self.device) + atype_ext = torch.tensor(self.atype_ext, dtype=int, device=self.device) + nlist = torch.tensor(self.nlist, dtype=int, device=self.device) + + def fn(coord_ext, atype_ext, nlist): + coord_ext = coord_ext.detach().requires_grad_(True) + rd = dd0(coord_ext, atype_ext, nlist)[0] + grad = torch.autograd.grad(rd.sum(), coord_ext, create_graph=False)[0] + return rd, grad + + rd_eager, grad_eager = fn(coord_ext, atype_ext, nlist) + traced = make_fx(fn)(coord_ext, atype_ext, nlist) + rd_traced, grad_traced = traced(coord_ext, atype_ext, nlist) + np.testing.assert_allclose( + rd_eager.detach().cpu().numpy(), + rd_traced.detach().cpu().numpy(), + rtol=1e-12, + atol=1e-12, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=1e-12, + atol=1e-12, + ) From 6acc0ef3f8d81b3f8461099dbf610ba3415ab7b6 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 12 Jun 2026 18:45:30 +0800 Subject: [PATCH 02/12] feat(pt_expt): dpa4_ener fitting wrapper --- deepmd/dpmodel/fitting/dpa4_ener.py | 15 ++- deepmd/pt_expt/fitting/__init__.py | 4 + deepmd/pt_expt/fitting/dpa4_ener.py | 70 +++++++++++ .../tests/pt_expt/fitting/test_dpa4_ener.py | 109 ++++++++++++++++++ 4 files changed, 194 insertions(+), 4 deletions(-) create mode 100644 deepmd/pt_expt/fitting/dpa4_ener.py create mode 100644 source/tests/pt_expt/fitting/test_dpa4_ener.py diff --git a/deepmd/dpmodel/fitting/dpa4_ener.py b/deepmd/dpmodel/fitting/dpa4_ener.py index b8f9507014..f8f1d8c8da 100644 --- a/deepmd/dpmodel/fitting/dpa4_ener.py +++ b/deepmd/dpmodel/fitting/dpa4_ener.py @@ -15,6 +15,9 @@ DEFAULT_PRECISION, NativeOP, ) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.array_api import ( Array, ) @@ -156,11 +159,15 @@ def serialize(self) -> dict[str, Any]: """Serialize the network to a dict (pt state_dict key contract).""" variables: dict[str, Any] = {} for layer_idx, layer in enumerate(self.hidden_layers): - variables[f"hidden_layers.{layer_idx}.linear.matrix"] = layer.w - variables[f"hidden_layers.{layer_idx}.linear.bias"] = layer.b - variables["output_layer.matrix"] = self.output_layer.w + variables[f"hidden_layers.{layer_idx}.linear.matrix"] = to_numpy_array( + layer.w + ) + variables[f"hidden_layers.{layer_idx}.linear.bias"] = to_numpy_array( + layer.b + ) + variables["output_layer.matrix"] = to_numpy_array(self.output_layer.w) if self.bias_out: - variables["output_layer.bias"] = self.output_layer.b + variables["output_layer.bias"] = to_numpy_array(self.output_layer.b) return { "@class": "GLUFittingNet", "@version": 1, diff --git a/deepmd/pt_expt/fitting/__init__.py b/deepmd/pt_expt/fitting/__init__.py index 3b69392cfd..8217f64bd2 100644 --- a/deepmd/pt_expt/fitting/__init__.py +++ b/deepmd/pt_expt/fitting/__init__.py @@ -8,6 +8,9 @@ from .dos_fitting import ( DOSFittingNet, ) +from .dpa4_ener import ( + SeZMEnergyFittingNet, +) from .ener_fitting import ( EnergyFittingNet, ) @@ -29,4 +32,5 @@ "InvarFitting", "PolarFitting", "PropertyFittingNet", + "SeZMEnergyFittingNet", ] diff --git a/deepmd/pt_expt/fitting/dpa4_ener.py b/deepmd/pt_expt/fitting/dpa4_ener.py new file mode 100644 index 0000000000..49a65b883f --- /dev/null +++ b/deepmd/pt_expt/fitting/dpa4_ener.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + ClassVar, +) + +import torch + +from deepmd.dpmodel.fitting.dpa4_ener import GLUFittingNet as GLUFittingNetDP +from deepmd.dpmodel.fitting.dpa4_ener import ( + SeZMEnergyFittingNet as SeZMEnergyFittingNetDP, +) +from deepmd.dpmodel.fitting.dpa4_ener import ( + SeZMNetworkCollection as SeZMNetworkCollectionDP, +) +from deepmd.pt_expt.common import ( + register_dpmodel_mapping, + torch_module, +) + +from .base_fitting import ( + BaseFitting, +) + + +@torch_module +class GLUFittingNet(GLUFittingNetDP): + def forward(self, *args: Any, **kwargs: Any) -> Any: + return self.call(*args, **kwargs) + + +register_dpmodel_mapping( + GLUFittingNetDP, + lambda v: GLUFittingNet.deserialize(v.serialize()), +) + + +@torch_module +class SeZMNetworkCollection(SeZMNetworkCollectionDP): + NETWORK_TYPE_MAP: ClassVar[dict[str, type]] = { + "sezm_fitting_network": GLUFittingNet, + } + + def __init__(self, *args: Any, **kwargs: Any) -> None: + self._module_networks = torch.nn.ModuleDict() + super().__init__(*args, **kwargs) + + def __setitem__(self, key: int | tuple | str, value: Any) -> None: + super().__setitem__(key, value) + idx = self._convert_key(key) + net = self._networks[idx] + key_str = str(idx) + if isinstance(net, torch.nn.Module): + self._module_networks[key_str] = net + elif key_str in self._module_networks: + del self._module_networks[key_str] + + +register_dpmodel_mapping( + SeZMNetworkCollectionDP, + lambda v: SeZMNetworkCollection.deserialize(v.serialize()), +) + + +@BaseFitting.register("dpa4_ener") +@BaseFitting.register("sezm_ener") +@torch_module +class SeZMEnergyFittingNet(SeZMEnergyFittingNetDP): + def forward(self, *args: Any, **kwargs: Any) -> Any: + return self.call(*args, **kwargs) diff --git a/source/tests/pt_expt/fitting/test_dpa4_ener.py b/source/tests/pt_expt/fitting/test_dpa4_ener.py new file mode 100644 index 0000000000..0c7469611e --- /dev/null +++ b/source/tests/pt_expt/fitting/test_dpa4_ener.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np +import pytest +import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) + +from deepmd.dpmodel.fitting.dpa4_ener import ( + SeZMEnergyFittingNet as SeZMEnergyFittingNetDP, +) +from deepmd.pt_expt.fitting.dpa4_ener import ( + SeZMEnergyFittingNet, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...common.test_mixins import ( + TestCaseSingleFrameWithNlist, +) +from ...seed import ( + GLOBAL_SEED, +) + +DIM_DESCRPT = 12 + + +class TestSeZMEnergyFittingNet(TestCaseSingleFrameWithNlist): + def setup_method(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + rng = np.random.default_rng(GLOBAL_SEED) + self.descriptor = rng.normal(size=(self.nf, self.nloc, DIM_DESCRPT)) + self.atype = self.atype_ext[:, : self.nloc] + + @pytest.mark.parametrize("neuron", [[0], [16, 16]]) # auto-width / hidden layers + @pytest.mark.parametrize("mixed_types", [True, False]) # type-mixed branches + def test_self_consistency_and_dpmodel(self, neuron, mixed_types) -> None: + ft0 = SeZMEnergyFittingNet( + self.nt, + DIM_DESCRPT, + neuron=neuron, + mixed_types=mixed_types, + precision="float64", + seed=GLOBAL_SEED, + ).to(self.device) + ft1 = SeZMEnergyFittingNet.deserialize(ft0.serialize()).to(self.device) + ft_dp = SeZMEnergyFittingNetDP.deserialize(ft0.serialize()) + + descriptor_t = torch.from_numpy(self.descriptor).to(self.device) + atype_t = torch.from_numpy(self.atype).to(self.device) + ret0 = ft0(descriptor_t, atype_t)["energy"].detach().cpu().numpy() + ret1 = ft1(descriptor_t, atype_t)["energy"].detach().cpu().numpy() + ret_dp = ft_dp.call(self.descriptor, self.atype)["energy"] + np.testing.assert_allclose(ret0, ret1, rtol=1e-12, atol=1e-14) + np.testing.assert_allclose(ret0, ret_dp, rtol=1e-12, atol=1e-14) + + @pytest.mark.parametrize("neuron", [[0], [16, 16]]) # auto-width / hidden layers + def test_trainable_parameters(self, neuron) -> None: + ft = SeZMEnergyFittingNet( + self.nt, + DIM_DESCRPT, + neuron=neuron, + precision="float64", + seed=GLOBAL_SEED, + ).to(self.device) + params = list(ft.parameters()) + assert len(params) > 0 + names = [name for name, _ in ft.named_parameters()] + assert any("hidden_layers" in name for name in names) + assert any("output_layer" in name for name in names) + + def test_serialize_type(self) -> None: + ft = SeZMEnergyFittingNet( + self.nt, DIM_DESCRPT, neuron=[16], precision="float64", seed=GLOBAL_SEED + ).to(self.device) + assert ft.serialize()["type"] == "sezm_ener" + + def test_make_fx(self) -> None: + ft = ( + SeZMEnergyFittingNet( + self.nt, + DIM_DESCRPT, + neuron=[16, 16], + precision="float64", + seed=GLOBAL_SEED, + ) + .to(self.device) + .eval() + ) + descriptor_t = torch.from_numpy(self.descriptor).to(self.device) + atype_t = torch.from_numpy(self.atype).to(self.device) + + def fn(descriptor, atype): + descriptor = descriptor.detach().requires_grad_(True) + ret = ft(descriptor, atype)["energy"] + grad = torch.autograd.grad(ret.sum(), descriptor, create_graph=False)[0] + return ret, grad + + traced = make_fx(fn)(descriptor_t, atype_t) + ret_t, grad_t = traced(descriptor_t, atype_t) + ret_e, grad_e = fn(descriptor_t, atype_t) + np.testing.assert_allclose( + ret_t.detach().cpu().numpy(), ret_e.detach().cpu().numpy() + ) + np.testing.assert_allclose( + grad_t.detach().cpu().numpy(), grad_e.detach().cpu().numpy() + ) From 50e6a7dfe8635c8caf4080872a45ba12b63027ab Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 12 Jun 2026 19:04:49 +0800 Subject: [PATCH 03/12] feat(pt_expt): dpa4 model-type assembly --- deepmd/dpmodel/descriptor/dpa4.py | 5 + deepmd/pt_expt/model/get_model.py | 60 +++++++ .../pt_expt/model/test_get_model_dpa4.py | 157 ++++++++++++++++++ 3 files changed, 222 insertions(+) create mode 100644 source/tests/pt_expt/model/test_get_model_dpa4.py diff --git a/deepmd/dpmodel/descriptor/dpa4.py b/deepmd/dpmodel/descriptor/dpa4.py index ad8c30a1d4..628b5dfb73 100644 --- a/deepmd/dpmodel/descriptor/dpa4.py +++ b/deepmd/dpmodel/descriptor/dpa4.py @@ -762,6 +762,7 @@ def call( mapping: Array | None = None, fparam: Array | None = None, comm_dict: dict | None = None, + charge_spin: Array | None = None, ) -> tuple[Array, Any, Any, Any, Any]: """Compute the DPA4 descriptor. @@ -780,6 +781,10 @@ def call( Frame parameters; not used by DPA4 (interface compatibility). comm_dict MPI communication metadata; not used (interface compatibility). + charge_spin + Charge/spin embedding input; must be None since + ``add_chg_spin_ebd=True`` is rejected at construction + (interface compatibility with ``DPAtomicModel``). Returns ------- diff --git a/deepmd/pt_expt/model/get_model.py b/deepmd/pt_expt/model/get_model.py index 9ca32ef641..d2e70c7d35 100644 --- a/deepmd/pt_expt/model/get_model.py +++ b/deepmd/pt_expt/model/get_model.py @@ -110,6 +110,64 @@ def get_standard_model(data: dict) -> EnergyModel: return model +def get_sezm_model(data: dict) -> EnergyModel: + """Build a pt_expt energy model from a DPA4/SeZM model config. + + Mirrors :func:`deepmd.pt.model.model.get_sezm_model` so that dpa4/sezm + training configs are interchangeable between the pt and pt_expt backends. + The pt-only SeZM extensions (bridging, LoRA, compile, spin) are not + supported here and raise ``NotImplementedError``. + """ + data = copy.deepcopy(data) + if "spin" in data: + raise NotImplementedError( + "Spin DPA4/SeZM models are not supported in the pt_expt backend." + ) + if str(data.get("bridging_method", "none")).lower() != "none": + raise NotImplementedError( + "`bridging_method` is not supported for DPA4/SeZM in the pt_expt backend." + ) + if data.get("lora") is not None: + raise NotImplementedError( + "`lora` is not supported for DPA4/SeZM in the pt_expt backend." + ) + if data.get("use_compile"): + raise NotImplementedError( + "`use_compile` is not supported for DPA4/SeZM in the pt_expt backend." + ) + data.pop("type", None) + data.setdefault("descriptor", {}) + data.setdefault("fitting_net", {}) + data["descriptor"].setdefault("type", "dpa4") + data["fitting_net"].setdefault("type", "dpa4_ener") + + # keep descriptor.exclude_types and model pair_exclude_types consistent + descriptor_exclude_types = [ + list(pair) for pair in (data["descriptor"].get("exclude_types") or []) + ] + if "pair_exclude_types" in data: + pair_exclude_types = [list(pair) for pair in (data["pair_exclude_types"] or [])] + if descriptor_exclude_types and descriptor_exclude_types != pair_exclude_types: + raise ValueError( + "SeZM `pair_exclude_types` and `descriptor.exclude_types` must match " + "when both are provided." + ) + else: + pair_exclude_types = descriptor_exclude_types + data["pair_exclude_types"] = pair_exclude_types + data["descriptor"]["exclude_types"] = copy.deepcopy(pair_exclude_types) + + ntypes = len(data["type_map"]) + descriptor, fitting, _ = _get_standard_model_components(data, ntypes) + return EnergyModel( + descriptor=descriptor, + fitting=fitting, + type_map=data["type_map"], + atom_exclude_types=data.get("atom_exclude_types", []), + pair_exclude_types=pair_exclude_types, + ) + + def get_linear_model(model_params: dict) -> BaseModel: """Get a linear energy model from a config dictionary. @@ -213,5 +271,7 @@ def get_model(data: dict) -> BaseModel: return get_standard_model(data) elif model_type == "linear_ener": return get_linear_model(data) + elif model_type in ("dpa4", "DPA4", "sezm", "SeZM"): + return get_sezm_model(data) else: return BaseModel.get_class_by_type(model_type).get_model(data) diff --git a/source/tests/pt_expt/model/test_get_model_dpa4.py b/source/tests/pt_expt/model/test_get_model_dpa4.py new file mode 100644 index 0000000000..0641be383c --- /dev/null +++ b/source/tests/pt_expt/model/test_get_model_dpa4.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for the DPA4/SeZM model-type dispatch in pt_expt ``get_model``.""" + +import copy +import unittest + +import torch + +from deepmd.pt_expt.model import ( + get_model, +) +from deepmd.pt_expt.model.ener_model import ( + EnergyModel, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.utils.argcheck import ( + normalize, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) + + +def _make_raw_model_config(**model_overrides) -> dict: + """Minimal (un-normalized) dpa4 model config; small dims for speed.""" + model = { + "type": "dpa4", + "type_map": ["O", "H"], + "descriptor": { + "sel": 20, + "rcut": 4.0, + "channels": 8, + "n_radial": 4, + "lmax": 1, + "mmax": 1, + "n_blocks": 1, + "precision": "float64", + "seed": 1, + }, + "fitting_net": { + "precision": "float64", + "seed": 1, + }, + } + model.update(model_overrides) + return model + + +def _normalize_model(model: dict) -> dict: + config = { + "model": model, + "training": {"training_data": {"systems": ["dummy"]}, "numb_steps": 1}, + "loss": {"type": "ener"}, + "learning_rate": {"type": "exp", "start_lr": 1e-3}, + } + config = update_deepmd_input(config, warning=False) + config = normalize(config) + return config["model"] + + +class TestGetModelDPA4(unittest.TestCase): + def setUp(self) -> None: + self.device = env.DEVICE + + def test_get_model_normalized_config(self) -> None: + """Normalized argcheck config (type key present) builds an EnergyModel.""" + model_params = _normalize_model(_make_raw_model_config()) + self.assertEqual(model_params["type"], "dpa4") + model = get_model(model_params).to(self.device) + self.assertIsInstance(model, torch.nn.Module) + self.assertIsInstance(model, EnergyModel) + self.assertEqual(model.get_dim_fparam(), 0) + self.assertEqual(model.get_type_map(), ["O", "H"]) + nparams = sum(p.numel() for p in model.parameters()) + self.assertGreater(nparams, 0) + # forward smoke + generator = torch.Generator(device=self.device).manual_seed(1) + cell = 5.0 * torch.eye(3, dtype=torch.float64, device=self.device) + coord = ( + torch.rand( + [1, 5, 3], + dtype=torch.float64, + device=self.device, + generator=generator, + ) + @ cell + ).requires_grad_(True) + atype = torch.tensor([[0, 0, 0, 1, 1]], dtype=torch.int64, device=self.device) + ret = model(coord, atype, cell.reshape(1, 9)) + self.assertEqual(ret["energy"].shape, (1, 1)) + self.assertEqual(ret["force"].shape, (1, 5, 3)) + + def test_get_model_type_aliases(self) -> None: + """All model-type aliases route to the SeZM path.""" + for alias in ("dpa4", "DPA4", "sezm", "SeZM"): + model_params = _make_raw_model_config(type=alias) + model = get_model(model_params) + self.assertIsInstance(model, EnergyModel, msg=f"alias={alias}") + + def test_descriptor_fitting_type_defaults(self) -> None: + """Descriptor/fitting type keys default to dpa4/dpa4_ener when absent.""" + raw = _make_raw_model_config() + self.assertNotIn("type", raw["descriptor"]) + self.assertNotIn("type", raw["fitting_net"]) + model = get_model(raw) + self.assertIsInstance(model, EnergyModel) + + def test_pair_exclude_types_from_descriptor(self) -> None: + """descriptor.exclude_types propagates when pair_exclude_types absent.""" + raw = _make_raw_model_config() + raw["descriptor"]["exclude_types"] = [[0, 1]] + model = get_model(raw) + self.assertEqual(model.atomic_model.pair_exclude_types, [[0, 1]]) + + def test_pair_exclude_types_consistent(self) -> None: + """Matching pair_exclude_types and descriptor.exclude_types are accepted.""" + raw = _make_raw_model_config() + raw["descriptor"]["exclude_types"] = [[0, 1]] + raw["pair_exclude_types"] = [[0, 1]] + model = get_model(raw) + self.assertEqual(model.atomic_model.pair_exclude_types, [[0, 1]]) + + def test_pair_exclude_types_mismatch_raises(self) -> None: + raw = _make_raw_model_config() + raw["descriptor"]["exclude_types"] = [[0, 1]] + raw["pair_exclude_types"] = [[0, 0]] + with self.assertRaisesRegex(ValueError, "must match"): + get_model(raw) + + def test_unsupported_keys_raise(self) -> None: + """pt-only SeZM model-level features fail fast with NotImplementedError.""" + cases = { + "spin": {"use_spin": [True, False], "virtual_scale": [0.3]}, + "bridging_method": "ZBL", + "lora": {"rank": 4}, + "use_compile": True, + } + for key, value in cases.items(): + raw = _make_raw_model_config() + raw[key] = value + with self.assertRaises(NotImplementedError, msg=f"key={key}"): + get_model(raw) + + def test_default_unsupported_values_pass(self) -> None: + """Normalized defaults (bridging None, lora None, use_compile False) build.""" + model_params = _normalize_model(_make_raw_model_config()) + self.assertEqual(model_params["bridging_method"], "None") + self.assertIsNone(model_params["lora"]) + self.assertFalse(model_params["use_compile"]) + model = get_model(copy.deepcopy(model_params)) + self.assertIsInstance(model, EnergyModel) + + +if __name__ == "__main__": + unittest.main() From b3651b9d21947f7c75afedec28986018a1331f6f Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 12 Jun 2026 19:10:12 +0800 Subject: [PATCH 04/12] fix(pt_expt): guard preset_out_bias, document enable_tf32 in dpa4 assembly --- deepmd/pt_expt/model/get_model.py | 15 +++++++++++++-- source/tests/pt_expt/model/test_get_model_dpa4.py | 2 ++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/deepmd/pt_expt/model/get_model.py b/deepmd/pt_expt/model/get_model.py index d2e70c7d35..8a4b453a03 100644 --- a/deepmd/pt_expt/model/get_model.py +++ b/deepmd/pt_expt/model/get_model.py @@ -115,8 +115,15 @@ def get_sezm_model(data: dict) -> EnergyModel: Mirrors :func:`deepmd.pt.model.model.get_sezm_model` so that dpa4/sezm training configs are interchangeable between the pt and pt_expt backends. - The pt-only SeZM extensions (bridging, LoRA, compile, spin) are not - supported here and raise ``NotImplementedError``. + The pt-only SeZM extensions (bridging, LoRA, compile, spin, + preset_out_bias) are not supported here and raise + ``NotImplementedError``. + + Notes + ----- + ``enable_tf32`` is accepted but ignored: the pt backend uses it to toggle + TF32 matmul precision, while the pt_expt backend always runs at full + ("highest") matmul precision, which is numerically conservative. """ data = copy.deepcopy(data) if "spin" in data: @@ -135,6 +142,10 @@ def get_sezm_model(data: dict) -> EnergyModel: raise NotImplementedError( "`use_compile` is not supported for DPA4/SeZM in the pt_expt backend." ) + if data.get("preset_out_bias"): + raise NotImplementedError( + "`preset_out_bias` is not supported for DPA4/SeZM in the pt_expt backend." + ) data.pop("type", None) data.setdefault("descriptor", {}) data.setdefault("fitting_net", {}) diff --git a/source/tests/pt_expt/model/test_get_model_dpa4.py b/source/tests/pt_expt/model/test_get_model_dpa4.py index 0641be383c..458f6c5b56 100644 --- a/source/tests/pt_expt/model/test_get_model_dpa4.py +++ b/source/tests/pt_expt/model/test_get_model_dpa4.py @@ -136,6 +136,7 @@ def test_unsupported_keys_raise(self) -> None: "bridging_method": "ZBL", "lora": {"rank": 4}, "use_compile": True, + "preset_out_bias": {"energy": [None, 1.0]}, } for key, value in cases.items(): raw = _make_raw_model_config() @@ -149,6 +150,7 @@ def test_default_unsupported_values_pass(self) -> None: self.assertEqual(model_params["bridging_method"], "None") self.assertIsNone(model_params["lora"]) self.assertFalse(model_params["use_compile"]) + self.assertIsNone(model_params.get("preset_out_bias")) model = get_model(copy.deepcopy(model_params)) self.assertIsInstance(model, EnergyModel) From c8cd87800c4780118160b282709704dbbd0a0091 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 12 Jun 2026 19:13:51 +0800 Subject: [PATCH 05/12] test(pt_expt): tighten dpa4 assembly guard assertions --- deepmd/pt_expt/model/get_model.py | 2 ++ .../tests/pt_expt/model/test_get_model_dpa4.py | 17 ++++++++++------- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/deepmd/pt_expt/model/get_model.py b/deepmd/pt_expt/model/get_model.py index 8a4b453a03..94c3a3ae41 100644 --- a/deepmd/pt_expt/model/get_model.py +++ b/deepmd/pt_expt/model/get_model.py @@ -115,6 +115,8 @@ def get_sezm_model(data: dict) -> EnergyModel: Mirrors :func:`deepmd.pt.model.model.get_sezm_model` so that dpa4/sezm training configs are interchangeable between the pt and pt_expt backends. + In addition to the ``SeZM``/``sezm``/``dpa4`` aliases accepted by pt, + pt_expt also accepts ``DPA4``. The pt-only SeZM extensions (bridging, LoRA, compile, spin, preset_out_bias) are not supported here and raise ``NotImplementedError``. diff --git a/source/tests/pt_expt/model/test_get_model_dpa4.py b/source/tests/pt_expt/model/test_get_model_dpa4.py index 458f6c5b56..8788952eee 100644 --- a/source/tests/pt_expt/model/test_get_model_dpa4.py +++ b/source/tests/pt_expt/model/test_get_model_dpa4.py @@ -132,16 +132,19 @@ def test_pair_exclude_types_mismatch_raises(self) -> None: def test_unsupported_keys_raise(self) -> None: """pt-only SeZM model-level features fail fast with NotImplementedError.""" cases = { - "spin": {"use_spin": [True, False], "virtual_scale": [0.3]}, - "bridging_method": "ZBL", - "lora": {"rank": 4}, - "use_compile": True, - "preset_out_bias": {"energy": [None, 1.0]}, + "spin": ({"use_spin": [True, False], "virtual_scale": [0.3]}, "Spin DPA4"), + "bridging_method": ("ZBL", "`bridging_method` is not supported"), + "lora": ({"rank": 4}, "`lora` is not supported"), + "use_compile": (True, "`use_compile` is not supported"), + "preset_out_bias": ( + {"energy": [None, 1.0]}, + "`preset_out_bias` is not supported", + ), } - for key, value in cases.items(): + for key, (value, msg_regex) in cases.items(): raw = _make_raw_model_config() raw[key] = value - with self.assertRaises(NotImplementedError, msg=f"key={key}"): + with self.assertRaisesRegex(NotImplementedError, msg_regex): get_model(raw) def test_default_unsupported_values_pass(self) -> None: From c0bae02fe5f380bc2ee3383212f230bc693fa705 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 12 Jun 2026 19:46:39 +0800 Subject: [PATCH 06/12] test(pt_expt): DPA4 training end-to-end --- source/tests/pt_expt/test_training.py | 36 +++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index cbb8368074..282d3a0843 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -8,6 +8,7 @@ 4. Loss decreases over those steps """ +import copy import datetime import os import shutil @@ -126,6 +127,32 @@ } +# DPA4/SeZM uses a dedicated model type ("dpa4") with fixed descriptor and +# fitting types, so it gets a full model config rather than only a descriptor +# dict that can be swapped into ``_make_config``. +_MODEL_DPA4 = { + "type": "dpa4", + "type_map": ["O", "H"], + "descriptor": { + "type": "dpa4", + "sel": 20, + "rcut": 4.0, + "channels": 16, + "n_radial": 8, + "lmax": 2, + "mmax": 1, + "n_blocks": 2, + "seed": 1, + }, + "fitting_net": { + "type": "dpa4_ener", + "neuron": [0], + "seed": 1, + }, + "data_stat_nbatch": 1, +} + + def _assert_compile_predictions_match( testcase: unittest.TestCase, out_c: dict, @@ -284,6 +311,15 @@ def test_training_loop(self) -> None: config = normalize(config) self._run_training(config) + def test_training_loop_dpa4(self) -> None: + """Run a few DPA4/SeZM training steps (model type "dpa4" dispatch).""" + config = _make_config(self.data_dir, numb_steps=5) + config["model"] = copy.deepcopy(_MODEL_DPA4) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + self.assertEqual(config["model"]["type"], "dpa4") + self._run_training(config) + def test_training_loop_compiled(self) -> None: """Run a few training steps with torch.compile enabled.""" config = _make_config(self.data_dir, numb_steps=5) From 7aca8801d668c9c03b16473f8b2ebd798c2c6e44 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 12 Jun 2026 23:59:32 +0800 Subject: [PATCH 07/12] fix(pt_expt): promote DPA4 trainable weights to Parameters; grad parity vs pt --- deepmd/dpmodel/array_api.py | 30 +++ deepmd/dpmodel/descriptor/dpa4.py | 12 +- .../dpmodel/descriptor/dpa4_nn/activation.py | 9 +- .../dpmodel/descriptor/dpa4_nn/edge_cache.py | 12 +- .../dpmodel/descriptor/dpa4_nn/embedding.py | 15 +- deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py | 9 +- deepmd/dpmodel/descriptor/dpa4_nn/indexing.py | 12 +- deepmd/dpmodel/descriptor/dpa4_nn/norm.py | 36 ++- .../dpmodel/descriptor/dpa4_nn/projection.py | 12 +- deepmd/dpmodel/descriptor/dpa4_nn/radial.py | 8 +- deepmd/dpmodel/descriptor/dpa4_nn/so2.py | 46 ++-- deepmd/dpmodel/descriptor/dpa4_nn/so3.py | 28 +- deepmd/dpmodel/descriptor/dpa4_nn/wignerd.py | 30 ++- deepmd/pt_expt/common.py | 15 +- deepmd/pt_expt/descriptor/dpa4.py | 84 ++++++ .../pt/model/test_dpa4_ptexpt_grad_parity.py | 242 ++++++++++++++++++ source/tests/pt_expt/descriptor/test_dpa4.py | 40 +++ 17 files changed, 562 insertions(+), 78 deletions(-) create mode 100644 source/tests/pt/model/test_dpa4_ptexpt_grad_parity.py diff --git a/deepmd/dpmodel/array_api.py b/deepmd/dpmodel/array_api.py index 0fed9813dd..c5547af79a 100644 --- a/deepmd/dpmodel/array_api.py +++ b/deepmd/dpmodel/array_api.py @@ -15,6 +15,36 @@ Array = np.ndarray | Any # Any to support JAX, PyTorch, etc. arrays +def xp_asarray_nodetach( + xp: Any, + obj: Any, + *, + dtype: Any = None, + device: Any = None, +) -> Array: + """``xp.asarray`` that preserves autograd for backend tensors. + + ``torch.asarray`` detaches its input from the autograd graph, so calling + ``xp.asarray`` on a weight attribute that is already a backend tensor + (e.g. a ``torch.nn.Parameter`` registered by the pt_expt backend) + silently breaks gradient flow to that weight. This helper converts + genuine non-backend data (numpy arrays, python scalars/lists) via + ``xp.asarray``; backend tensors are returned as-is, with an optional + differentiable dtype cast via ``xp.astype``. + + The ``device`` argument only applies to the conversion path: backend + tensors are assumed to already live on the working device (they are + created together with the inputs). + """ + if isinstance(obj, np.ndarray) or not array_api_compat.is_array_api_obj(obj): + if dtype is None: + return xp.asarray(obj, device=device) + return xp.asarray(obj, dtype=dtype, device=device) + if dtype is not None and obj.dtype != dtype: + obj = xp.astype(obj, dtype) + return obj + + # array api adds take_along_axis in https://github.com/data-apis/array-api/pull/816 # but it hasn't been released yet # below is a pure Python implementation of take_along_axis diff --git a/deepmd/dpmodel/descriptor/dpa4.py b/deepmd/dpmodel/descriptor/dpa4.py index 628b5dfb73..effe51afdb 100644 --- a/deepmd/dpmodel/descriptor/dpa4.py +++ b/deepmd/dpmodel/descriptor/dpa4.py @@ -96,6 +96,10 @@ WignerDCalculator, ) +from deepmd.dpmodel.array_api import ( + xp_asarray_nodetach, +) + if TYPE_CHECKING: from deepmd.dpmodel.array_api import ( Array, @@ -856,10 +860,10 @@ def call( shift_hat = self.film_shift_norm(shift_logits) device = array_api_compat.device(scale_hat) scale_strength = xp.exp( - xp.asarray(self.film_scale_strength_log, device=device) + xp_asarray_nodetach(xp, self.film_scale_strength_log, device=device) ) shift_strength = xp.exp( - xp.asarray(self.film_shift_strength_log, device=device) + xp_asarray_nodetach(xp, self.film_shift_strength_log, device=device) ) scale = 1.0 + scale_strength * xp.tanh(scale_hat) shift = shift_strength * xp.tanh(shift_hat) @@ -934,7 +938,9 @@ def _build_gie_zonal_coupling(self, edge_cache: EdgeCache) -> Any: mp_cols = self.gie.zonal_m0_col_index_for_row[:mp_row_count] Dt_full = edge_cache.Dt_full dim_full = Dt_full.shape[-1] - flat_index = xp.asarray(mp_rows * dim_full + mp_cols, device=device) + flat_index = xp_asarray_nodetach( + xp, mp_rows * dim_full + mp_cols, device=device + ) mp_coupling = xp.take( xp.reshape(Dt_full, (n_edge, dim_full * dim_full)), flat_index, diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/activation.py b/deepmd/dpmodel/descriptor/dpa4_nn/activation.py index 23fbeada4a..6e5ba7dbc2 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/activation.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/activation.py @@ -33,6 +33,7 @@ NativeOP, ) from deepmd.dpmodel.array_api import ( + xp_asarray_nodetach, xp_sigmoid, ) from deepmd.dpmodel.common import ( @@ -211,8 +212,8 @@ def call(self, x: Any, gate: Any = None) -> Any: if self.lmax == 0: return x0 - gate_weight = xp.asarray( - self.gate_linear.weight[...], device=array_api_compat.device(x) + gate_weight = xp_asarray_nodetach( + xp, self.gate_linear.weight[...], device=array_api_compat.device(x) ) input_dtype = gate_scalar_source.dtype if input_dtype != gate_weight.dtype: @@ -224,7 +225,9 @@ def call(self, x: Any, gate: Any = None) -> Any: gating_scalars, (x.shape[0], gate_scalar_source.shape[1], self.lmax, self.channels), ) - expand_index = xp.asarray(self.expand_index, device=array_api_compat.device(x)) + expand_index = xp_asarray_nodetach( + xp, self.expand_index, device=array_api_compat.device(x) + ) gates = xp.take(gating_scalars, expand_index, axis=2) # (N, F, D-1, C) if self.layout == "ndfc": gates = xp.permute_dims(gates, (0, 2, 1, 3)) # (N, D-1, F, C) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/edge_cache.py b/deepmd/dpmodel/descriptor/dpa4_nn/edge_cache.py index e2afba1670..03c3a57031 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/edge_cache.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/edge_cache.py @@ -39,6 +39,10 @@ ) import array_api_compat + +from deepmd.dpmodel.array_api import ( + xp_asarray_nodetach, +) import numpy as np from .utils import ( @@ -323,7 +327,9 @@ def build_edge_cache( # === Step 4. Rewrite invalid slots to the safe +z dummy vector === # Gradient safety: see the function docstring. maskf = xp.astype(mask_flat, vec.dtype)[:, None] # (E, 1) - z_unit = xp.asarray(np.array([[0.0, 0.0, 1.0]]), dtype=vec.dtype, device=device) + z_unit = xp_asarray_nodetach( + xp, np.array([[0.0, 0.0, 1.0]]), dtype=vec.dtype, device=device + ) edge_vec = vec * maskf + (1.0 - maskf) * z_unit edge_len = safe_norm(edge_vec, eps) # (E, 1) @@ -336,7 +342,9 @@ def build_edge_cache( if random_gamma: if gamma is None: gamma = np.random.default_rng().uniform(0.0, 2.0 * math.pi, n_edge) - gamma = xp.astype(xp.asarray(gamma, device=device), edge_quat.dtype) + gamma = xp.astype( + xp_asarray_nodetach(xp, gamma, device=device), edge_quat.dtype + ) edge_quat = quaternion_multiply(quaternion_z_rotation(gamma), edge_quat) D_full, Dt_full = wigner_calc(edge_quat) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/embedding.py b/deepmd/dpmodel/descriptor/dpa4_nn/embedding.py index 152b8b5a55..40f59b3b00 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/embedding.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/embedding.py @@ -39,6 +39,10 @@ ) import array_api_compat + +from deepmd.dpmodel.array_api import ( + xp_asarray_nodetach, +) import numpy as np from deepmd.dpmodel import ( @@ -157,8 +161,8 @@ def call(self, atype: Any) -> Any: Type embeddings with shape (..., embed_dim). """ xp = array_api_compat.array_namespace(atype) - weight = xp.asarray( - self.adam_type_embedding[...], device=array_api_compat.device(atype) + weight = xp_asarray_nodetach( + xp, self.adam_type_embedding[...], device=array_api_compat.device(atype) ) # pt embedding.py:143 torch.embedding -> flat int64 take + reshape. index = xp.astype(xp.reshape(atype, (-1,)), xp.int64) @@ -297,7 +301,8 @@ def call( if zonal_coupling is None: Dt_full = edge_cache.Dt_full # (E, D, D) dim_full = Dt_full.shape[-1] - flat_index = xp.asarray( + flat_index = xp_asarray_nodetach( + xp, self.non_scalar_row_index * dim_full + self.zonal_m0_col_index_for_row, device=device, ) @@ -310,7 +315,9 @@ def call( # === Step 3. Broadcast radial features per row === # Each non-scalar packed row reuses the radial feature of its degree l # (pt embedding.py:245-250, index_select on axis 1). - radial_slot_index = xp.asarray(self.radial_slot_index_for_row, device=device) + radial_slot_index = xp_asarray_nodetach( + xp, self.radial_slot_index_for_row, device=device + ) radial_value_for_row = xp.take( radial_feat, radial_slot_index, axis=1 ) # (E, D-1, C) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py index 647d112722..70db705bb9 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py @@ -59,6 +59,7 @@ NativeOP, ) from deepmd.dpmodel.array_api import ( + xp_asarray_nodetach, xp_sigmoid, ) from deepmd.dpmodel.common import ( @@ -423,8 +424,8 @@ def _to_grid(self, coeff: Any) -> Any: # einsum "gd,ndfc->ngfc" (n_frames == 1) as a broadcast batched matmul xp = array_api_compat.array_namespace(coeff) n_batch, coeff_dim, n_focus, _ = coeff.shape - to_grid_mat = xp.asarray( - self.projector.to_grid_mat[...], device=array_api_compat.device(coeff) + to_grid_mat = xp_asarray_nodetach( + xp, self.projector.to_grid_mat[...], device=array_api_compat.device(coeff) ) if to_grid_mat.dtype != coeff.dtype: to_grid_mat = xp.astype(to_grid_mat, coeff.dtype) @@ -439,8 +440,8 @@ def _from_grid(self, grid: Any) -> Any: xp = array_api_compat.array_namespace(grid) n_batch, n_grid, n_focus, _ = grid.shape coeff_dim = self.projector.coeff_dim - from_grid_mat = xp.asarray( - self.projector.from_grid_mat[...], device=array_api_compat.device(grid) + from_grid_mat = xp_asarray_nodetach( + xp, self.projector.from_grid_mat[...], device=array_api_compat.device(grid) ) if from_grid_mat.dtype != grid.dtype: from_grid_mat = xp.astype(from_grid_mat, grid.dtype) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/indexing.py b/deepmd/dpmodel/descriptor/dpa4_nn/indexing.py index 179dcefd9f..080bcb58d9 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/indexing.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/indexing.py @@ -22,6 +22,10 @@ ) import array_api_compat + +from deepmd.dpmodel.array_api import ( + xp_asarray_nodetach, +) import numpy as np @@ -169,7 +173,9 @@ def project_D_to_m( xp = array_api_compat.array_namespace(D_full) D_block = D_full[:, :ebed_dim_full, :ebed_dim_full] - index = xp.asarray(coeff_index_m, device=array_api_compat.device(D_full)) + index = xp_asarray_nodetach( + xp, coeff_index_m, device=array_api_compat.device(D_full) + ) proj = xp.take(D_block, index, axis=1) if cache is not None: cache[cache_key] = proj @@ -223,7 +229,9 @@ def project_Dt_from_m( xp = array_api_compat.array_namespace(Dt_full) Dt_block = Dt_full[:, :ebed_dim_full, :ebed_dim_full] - index = xp.asarray(coeff_index_m, device=array_api_compat.device(Dt_full)) + index = xp_asarray_nodetach( + xp, coeff_index_m, device=array_api_compat.device(Dt_full) + ) proj = xp.take(Dt_block, index, axis=2) if cache is not None: cache[cache_key] = proj diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/norm.py b/deepmd/dpmodel/descriptor/dpa4_nn/norm.py index 420f6d469b..0e5cafde64 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/norm.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/norm.py @@ -21,6 +21,10 @@ ) import array_api_compat + +from deepmd.dpmodel.array_api import ( + xp_asarray_nodetach, +) import numpy as np from deepmd.dpmodel import ( @@ -93,7 +97,9 @@ def call(self, x: Any) -> Any: Normalized array with shape `(..., C)`, same dtype as input. """ xp = array_api_compat.array_namespace(x) - scale = xp.asarray(self.adam_scale[...], device=array_api_compat.device(x)) + scale = xp_asarray_nodetach( + xp, self.adam_scale[...], device=array_api_compat.device(x) + ) in_dtype = x.dtype if in_dtype != scale.dtype: x = xp.astype(x, scale.dtype) @@ -233,10 +239,10 @@ def call(self, x: Any) -> Any: """ xp = array_api_compat.array_namespace(x) device = array_api_compat.device(x) - scale = xp.asarray(self.adam_scale[...], device=device) - bias = xp.asarray(self.bias[...], device=device) - balance_weight = xp.asarray( - self.balance_weight, device=array_api_compat.device(x) + scale = xp_asarray_nodetach(xp, self.adam_scale[...], device=device) + bias = xp_asarray_nodetach(xp, self.bias[...], device=device) + balance_weight = xp_asarray_nodetach( + xp, self.balance_weight, device=array_api_compat.device(x) ) in_dtype = x.dtype if in_dtype != scale.dtype: @@ -261,7 +267,9 @@ def call(self, x: Any) -> Any: xt = xt * inv_rms # === Step 3. Apply per-degree affine parameters === - expand_index = xp.asarray(self.expand_index, device=array_api_compat.device(x)) + expand_index = xp_asarray_nodetach( + xp, self.expand_index, device=array_api_compat.device(x) + ) expanded_scale = xp.take(scale, expand_index, axis=0) expanded_scale = expanded_scale[None, ...] # (1, D, F, C) x0 = x0 * expanded_scale[:, :1, :, :] @@ -435,10 +443,10 @@ def call(self, x: Any) -> Any: """ xp = array_api_compat.array_namespace(x) device = array_api_compat.device(x) - scale = xp.asarray(self.adam_scale[...], device=device) - bias0_w = xp.asarray(self.bias0[...], device=device) - balance_weight = xp.asarray( - self.balance_weight, device=array_api_compat.device(x) + scale = xp_asarray_nodetach(xp, self.adam_scale[...], device=device) + bias0_w = xp_asarray_nodetach(xp, self.bias0[...], device=device) + balance_weight = xp_asarray_nodetach( + xp, self.balance_weight, device=array_api_compat.device(x) ) in_dtype = x.dtype if in_dtype != scale.dtype: @@ -464,8 +472,8 @@ def call(self, x: Any) -> Any: xt = xt * inv_rms # === Step 3. Apply per-degree affine parameters === - degree_index_m = xp.asarray( - self.degree_index_m, device=array_api_compat.device(x) + degree_index_m = xp_asarray_nodetach( + xp, self.degree_index_m, device=array_api_compat.device(x) ) expanded_scale = xp.take(scale, degree_index_m, axis=1) expanded_scale = expanded_scale[None, ...] # (1, F, D_m_trunc, C) @@ -598,7 +606,9 @@ def call(self, x: Any) -> Any: Normalized array with the same shape as input and same dtype. """ xp = array_api_compat.array_namespace(x) - scale = xp.asarray(self.adam_scale[...], device=array_api_compat.device(x)) + scale = xp_asarray_nodetach( + xp, self.adam_scale[...], device=array_api_compat.device(x) + ) in_dtype = x.dtype if in_dtype != scale.dtype: x = xp.astype(x, scale.dtype) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/projection.py b/deepmd/dpmodel/descriptor/dpa4_nn/projection.py index bc1ce5add7..b21ccfaccb 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/projection.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/projection.py @@ -42,6 +42,10 @@ ) import array_api_compat + +from deepmd.dpmodel.array_api import ( + xp_asarray_nodetach, +) import numpy as np from deepmd.dpmodel import ( @@ -123,8 +127,8 @@ def call(self, *args: Any, **kwargs: Any) -> Any: def to_grid(self, embedding: Any) -> Any: """Project flattened coefficients ``(N, J, C)`` to grid fields ``(N, G, C)``.""" xp = array_api_compat.array_namespace(embedding) - to_grid_mat = xp.asarray( - self.to_grid_mat[...], device=array_api_compat.device(embedding) + to_grid_mat = xp_asarray_nodetach( + xp, self.to_grid_mat[...], device=array_api_compat.device(embedding) ) if to_grid_mat.dtype != embedding.dtype: to_grid_mat = xp.astype(to_grid_mat, embedding.dtype) @@ -134,8 +138,8 @@ def to_grid(self, embedding: Any) -> Any: def from_grid(self, grid: Any) -> Any: """Project grid fields ``(N, G, C)`` back to flattened coefficients ``(N, J, C)``.""" xp = array_api_compat.array_namespace(grid) - from_grid_mat = xp.asarray( - self.from_grid_mat[...], device=array_api_compat.device(grid) + from_grid_mat = xp_asarray_nodetach( + xp, self.from_grid_mat[...], device=array_api_compat.device(grid) ) if from_grid_mat.dtype != grid.dtype: from_grid_mat = xp.astype(from_grid_mat, grid.dtype) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/radial.py b/deepmd/dpmodel/descriptor/dpa4_nn/radial.py index b93836ea8e..4428aab5fb 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/radial.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/radial.py @@ -22,6 +22,10 @@ ) import array_api_compat + +from deepmd.dpmodel.array_api import ( + xp_asarray_nodetach, +) import numpy as np from deepmd.dpmodel import ( @@ -411,7 +415,9 @@ def call(self, r: Any) -> Any: (N, n_rbf). The output is smoothly truncated to zero at r = rcut. """ xp = array_api_compat.array_namespace(r) - freqs = xp.asarray(self.adam_freqs, device=array_api_compat.device(r)) + freqs = xp_asarray_nodetach( + xp, self.adam_freqs, device=array_api_compat.device(r) + ) # === Step 1. Radial basis === # Shape: (N, 1) * (1, n_radial) -> (N, n_radial) if self.basis_type == "bessel": diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/so2.py b/deepmd/dpmodel/descriptor/dpa4_nn/so2.py index 12b0ff5371..1986b5e9c9 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/so2.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/so2.py @@ -45,6 +45,7 @@ NativeOP, ) from deepmd.dpmodel.array_api import ( + xp_asarray_nodetach, xp_sigmoid, ) from deepmd.dpmodel.common import ( @@ -354,7 +355,7 @@ def call(self, x: Any) -> Any: num_m0 = self.lmax + 1 device = array_api_compat.device(x) weight_m0 = xp.reshape( - xp.asarray(self.weight_m0[...], device=device), + xp_asarray_nodetach(xp, self.weight_m0[...], device=device), (num_m0 * self.in_channels, self.n_focus, num_m0 * self.out_channels), ) weight_m0 = xp.permute_dims(weight_m0, (1, 0, 2)) # (F, in, out) @@ -366,7 +367,8 @@ def call(self, x: Any) -> Any: ib = ni1 - ni0 # in_block size ob = no1 - no0 # out_block size w = xp.reshape( - xp.asarray(w[...], device=device), (ib, self.n_focus, 2 * ob) + xp_asarray_nodetach(xp, w[...], device=device), + (ib, self.n_focus, 2 * ob), ) w = xp.permute_dims(w, (1, 0, 2)) # (F, in_blk, 2*out_blk) w_u = w[:, :, :ob] # (F, in_blk, out_blk) @@ -392,7 +394,7 @@ def call(self, x: Any) -> Any: # === Step 3. Bias on l=0 scalar index === if self.mlp_bias: bias0 = xp.reshape( - xp.asarray(self.bias0[...], device=device), + xp_asarray_nodetach(xp, self.bias0[...], device=device), (self.n_focus, self.out_channels), ) out0 = out[:, :, :1, :] + bias0[None, :, None, :] @@ -621,16 +623,17 @@ def _project_radial(self, xp: Any, radial_feat: Any) -> Any: radial_feat[:, : self.lmax + 1, :], (radial_feat.shape[0], self.input_dim), ) - weight = xp.asarray( - self.weight[...], device=array_api_compat.device(radial_feat) + weight = xp_asarray_nodetach( + xp, self.weight[...], device=array_api_compat.device(radial_feat) ) return xp.matmul(radial_m0, weight) def _scatter_dense(self, xp: Any, compact: Any, device: Any) -> Any: """Scatter the compact per-block kernel into the dense (D_m*D_m, ...) layout.""" - gather_index = xp.asarray(self._dense_gather_index, device=device) + gather_index = xp_asarray_nodetach(xp, self._dense_gather_index, device=device) scatter_mask = xp.astype( - xp.asarray(self._dense_scatter_mask, device=device), compact.dtype + xp_asarray_nodetach(xp, self._dense_scatter_mask, device=device), + compact.dtype, ) dense = xp.take(compact, gather_index, axis=1) if compact.ndim == 2: @@ -675,7 +678,7 @@ def call(self, x_local: Any, radial_feat: Any) -> Any: kernel = xp.permute_dims(kernel, (0, 1, 3, 2)) mixed = xp.matmul(kernel, x_local[:, None, :, :]) channel_basis = xp.reshape( - xp.asarray(self.channel_basis[...], device=device), + xp_asarray_nodetach(xp, self.channel_basis[...], device=device), (1, 1, self.rank, self.channels), ) return xp.sum(mixed * channel_basis, axis=2) @@ -1209,7 +1212,7 @@ def call( x_local = xp.matmul(D_m_prime, x_src) # (E, D_m, C_wide) # === Step 3. Select radial/type features for reduced layout === - degree_index_m = xp.asarray(self.degree_index_m, device=device) + degree_index_m = xp_asarray_nodetach(xp, self.degree_index_m, device=device) rad_feat = xp.take(radial_feat, degree_index_m, axis=1) # (E, D_m, C) if self.radial_hidden_proj is not None: rad_feat = self.radial_hidden_proj(rad_feat) @@ -1241,7 +1244,7 @@ def apply_bias_correction( if layer_idx != 0 or so2_linear.bias0 is None: return x_local bias0 = xp.reshape( - xp.asarray(so2_linear.bias0[...], device=device), + xp_asarray_nodetach(xp, so2_linear.bias0[...], device=device), (1, self.n_focus, so2_linear.out_channels), ) if so2_linear.out_channels == self.so2_focus_dim: @@ -1278,7 +1281,9 @@ def apply_bias_correction( # === Step 6. Cross-focus softmax competition === if self.focus_compete and self.n_focus > 1: - compete_w = xp.asarray(self.adamw_focus_compete_w[...], device=device) + compete_w = xp_asarray_nodetach( + xp, self.adamw_focus_compete_w[...], device=device + ) gate_in = xp.astype(focus_gate_src, compete_w.dtype) gate_normed = self.focus_compete_norm(gate_in) # (E, F, Cf) # einsum "efi,if->ef" @@ -1289,7 +1294,9 @@ def apply_bias_correction( if self.mlp_bias: focus_logits = ( focus_logits - + xp.asarray(self.focus_compete_bias[...], device=device)[None, :] + + xp_asarray_nodetach( + xp, self.focus_compete_bias[...], device=device + )[None, :] ) focus_logits = focus_logits / self.focus_softmax_tau logits_max = xp.max(focus_logits, axis=1, keepdims=True) @@ -1321,7 +1328,8 @@ def apply_bias_correction( # inverse-rotation degree rescale after the global lift restores the # full-basis amplitude expected by the block output contract. rescale = xp.astype( - xp.asarray(self.rotate_inv_rescale_full, device=device), x_message.dtype + xp_asarray_nodetach(xp, self.rotate_inv_rescale_full, device=device), + x_message.dtype, ) x_message = x_message * xp.reshape(rescale, (1, -1, 1)) @@ -1352,7 +1360,7 @@ def apply_bias_correction( out = out * inv_sqrt_deg # (N, D, C_wide) else: # === Step 8.1. Build attention logits from scalar channels === - qk_w = xp.asarray(self.attn_q_proj.weight[...], device=device) + qk_w = xp_asarray_nodetach(xp, self.attn_q_proj.weight[...], device=device) x_l0_node = xp.reshape( x[:, 0, :], (n_node, self.attn_n_focus, self.attn_focus_dim) ) # (N, Fa, Ca) @@ -1375,7 +1383,8 @@ def apply_bias_correction( radial_l0 = xp.astype(radial_l0, qk_w.dtype) # einsum "efi,ifo->efo" as a broadcast batched matmul. logit_w = xp.permute_dims( - xp.asarray(self.adamw_attn_logit_w[...], device=device), (1, 0, 2) + xp_asarray_nodetach(xp, self.adamw_attn_logit_w[...], device=device), + (1, 0, 2), ) # (Fa, Ca, H) radial_bias = xp.matmul(radial_l0[:, :, None, :], logit_w[None, ...])[ ..., 0, : @@ -1392,7 +1401,9 @@ def apply_bias_correction( logits=attn_logits, edge_env=xp.astype(edge_cache.edge_env, attn_logits.dtype), n_nodes=n_node, - z_bias_raw=xp.asarray(self.adamw_attn_z_bias_raw, device=device), + z_bias_raw=xp_asarray_nodetach( + xp, self.adamw_attn_z_bias_raw, device=device + ), eps=self.eps, src_weight=( None @@ -1439,7 +1450,8 @@ def apply_bias_correction( # === Step 8.4. Output-side head gate === gate_w = xp.permute_dims( - xp.asarray(self.adamw_attn_gate_w[...], device=device), (1, 0, 2) + xp_asarray_nodetach(xp, self.adamw_attn_gate_w[...], device=device), + (1, 0, 2), ) # (Fa, Ca, H) gate_in = self.attn_output_gate_norm(x_l0_node) attn_output_gate = xp_sigmoid( diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/so3.py b/deepmd/dpmodel/descriptor/dpa4_nn/so3.py index 60e8c2bceb..699d8cbc93 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/so3.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/so3.py @@ -32,6 +32,10 @@ ) import array_api_compat + +from deepmd.dpmodel.array_api import ( + xp_asarray_nodetach, +) import numpy as np from deepmd.dpmodel import ( @@ -139,14 +143,18 @@ def call(self, x: Any) -> Any: Projected array with shape (B, F, Cout). """ xp = array_api_compat.array_namespace(x) - weight = xp.asarray(self.weight[...], device=array_api_compat.device(x)) + weight = xp_asarray_nodetach( + xp, self.weight[...], device=array_api_compat.device(x) + ) weight = xp.reshape(weight, (self.in_channels, self.n_focus, self.out_channels)) # einsum "bfi,ifo->bfo" as a broadcast batched matmul: # (B, F, 1, Cin) @ (1, F, Cin, Cout) -> (B, F, 1, Cout) weight = xp.permute_dims(weight, (1, 0, 2)) # (F, Cin, Cout) out = xp.matmul(x[:, :, None, :], weight[None, ...])[..., 0, :] if self.use_bias: - bias = xp.asarray(self.bias[...], device=array_api_compat.device(x)) + bias = xp_asarray_nodetach( + xp, self.bias[...], device=array_api_compat.device(x) + ) bias = xp.reshape(bias, (self.n_focus, self.out_channels)) out = out + bias[None, ...] return out @@ -285,9 +293,9 @@ def call(self, x: Any) -> Any: xp = array_api_compat.array_namespace(x) # einsum "...i,io->...o" is a plain matmul on the last axis device = array_api_compat.device(x) - out = xp.matmul(x, xp.asarray(self.weight[...], device=device)) + out = xp.matmul(x, xp_asarray_nodetach(xp, self.weight[...], device=device)) if self.use_bias: - out = out + xp.asarray(self.bias[...], device=device) + out = out + xp_asarray_nodetach(xp, self.bias[...], device=device) return out def serialize(self) -> dict[str, Any]: @@ -457,11 +465,15 @@ def call(self, x: Any) -> Any: # === Step 1. Expand per-l weights to packed coefficient layout === # (L, Cin, F*Cout) -> (L, Cin, F, Cout) weight = xp.reshape( - xp.asarray(self.weight[...], device=array_api_compat.device(x)), + xp_asarray_nodetach( + xp, self.weight[...], device=array_api_compat.device(x) + ), (self.lmax + 1, self.in_channels, self.n_focus, self.out_channels), ) # (L, Cin, F, Cout) # (L, Cin, F, Cout) -> (D, Cin, F, Cout) - expand_index = xp.asarray(self.expand_index, device=array_api_compat.device(x)) + expand_index = xp_asarray_nodetach( + xp, self.expand_index, device=array_api_compat.device(x) + ) weight_expanded = xp.take(weight, expand_index, axis=0) # === Step 2. Per-focus, per-degree channel mixing === @@ -474,7 +486,9 @@ def call(self, x: Any) -> Any: # === Step 3. Add l=0 bias === if self.mlp_bias: - bias = xp.asarray(self.bias[...], device=array_api_compat.device(x)) + bias = xp_asarray_nodetach( + xp, self.bias[...], device=array_api_compat.device(x) + ) bias = xp.reshape(bias, (self.n_focus, self.out_channels)) out0 = out[:, :1, :, :] + bias[None, None, ...] out = xp.concat([out0, out[:, 1:, :, :]], axis=1) if self.lmax > 0 else out0 diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/wignerd.py b/deepmd/dpmodel/descriptor/dpa4_nn/wignerd.py index 0612d233be..c630a952f5 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/wignerd.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/wignerd.py @@ -40,6 +40,10 @@ ) import array_api_compat + +from deepmd.dpmodel.array_api import ( + xp_asarray_nodetach, +) import numpy as np from deepmd.dpmodel import ( @@ -691,7 +695,7 @@ def call(self, edge_quaternion: Any) -> tuple[Any, Any]: segments.append(xp.reshape(packed, (n_edge, packed_size * packed_size))) segments.append(xp.zeros((n_edge, 1), dtype=dtype, device=device)) values = xp.concat(segments, axis=1) - idx = xp.asarray(self.full_gather_idx, device=device) + idx = xp_asarray_nodetach(xp, self.full_gather_idx, device=device) D_full = xp.reshape( xp.take(values, idx, axis=1), (n_edge, self.dim_full, self.dim_full), @@ -756,7 +760,7 @@ def _compute_l1_block(self, q: Any, xp: Any, dtype: Any, device: Any) -> Any: # row/column permutation [1, 2, 0], applied structurally (no gather) rot = xp.stack([rot[..., 1, :], rot[..., 2, :], rot[..., 0, :]], axis=-2) rot = xp.stack([rot[..., 1], rot[..., 2], rot[..., 0]], axis=-1) - sign = xp.asarray(self.l1_sign_outer, dtype=dtype, device=device) + sign = xp_asarray_nodetach(xp, self.l1_sign_outer, dtype=dtype, device=device) return rot * sign def _compute_packed_blocks(self, q: Any, xp: Any, dtype: Any, device: Any) -> Any: @@ -769,10 +773,10 @@ def _compute_packed_blocks(self, q: Any, xp: Any, dtype: Any, device: Any) -> An D_re, D_im = self._wigner_d_matrix_realpair( ra_re, ra_im, rb_re, rb_im, xp, dtype, device ) - u_re = xp.asarray(self.poly_u_re, dtype=dtype, device=device) - u_im = xp.asarray(self.poly_u_im, dtype=dtype, device=device) - u_re_t = xp.asarray(self.poly_u_re_t, dtype=dtype, device=device) - u_im_t = xp.asarray(self.poly_u_im_t, dtype=dtype, device=device) + u_re = xp_asarray_nodetach(xp, self.poly_u_re, dtype=dtype, device=device) + u_im = xp_asarray_nodetach(xp, self.poly_u_im, dtype=dtype, device=device) + u_re_t = xp_asarray_nodetach(xp, self.poly_u_re_t, dtype=dtype, device=device) + u_im_t = xp_asarray_nodetach(xp, self.poly_u_im_t, dtype=dtype, device=device) temp_re = xp.matmul(D_re, u_re_t) + xp.matmul(D_im, u_im_t) temp_im = xp.matmul(D_im, u_re_t) - xp.matmul(D_re, u_im_t) return xp.matmul(u_re, temp_re) - xp.matmul(u_im, temp_im) @@ -805,7 +809,7 @@ def _wigner_d_matrix_realpair( rb_im = xp.astype(rb_im, f64) def cv(arr: np.ndarray) -> Any: # constant table -> xp on input device - return xp.asarray(arr, device=device) + return xp_asarray_nodetach(xp, arr, device=device) eps = float(np.finfo(np.float64).eps) eps_sq = eps * eps @@ -929,16 +933,18 @@ def _compute_case_magnitude( horner_sum = _vectorized_horner( xp, ratio, - xp.asarray(case.horner, device=device), - xp.asarray(case.horner_step_mask, device=device), + xp_asarray_nodetach(xp, case.horner, device=device), + xp_asarray_nodetach(xp, case.horner_step_mask, device=device), ) ra_powers = xp.exp( - log_ra[:, None] * xp.asarray(case.ra_exp, device=device)[None, :] + log_ra[:, None] + * xp_asarray_nodetach(xp, case.ra_exp, device=device)[None, :] ) rb_powers = xp.exp( - log_rb[:, None] * xp.asarray(case.rb_exp, device=device)[None, :] + log_rb[:, None] + * xp_asarray_nodetach(xp, case.rb_exp, device=device)[None, :] ) - signed_coeff = xp.asarray(case.signed_coeff, device=device) + signed_coeff = xp_asarray_nodetach(xp, case.signed_coeff, device=device) magnitude = signed_coeff[None, :] * ra_powers * rb_powers return magnitude * horner_sum diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index ab0830d498..5676652ba3 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -178,15 +178,18 @@ def _try_convert_list(name: str, value: list) -> torch.nn.Module | None: """ if not value: return None - # List of torch.nn.Module → ModuleList - if all(isinstance(v, torch.nn.Module) for v in value): - return torch.nn.ModuleList(value) - # List of NativeOP (not yet Module) → convert each + ModuleList - if all( - isinstance(v, NativeOP) and not isinstance(v, torch.nn.Module) for v in value + # List of (torch.nn.Module | NativeOP | None) with at least one module-like + # entry → ModuleList (None entries are preserved: optional sub-modules such + # as per-degree activations/norms are None when disabled). Plain NativeOP + # entries are converted first. + if any(isinstance(v, (torch.nn.Module, NativeOP)) for v in value) and all( + v is None or isinstance(v, (torch.nn.Module, NativeOP)) for v in value ): converted = [] for v in value: + if v is None or isinstance(v, torch.nn.Module): + converted.append(v) + continue c = try_convert_module(v) if c is None: raise TypeError( diff --git a/deepmd/pt_expt/descriptor/dpa4.py b/deepmd/pt_expt/descriptor/dpa4.py index 9aa2473560..dd93ac5e3c 100644 --- a/deepmd/pt_expt/descriptor/dpa4.py +++ b/deepmd/pt_expt/descriptor/dpa4.py @@ -3,6 +3,8 @@ Any, ) +import torch + from deepmd.dpmodel.descriptor.dpa4 import DescrptDPA4 as DescrptDPA4DP from deepmd.dpmodel.descriptor.dpa4_nn.activation import SwiGLU as SwiGLUDP from deepmd.dpmodel.descriptor.dpa4_nn.wignerd import ( @@ -44,6 +46,77 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: register_dpmodel_mapping(SwiGLUDP, lambda v: SwiGLU()) +# --------------------------------------------------------------------------- +# Trainable-weight promotion +# +# ``dpmodel_setattr`` registers every numpy attribute as a torch *buffer*, so +# the auto-wrapped dpa4_nn sub-modules would otherwise expose their trainable +# weights as non-trainable buffers (no autograd, invisible to the optimizer). +# The table below lists, per dpmodel class name, the attributes that are +# ``torch.nn.Parameter`` in the reference pt SeZM implementation +# (deepmd/pt/model/descriptor/sezm_nn). ``_promote_trainable_tree`` walks the +# fully-built module tree and re-registers those buffers as Parameters. +# +# Constant float buffers (e.g. ``balance_weight``, ``rotate_inv_rescale_full``, +# ``mean``/``stddev``) are intentionally NOT listed: they are buffers in pt +# too. Lists of weights (e.g. ``SO2Linear.weight_m``) are already converted +# to trainable ``ParameterList`` by ``_try_convert_list``. +# --------------------------------------------------------------------------- +_TRAINABLE_ATTRS: dict[str, tuple[str, ...]] = { + # dpa4_nn.norm + "RMSNorm": ("adam_scale",), + "EquivariantRMSNorm": ("adam_scale", "bias"), + "ReducedEquivariantRMSNorm": ("adam_scale", "bias0"), + "ScalarRMSNorm": ("adam_scale",), + # dpa4_nn.radial + "RadialBasis": ("adam_freqs",), + # dpa4_nn.so3 + "SO3Linear": ("weight", "bias"), + "FocusLinear": ("weight", "bias"), + "ChannelLinear": ("weight", "bias"), + # dpa4_nn.so2 + "SO2Linear": ("weight_m0", "bias0"), + "DynamicRadialDegreeMixer": ("weight", "channel_basis"), + "SO2Convolution": ( + "adamw_attn_logit_w", + "adamw_attn_z_bias_raw", + "adamw_attn_gate_w", + "adamw_focus_compete_w", + "focus_compete_bias", + ), + # dpa4_nn.embedding + "SeZMTypeEmbedding": ("adam_type_embedding",), + # descriptor-level FiLM strengths + "DescrptDPA4": ("film_scale_strength_log", "film_shift_strength_log"), +} + + +def _promote_trainable(module: torch.nn.Module, names: tuple[str, ...]) -> None: + """Re-register the given float buffers of *module* as Parameters.""" + if not getattr(module, "trainable", True): + return + for name in names: + buf = module._buffers.get(name) + if buf is None or not buf.is_floating_point(): + continue + del module._buffers[name] + setattr(module, name, torch.nn.Parameter(buf, requires_grad=True)) + + +def _promote_trainable_tree(module: torch.nn.Module) -> torch.nn.Module: + """Promote trainable buffers to Parameters across the whole module tree. + + Must run after the tree is fully built (post ``__init__`` / + ``deserialize``): dpmodel deserialize may assign numpy arrays onto nested + attributes, which ``dpmodel_setattr`` would re-register as buffers. + """ + for sub in module.modules(): + names = _TRAINABLE_ATTRS.get(type(sub).__name__) + if names is not None: + _promote_trainable(sub, names) + return module + + @BaseDescriptor.register("SeZM") @BaseDescriptor.register("sezm") @BaseDescriptor.register("DPA4") @@ -52,6 +125,17 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: class DescrptDPA4(DescrptDPA4DP): _update_sel_cls = UpdateSel + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + _promote_trainable_tree(self) + + @classmethod + def deserialize(cls, data: dict) -> "DescrptDPA4": + # deserialize assigns numpy arrays after __init__, which demotes + # promoted Parameters back to buffers; re-promote at the end. + obj = super().deserialize(data) + return _promote_trainable_tree(obj) + def forward(self, *args: Any, **kwargs: Any) -> Any: return self.call(*args, **kwargs) diff --git a/source/tests/pt/model/test_dpa4_ptexpt_grad_parity.py b/source/tests/pt/model/test_dpa4_ptexpt_grad_parity.py new file mode 100644 index 0000000000..530c76edac --- /dev/null +++ b/source/tests/pt/model/test_dpa4_ptexpt_grad_parity.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Parameter-gradient parity: pt SeZM (reference) vs pt_expt DPA4 wrappers. + +The pt_expt wrappers register dpmodel numpy attributes as torch buffers by +default; ``_promote_trainable_tree`` (deepmd/pt_expt/descriptor/dpa4.py) must +re-register every weight that is a trainable ``nn.Parameter`` in pt as a +Parameter so the optimizer sees it and autograd populates its grad. This file +proves that promotion is complete and correct: + +- weights are transferred pt -> pt_expt via serialize()/deserialize(), +- the forward outputs must match (guard assertion), +- a quadratic loss is backpropagated on both sides, +- every gradient is compared 1:1 through the shared serialization contract: + both sides serialize to the same ``@variables`` key names (pt state_dict key + contract), so swapping each Parameter's data with its grad and re-serializing + yields name-aligned gradient trees. A weight that is a Parameter in pt but + was left a buffer in pt_expt shows up as grad-vs-weight mismatch (no silent + drops); parameter counts are asserted equal as well. +""" + +import numpy as np +import pytest +import torch + +from deepmd.pt.utils import env as pt_env + +from .test_dpa4_dpmodel_parity import ( + _build_descriptor_inputs, +) + +PT_DEVICE = pt_env.DEVICE +_ON_CPU = PT_DEVICE.type == "cpu" +# device-conditional gates (see test_dpa4_dpmodel_parity.py header) +PT_RTOL, PT_ATOL = (1e-12, 1e-14) if _ON_CPU else (1e-10, 1e-12) + + +def to_pt(x: np.ndarray) -> torch.Tensor: + return torch.from_numpy(np.ascontiguousarray(x)).to(PT_DEVICE) + + +def _flatten_arrays(data, prefix="") -> dict[str, np.ndarray]: + """Flatten a serialize() tree to {dotted-path: ndarray}.""" + out: dict[str, np.ndarray] = {} + if isinstance(data, dict): + items = data.items() + elif isinstance(data, (list, tuple)): + items = ((str(i), v) for i, v in enumerate(data)) + else: + return out + for k, v in items: + key = f"{prefix}.{k}" if prefix else str(k) + if isinstance(v, np.ndarray): + out[key] = v + elif isinstance(v, (dict, list, tuple)): + out.update(_flatten_arrays(v, key)) + return out + + +def _swap_data_with_grad(module: torch.nn.Module) -> int: + """Replace each Parameter's data with its gradient, in place. + + Every requires-grad Parameter must have received a non-None grad + (asserted); returns the number of swapped parameters. + """ + missing = [ + n for n, p in module.named_parameters() if p.requires_grad and p.grad is None + ] + assert not missing, f"parameters with no grad after backward: {missing}" + n_swapped = 0 + with torch.no_grad(): + for _, p in module.named_parameters(): + if p.requires_grad: + p.data = p.grad.detach().clone() + n_swapped += 1 + return n_swapped + + +def _assert_grad_trees_match(pt_mod, expt_mod, rtol=PT_RTOL, atol=PT_ATOL) -> None: + """Swap data<->grad on both sides, serialize, compare name-aligned.""" + n_pt = _swap_data_with_grad(pt_mod) + n_expt = _swap_data_with_grad(expt_mod) + # exact trainable-parameter count parity: a buffer wrongly promoted (or a + # parameter left as buffer) changes the count + assert n_pt == n_expt, f"trainable parameter count: pt {n_pt} vs pt_expt {n_expt}" + ref = _flatten_arrays(pt_mod.serialize()) + res = _flatten_arrays(expt_mod.serialize()) + assert sorted(ref) == sorted(res) + for key in sorted(ref): + np.testing.assert_allclose( + res[key], + ref[key], + rtol=rtol, + atol=atol, + err_msg=f"gradient mismatch for {key}", + ) + + +class TestDescriptorGradParity: + nf = 2 + nloc = 6 + nall = 10 + nnei = 20 + ntypes = 2 + + def _build_pair(self, **overrides): + from deepmd.pt.model.descriptor.sezm import ( + DescrptSeZM, + ) + from deepmd.pt_expt.descriptor.dpa4 import ( + DescrptDPA4, + ) + + kwargs = { + "ntypes": self.ntypes, + "sel": self.nnei, + "rcut": 4.0, + "channels": 16, + "n_radial": 8, + "lmax": 2, + "mmax": 1, + "n_blocks": 2, + "grid_branch": [1, 1, 1], + "s2_activation": [False, True], + "random_gamma": False, + "precision": "float64", + "seed": 7, + } + kwargs.update(overrides) + pt_mod = DescrptSeZM(**kwargs).double() + # several projections are zero-initialized; perturb for nonzero + # output and weight-dependent gradients everywhere + rng = np.random.default_rng(2150) + with torch.no_grad(): + for p in pt_mod.parameters(): + p += to_pt(0.05 * rng.normal(size=tuple(p.shape))) + expt_mod = DescrptDPA4.deserialize(pt_mod.serialize()) + return pt_mod, expt_mod + + def _inputs(self, seed=2151): + rng = np.random.default_rng(seed) + return _build_descriptor_inputs( + rng, + nf=self.nf, + nloc=self.nloc, + nall=self.nall, + nnei=self.nnei, + ntypes=self.ntypes, + ) + + @pytest.mark.parametrize("use_env_seed", [False, True]) # env FiLM (film_* params) + def test_descriptor_grad_parity(self, use_env_seed) -> None: + pt_mod, expt_mod = self._build_pair(use_env_seed=use_env_seed) + inp = self._inputs() + coord = inp["coord"].reshape(self.nf, -1) + atype_ext, nlist, mapping = inp["atype_ext"], inp["nlist"], inp["mapping"] + + out_pt = pt_mod( + to_pt(inp["coord"]), + to_pt(atype_ext), + to_pt(nlist), + mapping=to_pt(mapping), + )[0] + out_expt = expt_mod( + to_pt(coord), + to_pt(atype_ext.astype(np.int64)), + to_pt(nlist.astype(np.int64)), + mapping=to_pt(mapping.astype(np.int64)), + )[0] + # guard: forward outputs must match before comparing gradients + np.testing.assert_allclose( + out_expt.detach().cpu().numpy(), + out_pt.detach().cpu().numpy(), + rtol=1e-10, + atol=1e-12, + ) + # quadratic loss -> dL/dw depends on the weights, not just the inputs + (out_pt**2).sum().backward() + (out_expt**2).sum().backward() + # descriptor-level gate (same as the forward parity gate in + # test_dpa4_dpmodel_parity.py): grads chain the full descriptor + # math, where fp64 accumulation-order drift reaches ~3e-11 rel + _assert_grad_trees_match(pt_mod, expt_mod, rtol=1e-10, atol=1e-12) + + +class TestFittingGradParity: + nf = 2 + nloc = 6 + in_dim = 12 + ntypes = 2 + + def _build_pair(self, **overrides): + from deepmd.pt.model.task.sezm_ener import ( + SeZMEnergyFittingNet as SeZMEnergyFittingNetPT, + ) + from deepmd.pt_expt.fitting.dpa4_ener import ( + SeZMEnergyFittingNet as SeZMEnergyFittingNetExpt, + ) + + kwargs = { + "ntypes": self.ntypes, + "dim_descrpt": self.in_dim, + "neuron": [16, 16], + "precision": "float64", + "seed": 7, + } + kwargs.update(overrides) + pt_mod = SeZMEnergyFittingNetPT(**kwargs) + # bias_atom_e is zero-initialized; perturb for a nontrivial bias path + rng = np.random.default_rng(2160) + with torch.no_grad(): + pt_mod.bias_atom_e += to_pt( + rng.normal(size=tuple(pt_mod.bias_atom_e.shape)) + ) + expt_mod = SeZMEnergyFittingNetExpt.deserialize(pt_mod.serialize()) + return pt_mod, expt_mod + + def _inputs(self, seed=2161): + rng = np.random.default_rng(seed) + descriptor = rng.normal(size=(self.nf, self.nloc, self.in_dim)) + atype = rng.integers(0, self.ntypes, size=(self.nf, self.nloc)) + atype[0, 0], atype[0, 1] = 0, 1 + return descriptor, atype + + @pytest.mark.parametrize("bias_out", [False, True]) # output-layer bias + def test_fitting_grad_parity(self, bias_out) -> None: + pt_mod, expt_mod = self._build_pair(bias_out=bias_out) + descriptor, atype = self._inputs() + + out_pt = pt_mod(to_pt(descriptor), to_pt(atype))["energy"] + out_expt = expt_mod(to_pt(descriptor), to_pt(atype.astype(np.int64)))["energy"] + # guard: forward outputs must match before comparing gradients + np.testing.assert_allclose( + out_expt.detach().cpu().numpy(), + out_pt.detach().cpu().numpy(), + rtol=PT_RTOL, + atol=PT_ATOL, + ) + # quadratic loss: energy.sum() would make the grad of the last + # linear layer independent of upstream weights + (out_pt**2).sum().backward() + (out_expt**2).sum().backward() + _assert_grad_trees_match(pt_mod, expt_mod) diff --git a/source/tests/pt_expt/descriptor/test_dpa4.py b/source/tests/pt_expt/descriptor/test_dpa4.py index 20efe20b28..3955491810 100644 --- a/source/tests/pt_expt/descriptor/test_dpa4.py +++ b/source/tests/pt_expt/descriptor/test_dpa4.py @@ -145,3 +145,43 @@ def fn(coord_ext, atype_ext, nlist): rtol=1e-12, atol=1e-12, ) + + def test_trainable_parameters(self) -> None: + # `_promote_trainable_tree` must promote every weight that is a + # trainable nn.Parameter in the reference pt SeZM implementation + # (full 1:1 gradient parity is proven in + # source/tests/pt/model/test_dpa4_ptexpt_grad_parity.py) + dd0 = make_descriptor(self.nt, self.sel_mix, self.rcut, use_env_seed=True).to( + self.device + ) + param_names = dict(dd0.named_parameters()) + buffer_names = dict(dd0.named_buffers()) + # spot-check known trainable weights are Parameters + for name in ( + "type_embedding.adam_type_embedding", + "radial_basis.adam_freqs", + "film_scale_strength_log", + "blocks.0.so2_conv.so2_linears.0.weight_m0", + "blocks.0.so2_conv.so2_linears.0.weight_m.0", + "blocks.0.so2_conv.non_linearities.0.gate_linear.weight", + "blocks.0.post_so2_norm.adam_scale", + "blocks.0.ffns.0.act.grid_op.left_proj.weight", + "output_ffn.so3_linear_1.weight", + ): + assert name in param_names, f"{name} not promoted to Parameter" + assert param_names[name].requires_grad + # spot-check constants stay buffers (pt registers them as buffers) + for name in ( + "mean", + "stddev", + "blocks.0.post_so2_norm.balance_weight", + "blocks.0.so2_conv.rotate_inv_rescale_full", + ): + assert name in buffer_names, f"{name} should stay a buffer" + assert name not in param_names + # wigner tables must never be trainable + assert not any("wigner" in n.lower() for n in param_names) + # all promoted parameters are float and trainable + for name, p in param_names.items(): + assert p.is_floating_point(), name + assert p.requires_grad, name From 1f26ba13597606e9e9b3510087195e2519b5b18b Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 13 Jun 2026 00:10:38 +0800 Subject: [PATCH 08/12] test(consistent): enable pt_expt rows for DPA4 --- .../tests/consistent/descriptor/test_dpa4.py | 19 ++++++++++++-- .../consistent/fitting/test_dpa4_ener.py | 25 +++++++++++++++++-- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/source/tests/consistent/descriptor/test_dpa4.py b/source/tests/consistent/descriptor/test_dpa4.py index ced71078f3..7cfb642ac2 100644 --- a/source/tests/consistent/descriptor/test_dpa4.py +++ b/source/tests/consistent/descriptor/test_dpa4.py @@ -20,6 +20,7 @@ from ..common import ( INSTALLED_PT, + INSTALLED_PT_EXPT, CommonTest, parameterized_cases, ) @@ -31,6 +32,10 @@ from deepmd.pt.model.descriptor.sezm import DescrptSeZM as DescrptDPA4PT else: DescrptDPA4PT = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.descriptor.dpa4 import DescrptDPA4 as DescrptDPA4PTExpt +else: + DescrptDPA4PTExpt = None # not implemented DescrptDPA4TF = None @@ -119,13 +124,13 @@ def skip_pt(self) -> bool: skip_tf = True skip_jax = True skip_pd = True - skip_pt_expt = True + skip_pt_expt = not INSTALLED_PT_EXPT skip_array_api_strict = True tf_class = DescrptDPA4TF dp_class = DescrptDPA4DP pt_class = DescrptDPA4PT - pt_expt_class = None + pt_expt_class = DescrptDPA4PTExpt jax_class = None pd_class = None array_api_strict_class = None @@ -191,6 +196,16 @@ def eval_pt(self, pt_obj: Any) -> Any: mixed_types=True, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret[0],) diff --git a/source/tests/consistent/fitting/test_dpa4_ener.py b/source/tests/consistent/fitting/test_dpa4_ener.py index 666594bf64..64b7918919 100644 --- a/source/tests/consistent/fitting/test_dpa4_ener.py +++ b/source/tests/consistent/fitting/test_dpa4_ener.py @@ -16,6 +16,7 @@ from ..common import ( INSTALLED_PT, + INSTALLED_PT_EXPT, CommonTest, parameterized, ) @@ -30,6 +31,15 @@ from deepmd.pt.utils.env import DEVICE as PT_DEVICE else: SeZMEnerFittingPT = None +if INSTALLED_PT_EXPT: + import torch + + from deepmd.pt_expt.fitting.dpa4_ener import ( + SeZMEnergyFittingNet as SeZMEnerFittingPTExpt, + ) + from deepmd.pt_expt.utils.env import DEVICE as PT_EXPT_DEVICE +else: + SeZMEnerFittingPTExpt = None # not implemented SeZMEnerFittingTF = None @@ -61,13 +71,13 @@ def skip_pt(self) -> bool: skip_tf = True skip_jax = True skip_pd = True - skip_pt_expt = True + skip_pt_expt = not INSTALLED_PT_EXPT skip_array_api_strict = True tf_class = SeZMEnerFittingTF dp_class = SeZMEnerFittingDP pt_class = SeZMEnerFittingPT - pt_expt_class = None + pt_expt_class = SeZMEnerFittingPTExpt jax_class = None pd_class = None array_api_strict_class = None @@ -112,6 +122,17 @@ def eval_pt(self, pt_obj: Any) -> Any: .numpy() ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return ( + pt_expt_obj( + torch.from_numpy(self.inputs).to(device=PT_EXPT_DEVICE), + torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_EXPT_DEVICE), + )["energy"] + .detach() + .cpu() + .numpy() + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret,) From b1920eecb3a90fa845565d781192c0265b4b4ea1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Jun 2026 16:19:14 +0000 Subject: [PATCH 09/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/dpmodel/descriptor/dpa4.py | 7 +++---- deepmd/dpmodel/descriptor/dpa4_nn/edge_cache.py | 2 +- deepmd/dpmodel/descriptor/dpa4_nn/embedding.py | 7 +++---- deepmd/dpmodel/descriptor/dpa4_nn/indexing.py | 2 +- deepmd/dpmodel/descriptor/dpa4_nn/norm.py | 7 +++---- deepmd/dpmodel/descriptor/dpa4_nn/projection.py | 7 +++---- deepmd/dpmodel/descriptor/dpa4_nn/radial.py | 7 +++---- deepmd/dpmodel/descriptor/dpa4_nn/so3.py | 7 +++---- deepmd/dpmodel/descriptor/dpa4_nn/wignerd.py | 7 +++---- deepmd/dpmodel/fitting/dpa4_ener.py | 6 +++--- 10 files changed, 26 insertions(+), 33 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa4.py b/deepmd/dpmodel/descriptor/dpa4.py index effe51afdb..8310c4b5ce 100644 --- a/deepmd/dpmodel/descriptor/dpa4.py +++ b/deepmd/dpmodel/descriptor/dpa4.py @@ -39,6 +39,9 @@ from deepmd.dpmodel import ( NativeOP, ) +from deepmd.dpmodel.array_api import ( + xp_asarray_nodetach, +) from deepmd.dpmodel.common import ( PRECISION_DICT, get_xp_precision, @@ -96,10 +99,6 @@ WignerDCalculator, ) -from deepmd.dpmodel.array_api import ( - xp_asarray_nodetach, -) - if TYPE_CHECKING: from deepmd.dpmodel.array_api import ( Array, diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/edge_cache.py b/deepmd/dpmodel/descriptor/dpa4_nn/edge_cache.py index 03c3a57031..8d2b5cabf1 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/edge_cache.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/edge_cache.py @@ -39,11 +39,11 @@ ) import array_api_compat +import numpy as np from deepmd.dpmodel.array_api import ( xp_asarray_nodetach, ) -import numpy as np from .utils import ( safe_norm, diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/embedding.py b/deepmd/dpmodel/descriptor/dpa4_nn/embedding.py index 40f59b3b00..73ac77456d 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/embedding.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/embedding.py @@ -39,10 +39,6 @@ ) import array_api_compat - -from deepmd.dpmodel.array_api import ( - xp_asarray_nodetach, -) import numpy as np from deepmd.dpmodel import ( @@ -50,6 +46,9 @@ PRECISION_DICT, NativeOP, ) +from deepmd.dpmodel.array_api import ( + xp_asarray_nodetach, +) from deepmd.dpmodel.common import ( to_numpy_array, ) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/indexing.py b/deepmd/dpmodel/descriptor/dpa4_nn/indexing.py index 080bcb58d9..70ee724214 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/indexing.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/indexing.py @@ -22,11 +22,11 @@ ) import array_api_compat +import numpy as np from deepmd.dpmodel.array_api import ( xp_asarray_nodetach, ) -import numpy as np def get_so3_dim_of_lmax(lmax: int) -> int: diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/norm.py b/deepmd/dpmodel/descriptor/dpa4_nn/norm.py index 0e5cafde64..1f08b4ce06 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/norm.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/norm.py @@ -21,10 +21,6 @@ ) import array_api_compat - -from deepmd.dpmodel.array_api import ( - xp_asarray_nodetach, -) import numpy as np from deepmd.dpmodel import ( @@ -32,6 +28,9 @@ PRECISION_DICT, NativeOP, ) +from deepmd.dpmodel.array_api import ( + xp_asarray_nodetach, +) from deepmd.dpmodel.common import ( to_numpy_array, ) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/projection.py b/deepmd/dpmodel/descriptor/dpa4_nn/projection.py index b21ccfaccb..e5131b5d12 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/projection.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/projection.py @@ -42,10 +42,6 @@ ) import array_api_compat - -from deepmd.dpmodel.array_api import ( - xp_asarray_nodetach, -) import numpy as np from deepmd.dpmodel import ( @@ -53,6 +49,9 @@ PRECISION_DICT, NativeOP, ) +from deepmd.dpmodel.array_api import ( + xp_asarray_nodetach, +) from deepmd.dpmodel.utils.lebedev import ( LEBEDEV_PRECISION_TO_NPOINTS, load_lebedev_rule, diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/radial.py b/deepmd/dpmodel/descriptor/dpa4_nn/radial.py index 4428aab5fb..94e506fd0b 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/radial.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/radial.py @@ -22,10 +22,6 @@ ) import array_api_compat - -from deepmd.dpmodel.array_api import ( - xp_asarray_nodetach, -) import numpy as np from deepmd.dpmodel import ( @@ -33,6 +29,9 @@ PRECISION_DICT, NativeOP, ) +from deepmd.dpmodel.array_api import ( + xp_asarray_nodetach, +) from deepmd.dpmodel.common import ( to_numpy_array, ) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/so3.py b/deepmd/dpmodel/descriptor/dpa4_nn/so3.py index 699d8cbc93..d16d75002a 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/so3.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/so3.py @@ -32,10 +32,6 @@ ) import array_api_compat - -from deepmd.dpmodel.array_api import ( - xp_asarray_nodetach, -) import numpy as np from deepmd.dpmodel import ( @@ -43,6 +39,9 @@ PRECISION_DICT, NativeOP, ) +from deepmd.dpmodel.array_api import ( + xp_asarray_nodetach, +) from deepmd.dpmodel.common import ( to_numpy_array, ) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/wignerd.py b/deepmd/dpmodel/descriptor/dpa4_nn/wignerd.py index c630a952f5..c63f4b3eb6 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/wignerd.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/wignerd.py @@ -40,16 +40,15 @@ ) import array_api_compat - -from deepmd.dpmodel.array_api import ( - xp_asarray_nodetach, -) import numpy as np from deepmd.dpmodel import ( DEFAULT_PRECISION, NativeOP, ) +from deepmd.dpmodel.array_api import ( + xp_asarray_nodetach, +) from deepmd.dpmodel.common import ( get_xp_precision, ) diff --git a/deepmd/dpmodel/fitting/dpa4_ener.py b/deepmd/dpmodel/fitting/dpa4_ener.py index f8f1d8c8da..f08097f032 100644 --- a/deepmd/dpmodel/fitting/dpa4_ener.py +++ b/deepmd/dpmodel/fitting/dpa4_ener.py @@ -15,12 +15,12 @@ DEFAULT_PRECISION, NativeOP, ) -from deepmd.dpmodel.common import ( - to_numpy_array, -) from deepmd.dpmodel.array_api import ( Array, ) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.utils.network import ( NativeLayer, get_activation_fn, From 374930088120e1c39f81be2e2f6f555e1f916774 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 13 Jun 2026 09:30:07 +0800 Subject: [PATCH 10/12] fix(pt_expt): address review findings on DPA4 PR (trainable=False freeze, component type validation) --- deepmd/pt_expt/descriptor/dpa4.py | 8 +++++++ deepmd/pt_expt/model/get_model.py | 12 +++++++++++ source/tests/pt_expt/descriptor/test_dpa4.py | 20 ++++++++++++++++++ .../pt_expt/model/test_get_model_dpa4.py | 21 +++++++++++++++++++ 4 files changed, 61 insertions(+) diff --git a/deepmd/pt_expt/descriptor/dpa4.py b/deepmd/pt_expt/descriptor/dpa4.py index dd93ac5e3c..cd41b1aaf8 100644 --- a/deepmd/pt_expt/descriptor/dpa4.py +++ b/deepmd/pt_expt/descriptor/dpa4.py @@ -114,6 +114,14 @@ def _promote_trainable_tree(module: torch.nn.Module) -> torch.nn.Module: names = _TRAINABLE_ATTRS.get(type(sub).__name__) if names is not None: _promote_trainable(sub, names) + # Freeze every Parameter under a ``trainable=False`` module. This covers + # parameters that exist regardless of the promotion table above, e.g. the + # ``SO2Linear.weight_m`` list, which ``_try_convert_list`` converts to a + # ParameterList with ``requires_grad=True`` unconditionally. + for sub in module.modules(): + if getattr(sub, "trainable", True) is False: + for p in sub.parameters(recurse=True): + p.requires_grad_(False) return module diff --git a/deepmd/pt_expt/model/get_model.py b/deepmd/pt_expt/model/get_model.py index 94c3a3ae41..c0ac078024 100644 --- a/deepmd/pt_expt/model/get_model.py +++ b/deepmd/pt_expt/model/get_model.py @@ -153,6 +153,18 @@ def get_sezm_model(data: dict) -> EnergyModel: data.setdefault("fitting_net", {}) data["descriptor"].setdefault("type", "dpa4") data["fitting_net"].setdefault("type", "dpa4_ener") + # the DPA4/SeZM model type is a fixed descriptor/fitting contract; reject + # explicit mismatching component types instead of silently building them + if data["descriptor"]["type"] not in ("dpa4", "DPA4", "sezm", "SeZM"): + raise ValueError( + "Model type 'dpa4' requires a DPA4/SeZM descriptor, but got " + f"descriptor type '{data['descriptor']['type']}'." + ) + if data["fitting_net"]["type"] not in ("dpa4_ener", "sezm_ener"): + raise ValueError( + "Model type 'dpa4' requires the DPA4/SeZM energy fitting net, but got " + f"fitting_net type '{data['fitting_net']['type']}'." + ) # keep descriptor.exclude_types and model pair_exclude_types consistent descriptor_exclude_types = [ diff --git a/source/tests/pt_expt/descriptor/test_dpa4.py b/source/tests/pt_expt/descriptor/test_dpa4.py index 3955491810..ef4a10227b 100644 --- a/source/tests/pt_expt/descriptor/test_dpa4.py +++ b/source/tests/pt_expt/descriptor/test_dpa4.py @@ -185,3 +185,23 @@ def test_trainable_parameters(self) -> None: for name, p in param_names.items(): assert p.is_floating_point(), name assert p.requires_grad, name + + @pytest.mark.parametrize( + "via_deserialize", [False, True] + ) # constructor vs round-trip + def test_trainable_false_freezes_all_parameters(self, via_deserialize) -> None: + # trainable=False must freeze every parameter, including ParameterList + # entries such as SO2Linear.weight_m (mmax>=1) that dpmodel_setattr + # converts with requires_grad=True + dd0 = make_descriptor( + self.nt, self.sel_mix, self.rcut, use_env_seed=True, trainable=False + ).to(self.device) + if via_deserialize: + dd0 = DescrptDPA4.deserialize(dd0.serialize()) + params = dict(dd0.named_parameters()) + assert any(".weight_m." in n for n in params) # mmax>=1 exercised + frozen = [n for n, p in params.items() if not p.requires_grad] + assert frozen == list(params), ( + f"trainable=False left parameters trainable: " + f"{sorted(set(params) - set(frozen))}" + ) diff --git a/source/tests/pt_expt/model/test_get_model_dpa4.py b/source/tests/pt_expt/model/test_get_model_dpa4.py index 8788952eee..f550d8fab2 100644 --- a/source/tests/pt_expt/model/test_get_model_dpa4.py +++ b/source/tests/pt_expt/model/test_get_model_dpa4.py @@ -107,6 +107,27 @@ def test_descriptor_fitting_type_defaults(self) -> None: model = get_model(raw) self.assertIsInstance(model, EnergyModel) + def test_explicit_matching_component_types_ok(self) -> None: + """Explicit dpa4/sezm descriptor and fitting types are accepted.""" + for desc_type, fit_type in (("dpa4", "dpa4_ener"), ("sezm", "sezm_ener")): + raw = _make_raw_model_config() + raw["descriptor"]["type"] = desc_type + raw["fitting_net"]["type"] = fit_type + model = get_model(raw) + self.assertIsInstance(model, EnergyModel, msg=f"{desc_type}/{fit_type}") + + def test_explicit_mismatching_descriptor_type_raises(self) -> None: + raw = _make_raw_model_config() + raw["descriptor"]["type"] = "se_e2_a" + with self.assertRaisesRegex(ValueError, "requires a DPA4/SeZM descriptor"): + get_model(raw) + + def test_explicit_mismatching_fitting_type_raises(self) -> None: + raw = _make_raw_model_config() + raw["fitting_net"]["type"] = "ener" + with self.assertRaisesRegex(ValueError, "energy fitting net"): + get_model(raw) + def test_pair_exclude_types_from_descriptor(self) -> None: """descriptor.exclude_types propagates when pair_exclude_types absent.""" raw = _make_raw_model_config() From 79bce06e5a9d228e139a7173f9eae3b25088dc94 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 14 Jun 2026 11:39:22 +0800 Subject: [PATCH 11/12] fix(dpa4): convert rebuilt index buffers via to_numpy_array in deserialize (CUDA) --- deepmd/dpmodel/descriptor/dpa4_nn/activation.py | 2 +- deepmd/dpmodel/descriptor/dpa4_nn/norm.py | 4 ++-- deepmd/dpmodel/descriptor/dpa4_nn/so2.py | 6 +++++- deepmd/dpmodel/descriptor/dpa4_nn/so3.py | 2 +- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/activation.py b/deepmd/dpmodel/descriptor/dpa4_nn/activation.py index 6e5ba7dbc2..8b9897ff07 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/activation.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/activation.py @@ -287,7 +287,7 @@ def deserialize(cls, data: dict[str, Any]) -> GatedActivation: ) prec = PRECISION_DICT[obj.precision.lower()] expand_index = np.asarray(variables["expand_index"], dtype=np.int64) - if not np.array_equal(expand_index, obj.expand_index): + if not np.array_equal(expand_index, to_numpy_array(obj.expand_index)): raise ValueError("expand_index does not match the lmax/mmax tables") if obj.gate_linear is not None: weight = np.asarray(variables["gate_linear.weight"], dtype=prec) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/norm.py b/deepmd/dpmodel/descriptor/dpa4_nn/norm.py index 1f08b4ce06..02cc4755dc 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/norm.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/norm.py @@ -326,7 +326,7 @@ def deserialize(cls, data: dict[str, Any]) -> EquivariantRMSNorm: ) prec = PRECISION_DICT[obj.precision.lower()] expand_index = np.asarray(variables["expand_index"], dtype=np.int64) - if not np.array_equal(expand_index, obj.expand_index): + if not np.array_equal(expand_index, to_numpy_array(obj.expand_index)): raise ValueError("expand_index does not match the lmax-derived table") for name in ("adam_scale", "bias", "balance_weight"): value = np.asarray(variables[name], dtype=prec) @@ -535,7 +535,7 @@ def deserialize(cls, data: dict[str, Any]) -> ReducedEquivariantRMSNorm: ) prec = PRECISION_DICT[obj.precision.lower()] degree_index_m = np.asarray(variables["degree_index_m"], dtype=np.int64) - if not np.array_equal(degree_index_m, obj.degree_index_m): + if not np.array_equal(degree_index_m, to_numpy_array(obj.degree_index_m)): raise ValueError("degree_index_m variable does not match the config") for name in ("balance_weight", "adam_scale", "bias0"): value = np.asarray(variables[name], dtype=prec) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/so2.py b/deepmd/dpmodel/descriptor/dpa4_nn/so2.py index 1986b5e9c9..ac4f7d298b 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/so2.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/so2.py @@ -119,7 +119,11 @@ def _check_shape_assign(obj: Any, attr: str, value: Any, dtype: Any, key: str) - def _check_index_table(expected: np.ndarray, value: Any, key: str) -> None: """Validate that a serialized integer index table matches the rebuilt one.""" arr = np.asarray(value, dtype=np.int64) - if not np.array_equal(arr.reshape(-1), np.asarray(expected).reshape(-1)): + # ``expected`` is a rebuilt buffer that may be a (possibly CUDA) torch + # tensor in the pt_expt backend; ``np.asarray`` raises on CUDA tensors and + # ``np.array_equal`` would silently swallow that into ``False``. Convert via + # ``to_numpy_array`` (dlpack-through-CPU fallback) before comparing. + if not np.array_equal(arr.reshape(-1), to_numpy_array(expected).reshape(-1)): raise ValueError(f"{key} does not match the table derived from the config") diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/so3.py b/deepmd/dpmodel/descriptor/dpa4_nn/so3.py index d16d75002a..6a4b36aa4d 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/so3.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/so3.py @@ -539,7 +539,7 @@ def deserialize(cls, data: dict[str, Any]) -> SO3Linear: ) prec = PRECISION_DICT[obj.precision.lower()] expand_index = np.asarray(variables["expand_index"], dtype=np.int64) - if not np.array_equal(expand_index, obj.expand_index): + if not np.array_equal(expand_index, to_numpy_array(obj.expand_index)): raise ValueError("expand_index does not match the lmax-derived table") weight = np.asarray(variables["weight"], dtype=prec) if weight.shape != obj.weight.shape: From 56793cfa3048da6adb40993f3510e11bbb2fd211 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 14 Jun 2026 11:49:05 +0800 Subject: [PATCH 12/12] fix(dpa4): convert mean/stddev/wigner buffers via to_numpy_array in serialize (CUDA) --- deepmd/dpmodel/descriptor/dpa4.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa4.py b/deepmd/dpmodel/descriptor/dpa4.py index 8310c4b5ce..338ee3353e 100644 --- a/deepmd/dpmodel/descriptor/dpa4.py +++ b/deepmd/dpmodel/descriptor/dpa4.py @@ -1053,8 +1053,8 @@ def _variables(self) -> dict[str, np.ndarray]: # pt interface-compatibility buffers "version_tensor": np.asarray(self.version, dtype=np.float64), "_empty_tensor": np.zeros((0,), dtype=np.float64), - "mean": np.asarray(self.mean, dtype=model_np_prec), - "stddev": np.asarray(self.stddev, dtype=model_np_prec), + "mean": to_numpy_array(self.mean).astype(model_np_prec), + "stddev": to_numpy_array(self.stddev).astype(model_np_prec), } def add(prefix: str, sub_vars: dict[str, Any]) -> None: @@ -1083,7 +1083,7 @@ def add(prefix: str, sub_vars: dict[str, Any]) -> None: def wigner_buffers(calc: WignerDCalculator) -> dict[str, np.ndarray]: return { "l1_perm": np.asarray([1, 2, 0], dtype=np.int64), - "l1_sign_outer": np.asarray(calc.l1_sign_outer, dtype=np.float64), + "l1_sign_outer": to_numpy_array(calc.l1_sign_outer).astype(np.float64), } add("wigner_calc.", wigner_buffers(self.wigner_calc))