From ec2e031e4384a66e96c2d64d45f4dec897bb769d Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 07:48:20 +0800 Subject: [PATCH 01/33] implement pytorch-exportable for se_e2_a descriptor --- deepmd/backend/pt_expt.py | 126 ++++++++++++++++ deepmd/dpmodel/descriptor/se_e2_a.py | 6 +- deepmd/pt_expt/__init__.py | 1 + deepmd/pt_expt/descriptor/__init__.py | 8 ++ deepmd/pt_expt/descriptor/se_e2_a.py | 101 +++++++++++++ deepmd/pt_expt/utils/__init__.py | 1 + deepmd/pt_expt/utils/network.py | 130 +++++++++++++++++ source/tests/consistent/common.py | 80 ++++++++++- source/tests/consistent/descriptor/common.py | 31 +++- .../consistent/descriptor/test_se_e2_a.py | 60 ++++++++ source/tests/pt_expt/__init__.py | 1 + source/tests/pt_expt/model/__init__.py | 1 + source/tests/pt_expt/model/test_se_e2_a.py | 135 ++++++++++++++++++ 13 files changed, 676 insertions(+), 5 deletions(-) create mode 100644 deepmd/backend/pt_expt.py create mode 100644 deepmd/pt_expt/__init__.py create mode 100644 deepmd/pt_expt/descriptor/__init__.py create mode 100644 deepmd/pt_expt/descriptor/se_e2_a.py create mode 100644 deepmd/pt_expt/utils/__init__.py create mode 100644 deepmd/pt_expt/utils/network.py create mode 100644 source/tests/pt_expt/__init__.py create mode 100644 source/tests/pt_expt/model/__init__.py create mode 100644 source/tests/pt_expt/model/test_se_e2_a.py diff --git a/deepmd/backend/pt_expt.py b/deepmd/backend/pt_expt.py new file mode 100644 index 0000000000..38745c690c --- /dev/null +++ b/deepmd/backend/pt_expt.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from collections.abc import ( + Callable, +) +from importlib.util import ( + find_spec, +) +from typing import ( + TYPE_CHECKING, + ClassVar, +) + +from deepmd.backend.backend import ( + Backend, +) + +if TYPE_CHECKING: + from argparse import ( + Namespace, + ) + + from deepmd.infer.deep_eval import ( + DeepEvalBackend, + ) + from deepmd.utils.neighbor_stat import ( + NeighborStat, + ) + + +@Backend.register("pt-expt") +@Backend.register("pytorch-exportable") +class PyTorchExportableBackend(Backend): + """PyTorch exportable backend.""" + + name = "PyTorch Exportable" + """The formal name of the backend.""" + features: ClassVar[Backend.Feature] = ( + Backend.Feature.ENTRY_POINT + | Backend.Feature.DEEP_EVAL + | Backend.Feature.NEIGHBOR_STAT + | Backend.Feature.IO + ) + """The features of the backend.""" + suffixes: ClassVar[list[str]] = [".pth", ".pt"] + """The suffixes of the backend.""" + + def is_available(self) -> bool: + """Check if the backend is available. + + Returns + ------- + bool + Whether the backend is available. + """ + return find_spec("torch") is not None + + @property + def entry_point_hook(self) -> Callable[["Namespace"], None]: + """The entry point hook of the backend. + + Returns + ------- + Callable[[Namespace], None] + The entry point hook of the backend. + """ + from deepmd.pt.entrypoints.main import main as deepmd_main + + return deepmd_main + + @property + def deep_eval(self) -> type["DeepEvalBackend"]: + """The Deep Eval backend of the backend. + + Returns + ------- + type[DeepEvalBackend] + The Deep Eval backend of the backend. + """ + from deepmd.pt.infer.deep_eval import DeepEval as DeepEvalPT + + return DeepEvalPT + + @property + def neighbor_stat(self) -> type["NeighborStat"]: + """The neighbor statistics of the backend. + + Returns + ------- + type[NeighborStat] + The neighbor statistics of the backend. + """ + from deepmd.pt.utils.neighbor_stat import ( + NeighborStat, + ) + + return NeighborStat + + @property + def serialize_hook(self) -> Callable[[str], dict]: + """The serialize hook to convert the model file to a dictionary. + + Returns + ------- + Callable[[str], dict] + The serialize hook of the backend. + """ + from deepmd.pt.utils.serialization import ( + serialize_from_file, + ) + + return serialize_from_file + + @property + def deserialize_hook(self) -> Callable[[str, dict], None]: + """The deserialize hook to convert the dictionary to a model file. + + Returns + ------- + Callable[[str, dict], None] + The deserialize hook of the backend. + """ + from deepmd.pt.utils.serialization import ( + deserialize_to_file, + ) + + return deserialize_to_file diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index c09a6cbdc3..a6b17bf69a 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -607,7 +607,11 @@ def call( sec = self.sel_cumsum ng = self.neuron[-1] - gr = xp.zeros([nf * nloc, ng, 4], dtype=self.dstd.dtype) + gr = xp.zeros( + [nf * nloc, ng, 4], + dtype=self.dstd.dtype, + device=array_api_compat.device(coord_ext), + ) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) # merge nf and nloc axis, so for type_one_side == False, # we don't require atype is the same in all frames diff --git a/deepmd/pt_expt/__init__.py b/deepmd/pt_expt/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/pt_expt/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/pt_expt/descriptor/__init__.py b/deepmd/pt_expt/descriptor/__init__.py new file mode 100644 index 0000000000..fdac48ed41 --- /dev/null +++ b/deepmd/pt_expt/descriptor/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .se_e2_a import ( + DescrptSeA, +) + +__all__ = [ + "DescrptSeA", +] diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py new file mode 100644 index 0000000000..4334011ec3 --- /dev/null +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch # noqa: TID253 + +from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP +from deepmd.pt.model.descriptor.base_descriptor import ( # noqa: TID253 + BaseDescriptor, +) +from deepmd.pt.utils import ( # noqa: TID253 + env, +) +from deepmd.pt.utils.exclude_mask import ( # noqa: TID253 + PairExcludeMask, +) +from deepmd.pt_expt.utils.network import ( + NetworkCollection, +) + + +@BaseDescriptor.register("se_e2_a_expt") +@BaseDescriptor.register("se_a_expt") +class DescrptSeA(DescrptSeADP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + DescrptSeADP.__init__(self, *args, **kwargs) + self._convert_state() + + def __setattr__(self, name: str, value: Any) -> None: + if name in {"davg", "dstd"} and "_buffers" in self.__dict__: + tensor = ( + None if value is None else torch.as_tensor(value, device=env.DEVICE) + ) + if name in self._buffers: + self._buffers[name] = tensor + return + return super().__setattr__(name, tensor) + if name == "embeddings" and "_modules" in self.__dict__: + if value is not None and not isinstance(value, torch.nn.Module): + if hasattr(value, "serialize"): + value = NetworkCollection.deserialize(value.serialize()) + elif isinstance(value, dict): + value = NetworkCollection.deserialize(value) + return super().__setattr__(name, value) + if name == "emask" and "_modules" in self.__dict__: + if value is not None and not isinstance(value, torch.nn.Module): + value = PairExcludeMask( + self.ntypes, exclude_types=list(value.get_exclude_types()) + ) + return super().__setattr__(name, value) + return super().__setattr__(name, value) + + def _convert_state(self) -> None: + if self.davg is not None: + davg = torch.as_tensor(self.davg, device=env.DEVICE) + if "davg" in self._buffers: + self._buffers["davg"] = davg + else: + if hasattr(self, "davg"): + delattr(self, "davg") + self.register_buffer("davg", davg) + if self.dstd is not None: + dstd = torch.as_tensor(self.dstd, device=env.DEVICE) + if "dstd" in self._buffers: + self._buffers["dstd"] = dstd + else: + if hasattr(self, "dstd"): + delattr(self, "dstd") + self.register_buffer("dstd", dstd) + if self.embeddings is not None: + self.embeddings = NetworkCollection.deserialize(self.embeddings.serialize()) + if self.emask is not None: + self.emask = PairExcludeMask( + self.ntypes, exclude_types=list(self.emask.get_exclude_types()) + ) + + def forward( + self, + nlist: torch.Tensor, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_atype_embd: torch.Tensor | None = None, + mapping: torch.Tensor | None = None, + type_embedding: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + del extended_atype_embd, type_embedding + descrpt, rot_mat, g2, h2, sw = self.call( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + return descrpt, rot_mat, g2, h2, sw diff --git a/deepmd/pt_expt/utils/__init__.py b/deepmd/pt_expt/utils/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/pt_expt/utils/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py new file mode 100644 index 0000000000..f29d8970b3 --- /dev/null +++ b/deepmd/pt_expt/utils/network.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + ClassVar, + Self, +) + +import numpy as np +import torch # noqa: TID253 + +from deepmd.dpmodel.common import ( + NativeOP, +) +from deepmd.dpmodel.utils.network import LayerNorm as LayerNormDP +from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP +from deepmd.dpmodel.utils.network import NetworkCollection as NetworkCollectionDP +from deepmd.dpmodel.utils.network import ( + make_embedding_network, + make_fitting_network, + make_multilayer_network, +) +from deepmd.pt.utils import ( # noqa: TID253 + env, +) + + +def _to_torch_array(value: Any) -> torch.Tensor | None: + if value is None: + return None + if torch.is_tensor(value): + return value + return torch.as_tensor(value, device=env.DEVICE) + + +class TorchArrayParam(torch.nn.Parameter): + def __new__(cls, data: Any = None, requires_grad: bool = True) -> Self: + return torch.nn.Parameter.__new__(cls, data, requires_grad) + + def __array__(self, dtype: Any | None = None) -> np.ndarray: + arr = self.detach().cpu().numpy() + if dtype is None: + return arr + return arr.astype(dtype) + + +class NativeLayer(NativeLayerDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + NativeLayerDP.__init__(self, *args, **kwargs) + for name in ("w", "b", "idt"): + if name in self._parameters or name in self._buffers: + continue + val = _to_torch_array(getattr(self, name)) + if val is None: + continue + if self.trainable: + if hasattr(self, name) and name not in self._parameters: + delattr(self, name) + self.register_parameter(name, TorchArrayParam(val, requires_grad=True)) + else: + if hasattr(self, name) and name not in self._buffers: + delattr(self, name) + self.register_buffer(name, val) + + def __setattr__(self, name: str, value: Any) -> None: + if name in {"w", "b", "idt"} and "_parameters" in self.__dict__: + val = _to_torch_array(value) + if val is None: + return super().__setattr__(name, None) + if getattr(self, "trainable", False): + param = ( + value + if isinstance(value, TorchArrayParam) + else TorchArrayParam(val, requires_grad=True) + ) + if name in self._parameters: + self._parameters[name] = param + return + return super().__setattr__(name, param) + if name in self._buffers: + self._buffers[name] = val + return + return super().__setattr__(name, val) + return super().__setattr__(name, value) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.call(x) + + +class NativeNet(make_multilayer_network(NativeLayer, NativeOP), torch.nn.Module): + def __init__(self, layers: list[dict] | None = None) -> None: + torch.nn.Module.__init__(self) + super().__init__(layers) + self.layers = torch.nn.ModuleList(self.layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.call(x) + + +class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)): + pass + + +class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)): + pass + + +class NetworkCollection(NetworkCollectionDP, torch.nn.Module): + NETWORK_TYPE_MAP: ClassVar[dict[str, type]] = { + "network": NativeNet, + "embedding_network": EmbeddingNet, + "fitting_network": FittingNet, + } + + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + super().__init__(*args, **kwargs) + self._module_networks = torch.nn.ModuleDict() + for idx, net in enumerate(self._networks): + if isinstance(net, torch.nn.Module): + self._module_networks[str(idx)] = net + + def __setitem__(self, key: int | tuple, value: Any) -> None: + super().__setitem__(key, value) + if isinstance(value, torch.nn.Module): + self._module_networks[str(self._convert_key(key))] = value + + +class LayerNorm(LayerNormDP, NativeLayer): + pass diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index 88fad4e10b..3d60f6def0 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -41,6 +41,11 @@ INSTALLED_TF = Backend.get_backend("tensorflow")().is_available() INSTALLED_PT = Backend.get_backend("pytorch")().is_available() +try: + _PT_EXPT_BACKEND = Backend.get_backend("pytorch-exportable") +except (KeyError, RuntimeError): + _PT_EXPT_BACKEND = None +INSTALLED_PT_EXPT = _PT_EXPT_BACKEND is not None and _PT_EXPT_BACKEND().is_available() INSTALLED_JAX = Backend.get_backend("jax")().is_available() INSTALLED_PD = Backend.get_backend("paddle")().is_available() INSTALLED_ARRAY_API_STRICT = find_spec("array_api_strict") is not None @@ -67,6 +72,7 @@ "INSTALLED_JAX", "INSTALLED_PD", "INSTALLED_PT", + "INSTALLED_PT_EXPT", "INSTALLED_TF", "CommonTest", "CommonTest", @@ -86,6 +92,8 @@ class CommonTest(ABC): """Native DP model class.""" pt_class: ClassVar[type | None] """PyTorch model class.""" + pt_expt_class: ClassVar[type | None] + """PyTorch exportable model class.""" jax_class: ClassVar[type | None] """JAX model class.""" pd_class: ClassVar[type | None] @@ -99,6 +107,8 @@ class CommonTest(ABC): """Whether to skip the TensorFlow model.""" skip_pt: ClassVar[bool] = not INSTALLED_PT """Whether to skip the PyTorch model.""" + skip_pt_expt: ClassVar[bool] = not INSTALLED_PT_EXPT + """Whether to skip the PyTorch exportable model.""" # we may usually skip jax before jax is fully supported skip_jax: ClassVar[bool] = True """Whether to skip the JAX model.""" @@ -176,6 +186,16 @@ def eval_pt(self, pt_obj: Any) -> Any: The object of PT """ + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + """Evaluate the return value of PT exportable. + + Parameters + ---------- + pt_expt_obj : Any + The object of PT exportable + """ + raise NotImplementedError("Not implemented") + def eval_jax(self, jax_obj: Any) -> Any: """Evaluate the return value of JAX. @@ -212,9 +232,10 @@ class RefBackend(Enum): TF = 1 DP = 2 PT = 3 - PD = 4 - JAX = 5 - ARRAY_API_STRICT = 6 + PT_EXPT = 4 + PD = 5 + JAX = 6 + ARRAY_API_STRICT = 7 @abstractmethod def extract_ret(self, ret: Any, backend: RefBackend) -> tuple[np.ndarray, ...]: @@ -275,6 +296,11 @@ def get_dp_ret_serialization_from_cls(self, obj): data = obj.serialize() return ret, data + def get_pt_expt_ret_serialization_from_cls(self, obj): + ret = self.eval_pt_expt(obj) + data = obj.serialize() + return ret, data + def get_jax_ret_serialization_from_cls(self, obj): ret = self.eval_jax(obj) data = obj.serialize() @@ -301,6 +327,8 @@ def get_reference_backend(self): return self.RefBackend.TF if not self.skip_pt: return self.RefBackend.PT + if not self.skip_pt_expt and self.pt_expt_class is not None: + return self.RefBackend.PT_EXPT if not self.skip_jax: return self.RefBackend.JAX if not self.skip_pd: @@ -320,6 +348,11 @@ def get_reference_ret_serialization(self, ref: RefBackend): if ref == self.RefBackend.PT: obj = self.init_backend_cls(self.pt_class) return self.get_pt_ret_serialization_from_cls(obj) + if ref == self.RefBackend.PT_EXPT: + if self.pt_expt_class is None: + raise ValueError("PT exportable class is not set") + obj = self.init_backend_cls(self.pt_expt_class) + return self.get_pt_expt_ret_serialization_from_cls(obj) if ref == self.RefBackend.JAX: obj = self.init_backend_cls(self.jax_class) return self.get_jax_ret_serialization_from_cls(obj) @@ -456,6 +489,47 @@ def test_pt_self_consistent(self) -> None: else: self.assertEqual(rr1, rr2) + def test_pt_expt_consistent_with_ref(self) -> None: + """Test whether PT exportable and reference are consistent.""" + if self.skip_pt_expt or self.pt_expt_class is None: + self.skipTest("Unsupported backend") + ref_backend = self.get_reference_backend() + if ref_backend == self.RefBackend.PT_EXPT: + self.skipTest("Reference is self") + ret1, data1 = self.get_reference_ret_serialization(ref_backend) + ret1 = self.extract_ret(ret1, ref_backend) + obj = self.pt_expt_class.deserialize(data1) + ret2 = self.eval_pt_expt(obj) + ret2 = self.extract_ret(ret2, self.RefBackend.PT_EXPT) + data2 = obj.serialize() + if obj.__class__.__name__.startswith(("Polar", "Dipole", "DOS")): + common_keys = set(data1.keys()) & set(data2.keys()) + data1 = {k: data1[k] for k in common_keys} + data2 = {k: data2[k] for k in common_keys} + # drop @variables since they are not equal + data1.pop("@variables", None) + data2.pop("@variables", None) + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2, strict=True): + np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + + def test_pt_expt_self_consistent(self) -> None: + """Test whether PT exportable is self consistent.""" + if self.skip_pt_expt or self.pt_expt_class is None: + self.skipTest("Unsupported backend") + obj1 = self.init_backend_cls(self.pt_expt_class) + ret1, data1 = self.get_pt_expt_ret_serialization_from_cls(obj1) + obj2 = self.pt_expt_class.deserialize(data1) + ret2, data2 = self.get_pt_expt_ret_serialization_from_cls(obj2) + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2, strict=True): + if isinstance(rr1, np.ndarray) and isinstance(rr2, np.ndarray): + np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + else: + self.assertEqual(rr1, rr2) + def test_jax_consistent_with_ref(self) -> None: """Test whether JAX and reference are consistent.""" if self.skip_jax: diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index 8af1c7ea64..7c8cbce744 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -21,10 +21,11 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, ) -if INSTALLED_PT: +if INSTALLED_PT or INSTALLED_PT_EXPT: import torch from deepmd.pt.utils.env import DEVICE as PT_DEVICE @@ -143,6 +144,34 @@ def eval_pt_descriptor( for x in pt_obj(ext_coords, ext_atype, nlist=nlist, mapping=mapping) ] + def eval_pt_expt_descriptor( + self, + pt_expt_obj: Any, + natoms: np.ndarray, + coords: np.ndarray, + atype: np.ndarray, + box: np.ndarray, + mixed_types: bool = False, + ) -> Any: + ext_coords, ext_atype, mapping = extend_coord_with_ghosts( + torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), + torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), + torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3), + pt_expt_obj.get_rcut(), + ) + nlist = build_neighbor_list( + ext_coords, + ext_atype, + natoms[0], + pt_expt_obj.get_rcut(), + pt_expt_obj.get_sel(), + distinguish_types=(not mixed_types), + ) + return [ + x.detach().cpu().numpy() if torch.is_tensor(x) else x + for x in pt_expt_obj(ext_coords, ext_atype, nlist=nlist, mapping=mapping) + ] + def eval_jax_descriptor( self, jax_obj: Any, diff --git a/source/tests/consistent/descriptor/test_se_e2_a.py b/source/tests/consistent/descriptor/test_se_e2_a.py index b345a61ed3..68a0068965 100644 --- a/source/tests/consistent/descriptor/test_se_e2_a.py +++ b/source/tests/consistent/descriptor/test_se_e2_a.py @@ -16,6 +16,7 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, parameterized, @@ -33,6 +34,10 @@ ) else: DescrptSeAPT = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.descriptor.se_e2_a import DescrptSeA as DescrptSeAPTExpt +else: + DescrptSeAPTExpt = None if INSTALLED_TF: from deepmd.tf.descriptor.se_a import DescrptSeA as DescrptSeATF else: @@ -107,6 +112,17 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + @property + def skip_pt_expt(self) -> bool: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + env_protection, + ) = self.param + return (not type_one_side) or CommonTest.skip_pt_expt + @property def skip_dp(self) -> bool: ( @@ -165,6 +181,7 @@ def skip_array_api_strict(self) -> bool: tf_class = DescrptSeATF dp_class = DescrptSeADP pt_class = DescrptSeAPT + pt_expt_class = DescrptSeAPTExpt jax_class = DescrptSeAJAX pd_class = DescrptSeAPD array_api_strict_class = DescrptSeAArrayAPIStrict @@ -244,6 +261,15 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def eval_jax(self, jax_obj: Any) -> Any: return self.eval_jax_descriptor( jax_obj, @@ -351,6 +377,17 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + @property + def skip_pt_expt(self) -> bool: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + env_protection, + ) = self.param + return (not type_one_side) or CommonTest.skip_pt_expt + @property def skip_dp(self) -> bool: ( @@ -402,6 +439,7 @@ def skip_array_api_strict(self) -> bool: tf_class = DescrptSeATF dp_class = DescrptSeADP pt_class = DescrptSeAPT + pt_expt_class = DescrptSeAPTExpt jax_class = DescrptSeAJAX pd_class = DescrptSeAPD array_api_strict_class = DescrptSeAArrayAPIStrict @@ -505,6 +543,28 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + pt_expt_obj.compute_input_stats( + [ + { + "r0": None, + "coord": torch.from_numpy(self.coords) + .reshape(-1, self.natoms[0], 3) + .to(env.DEVICE), + "atype": torch.from_numpy(self.atype.reshape(1, -1)).to(env.DEVICE), + "box": torch.from_numpy(self.box.reshape(1, 3, 3)).to(env.DEVICE), + "natoms": self.natoms[0], + } + ] + ) + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def eval_jax(self, jax_obj: Any) -> Any: jax_obj.compute_input_stats( [ diff --git a/source/tests/pt_expt/__init__.py b/source/tests/pt_expt/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt_expt/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt_expt/model/__init__.py b/source/tests/pt_expt/model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt_expt/model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt_expt/model/test_se_e2_a.py b/source/tests/pt_expt/model/test_se_e2_a.py new file mode 100644 index 0000000000..b9b834849f --- /dev/null +++ b/source/tests/pt_expt/model/test_se_e2_a.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch # noqa: TID253 + +from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA +from deepmd.pt.utils import ( # noqa: TID253 + env, +) +from deepmd.pt.utils.env import ( # noqa: TID253 + PRECISION_DICT, +) +from deepmd.pt.utils.exclude_mask import ( # noqa: TID253 + PairExcludeMask, +) +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...pt.model.test_mlp import ( + get_tols, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestDescrptSeA(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_consistency(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec, em in itertools.product( + [False, True], + ["float64", "float32"], + [[], [[0, 1]], [[1, 1]]], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + dd0 = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + exclude_types=em, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + dd1 = DescrptSeA.deserialize(dd0.serialize()) + rd1, gr1, _, _, sw1 = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], + rd0.detach().cpu().numpy()[1], + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + dd2 = DPDescrptSeA.deserialize(dd0.serialize()) + rd2, gr2, _, _, sw2 = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + for aa, bb in zip([rd1, gr1, sw1], [rd2, gr2, sw2], strict=True): + np.testing.assert_allclose( + aa.detach().cpu().numpy(), + bb, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + if em: + dd1.reinit_exclude([tuple(x) for x in em]) + self.assertIsInstance(dd1.emask, PairExcludeMask) + + def test_exportable(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [False, True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + dd0 = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) From b8a48ffe6bdd97ff3dbee7ef13182aa4bcf03a87 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 07:53:13 +0800 Subject: [PATCH 02/33] better type for xp.zeros --- deepmd/dpmodel/descriptor/se_e2_a.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index a6b17bf69a..3ca28ba556 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -609,7 +609,7 @@ def call( ng = self.neuron[-1] gr = xp.zeros( [nf * nloc, ng, 4], - dtype=self.dstd.dtype, + dtype=input_dtype, device=array_api_compat.device(coord_ext), ) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) From 1cc001f7f262e6d76f40a8c7d36d9e80a9f1dd19 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 09:51:22 +0800 Subject: [PATCH 03/33] implement env, base_descriptor and exclude_mask, remove the dependency on pt backend. --- deepmd/pt_expt/descriptor/__init__.py | 4 + deepmd/pt_expt/descriptor/base_descriptor.py | 10 ++ deepmd/pt_expt/descriptor/se_e2_a.py | 11 +- deepmd/pt_expt/utils/__init__.py | 10 ++ deepmd/pt_expt/utils/env.py | 117 +++++++++++++++++++ deepmd/pt_expt/utils/exclude_mask.py | 27 +++++ deepmd/pt_expt/utils/network.py | 6 +- source/tests/pt_expt/model/test_se_e2_a.py | 16 +-- 8 files changed, 187 insertions(+), 14 deletions(-) create mode 100644 deepmd/pt_expt/descriptor/base_descriptor.py create mode 100644 deepmd/pt_expt/utils/env.py create mode 100644 deepmd/pt_expt/utils/exclude_mask.py diff --git a/deepmd/pt_expt/descriptor/__init__.py b/deepmd/pt_expt/descriptor/__init__.py index fdac48ed41..089e5619e0 100644 --- a/deepmd/pt_expt/descriptor/__init__.py +++ b/deepmd/pt_expt/descriptor/__init__.py @@ -1,8 +1,12 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from .base_descriptor import ( + BaseDescriptor, +) from .se_e2_a import ( DescrptSeA, ) __all__ = [ + "BaseDescriptor", "DescrptSeA", ] diff --git a/deepmd/pt_expt/descriptor/base_descriptor.py b/deepmd/pt_expt/descriptor/base_descriptor.py new file mode 100644 index 0000000000..51e9325bba --- /dev/null +++ b/deepmd/pt_expt/descriptor/base_descriptor.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib + +from deepmd.dpmodel.descriptor import ( + make_base_descriptor, +) + +torch = importlib.import_module("torch") + +BaseDescriptor = make_base_descriptor(torch.Tensor, "forward") diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 4334011ec3..bb0c0cb2bd 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -1,24 +1,25 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import importlib from typing import ( Any, ) -import torch # noqa: TID253 - from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP -from deepmd.pt.model.descriptor.base_descriptor import ( # noqa: TID253 +from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, ) -from deepmd.pt.utils import ( # noqa: TID253 +from deepmd.pt_expt.utils import ( env, ) -from deepmd.pt.utils.exclude_mask import ( # noqa: TID253 +from deepmd.pt_expt.utils.exclude_mask import ( PairExcludeMask, ) from deepmd.pt_expt.utils.network import ( NetworkCollection, ) +torch = importlib.import_module("torch") + @BaseDescriptor.register("se_e2_a_expt") @BaseDescriptor.register("se_a_expt") diff --git a/deepmd/pt_expt/utils/__init__.py b/deepmd/pt_expt/utils/__init__.py index 6ceb116d85..f90cf82249 100644 --- a/deepmd/pt_expt/utils/__init__.py +++ b/deepmd/pt_expt/utils/__init__.py @@ -1 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later + +from .exclude_mask import ( + AtomExcludeMask, + PairExcludeMask, +) + +__all__ = [ + "AtomExcludeMask", + "PairExcludeMask", +] diff --git a/deepmd/pt_expt/utils/env.py b/deepmd/pt_expt/utils/env.py new file mode 100644 index 0000000000..bd644e7206 --- /dev/null +++ b/deepmd/pt_expt/utils/env.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +import logging +import multiprocessing +import os +import sys + +import numpy as np + +from deepmd.common import ( + VALID_PRECISION, +) +from deepmd.env import ( + GLOBAL_ENER_FLOAT_PRECISION, + GLOBAL_NP_FLOAT_PRECISION, + get_default_nthreads, + set_default_nthreads, +) + +log = logging.getLogger(__name__) +torch = importlib.import_module("torch") + +if sys.platform != "win32": + try: + multiprocessing.set_start_method("fork", force=True) + log.debug("Successfully set multiprocessing start method to 'fork'.") + except (RuntimeError, ValueError) as err: + log.warning(f"Could not set multiprocessing start method: {err}") +else: + log.debug("Skipping fork start method on Windows (not supported).") + +SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False) +DP_DTYPE_PROMOTION_STRICT = os.environ.get("DP_DTYPE_PROMOTION_STRICT", "0") == "1" +try: + # only linux + ncpus = len(os.sched_getaffinity(0)) +except AttributeError: + ncpus = os.cpu_count() +NUM_WORKERS = int(os.environ.get("NUM_WORKERS", min(4, ncpus))) +if multiprocessing.get_start_method() != "fork": + # spawn or forkserver does not support NUM_WORKERS > 0 for DataLoader + log.warning( + "NUM_WORKERS > 0 is not supported with spawn or forkserver start method. " + "Setting NUM_WORKERS to 0." + ) + NUM_WORKERS = 0 + +# Make sure DDP uses correct device if applicable +LOCAL_RANK = os.environ.get("LOCAL_RANK") +LOCAL_RANK = int(0 if LOCAL_RANK is None else LOCAL_RANK) + +if os.environ.get("DEVICE") == "cpu" or torch.cuda.is_available() is False: + DEVICE = torch.device("cpu") +else: + DEVICE = torch.device(f"cuda:{LOCAL_RANK}") + +JIT = False +CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory +ENERGY_BIAS_TRAINABLE = True +CUSTOM_OP_USE_JIT = False + +PRECISION_DICT = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "half": torch.float16, + "single": torch.float32, + "double": torch.float64, + "int32": torch.int32, + "int64": torch.int64, + "bfloat16": torch.bfloat16, + "bool": torch.bool, +} +GLOBAL_PT_FLOAT_PRECISION = PRECISION_DICT[np.dtype(GLOBAL_NP_FLOAT_PRECISION).name] +GLOBAL_PT_ENER_FLOAT_PRECISION = PRECISION_DICT[ + np.dtype(GLOBAL_ENER_FLOAT_PRECISION).name +] +PRECISION_DICT["default"] = GLOBAL_PT_FLOAT_PRECISION +assert VALID_PRECISION.issubset(PRECISION_DICT.keys()) +# cannot automatically generated +RESERVED_PRECISION_DICT = { + torch.float16: "float16", + torch.float32: "float32", + torch.float64: "float64", + torch.int32: "int32", + torch.int64: "int64", + torch.bfloat16: "bfloat16", + torch.bool: "bool", +} +assert set(PRECISION_DICT.values()) == set(RESERVED_PRECISION_DICT.keys()) +DEFAULT_PRECISION = "float64" + +# throw warnings if threads not set +set_default_nthreads() +inter_nthreads, intra_nthreads = get_default_nthreads() +if inter_nthreads > 0: # the behavior of 0 is not documented + torch.set_num_interop_threads(inter_nthreads) +if intra_nthreads > 0: + torch.set_num_threads(intra_nthreads) + +__all__ = [ + "CACHE_PER_SYS", + "CUSTOM_OP_USE_JIT", + "DEFAULT_PRECISION", + "DEVICE", + "ENERGY_BIAS_TRAINABLE", + "GLOBAL_ENER_FLOAT_PRECISION", + "GLOBAL_NP_FLOAT_PRECISION", + "GLOBAL_PT_ENER_FLOAT_PRECISION", + "GLOBAL_PT_FLOAT_PRECISION", + "JIT", + "LOCAL_RANK", + "NUM_WORKERS", + "PRECISION_DICT", + "RESERVED_PRECISION_DICT", + "SAMPLER_RECORD", +] diff --git a/deepmd/pt_expt/utils/exclude_mask.py b/deepmd/pt_expt/utils/exclude_mask.py new file mode 100644 index 0000000000..ed296c9f98 --- /dev/null +++ b/deepmd/pt_expt/utils/exclude_mask.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +from typing import ( + Any, +) + +from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP +from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP +from deepmd.pt_expt.utils import ( + env, +) + +torch = importlib.import_module("torch") + + +class AtomExcludeMask(AtomExcludeMaskDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "type_mask": + value = None if value is None else torch.as_tensor(value, device=env.DEVICE) + return super().__setattr__(name, value) + + +class PairExcludeMask(PairExcludeMaskDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "type_mask": + value = None if value is None else torch.as_tensor(value, device=env.DEVICE) + return super().__setattr__(name, value) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index f29d8970b3..91a6999766 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import importlib from typing import ( Any, ClassVar, @@ -6,7 +7,6 @@ ) import numpy as np -import torch # noqa: TID253 from deepmd.dpmodel.common import ( NativeOP, @@ -19,10 +19,12 @@ make_fitting_network, make_multilayer_network, ) -from deepmd.pt.utils import ( # noqa: TID253 +from deepmd.pt_expt.utils import ( env, ) +torch = importlib.import_module("torch") + def _to_torch_array(value: Any) -> torch.Tensor | None: if value is None: diff --git a/source/tests/pt_expt/model/test_se_e2_a.py b/source/tests/pt_expt/model/test_se_e2_a.py index b9b834849f..57923b97a3 100644 --- a/source/tests/pt_expt/model/test_se_e2_a.py +++ b/source/tests/pt_expt/model/test_se_e2_a.py @@ -1,23 +1,23 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import importlib import itertools import unittest import numpy as np -import torch # noqa: TID253 from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA -from deepmd.pt.utils import ( # noqa: TID253 +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.pt_expt.utils import ( env, ) -from deepmd.pt.utils.env import ( # noqa: TID253 +from deepmd.pt_expt.utils.env import ( PRECISION_DICT, ) -from deepmd.pt.utils.exclude_mask import ( # noqa: TID253 +from deepmd.pt_expt.utils.exclude_mask import ( PairExcludeMask, ) -from deepmd.pt_expt.descriptor.se_e2_a import ( - DescrptSeA, -) from ...pt.model.test_env_mat import ( TestCaseSingleFrameWithNlist, @@ -29,6 +29,8 @@ GLOBAL_SEED, ) +torch = importlib.import_module("torch") + class TestDescrptSeA(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self) -> None: From f2fbe8884fdacf4369dbec847c06ebc54b453d02 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 10:08:29 +0800 Subject: [PATCH 04/33] mv to_torch_tensor to common --- deepmd/pt_expt/common.py | 35 +++++++++++++++++++++++++++++++++ deepmd/pt_expt/utils/network.py | 16 ++++----------- 2 files changed, 39 insertions(+), 12 deletions(-) create mode 100644 deepmd/pt_expt/common.py diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py new file mode 100644 index 0000000000..f065eeb76d --- /dev/null +++ b/deepmd/pt_expt/common.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +from typing import ( + Any, + overload, +) + +import numpy as np + +from deepmd.pt_expt.utils import ( + env, +) + +torch = importlib.import_module("torch") + + +@overload +def to_torch_array(array: np.ndarray) -> torch.Tensor: ... + + +@overload +def to_torch_array(array: None) -> None: ... + + +@overload +def to_torch_array(array: torch.Tensor) -> torch.Tensor: ... + + +def to_torch_array(array: Any) -> torch.Tensor | None: + """Convert input to a torch tensor on the pt-expt device.""" + if array is None: + return None + if torch.is_tensor(array): + return array + return torch.as_tensor(array, device=env.DEVICE) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 91a6999766..18840200be 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -19,21 +19,13 @@ make_fitting_network, make_multilayer_network, ) -from deepmd.pt_expt.utils import ( - env, +from deepmd.pt_expt.common import ( + to_torch_array, ) torch = importlib.import_module("torch") -def _to_torch_array(value: Any) -> torch.Tensor | None: - if value is None: - return None - if torch.is_tensor(value): - return value - return torch.as_tensor(value, device=env.DEVICE) - - class TorchArrayParam(torch.nn.Parameter): def __new__(cls, data: Any = None, requires_grad: bool = True) -> Self: return torch.nn.Parameter.__new__(cls, data, requires_grad) @@ -52,7 +44,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: for name in ("w", "b", "idt"): if name in self._parameters or name in self._buffers: continue - val = _to_torch_array(getattr(self, name)) + val = to_torch_array(getattr(self, name)) if val is None: continue if self.trainable: @@ -66,7 +58,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"} and "_parameters" in self.__dict__: - val = _to_torch_array(value) + val = to_torch_array(value) if val is None: return super().__setattr__(name, None) if getattr(self, "trainable", False): From e2afbe9c190ffef45315cac5089e067c7da800c5 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 10:26:15 +0800 Subject: [PATCH 05/33] simplify __init__ of the NaiveLayer --- deepmd/pt_expt/utils/network.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 18840200be..5708197c66 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -41,20 +41,6 @@ class NativeLayer(NativeLayerDP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) NativeLayerDP.__init__(self, *args, **kwargs) - for name in ("w", "b", "idt"): - if name in self._parameters or name in self._buffers: - continue - val = to_torch_array(getattr(self, name)) - if val is None: - continue - if self.trainable: - if hasattr(self, name) and name not in self._parameters: - delattr(self, name) - self.register_parameter(name, TorchArrayParam(val, requires_grad=True)) - else: - if hasattr(self, name) and name not in self._buffers: - delattr(self, name) - self.register_buffer(name, val) def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"} and "_parameters" in self.__dict__: From 4ba511ac49d1093b3cabced160a953c90e8e0f81 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 10:32:09 +0800 Subject: [PATCH 06/33] fix bug --- deepmd/pt_expt/utils/network.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 5708197c66..f2230383de 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -3,7 +3,6 @@ from typing import ( Any, ClassVar, - Self, ) import numpy as np @@ -27,7 +26,9 @@ class TorchArrayParam(torch.nn.Parameter): - def __new__(cls, data: Any = None, requires_grad: bool = True) -> Self: + def __new__( # noqa: PYI034 + cls, data: Any = None, requires_grad: bool = True + ) -> "TorchArrayParam": return torch.nn.Parameter.__new__(cls, data, requires_grad) def __array__(self, dtype: Any | None = None) -> np.ndarray: From fb9598a68d13b1712b11b9ea481fd3e2ca85b502 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 10:43:19 +0800 Subject: [PATCH 07/33] fix bug --- deepmd/pt_expt/utils/network.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index f2230383de..7a85634dca 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -95,16 +95,18 @@ class NetworkCollection(NetworkCollectionDP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) - super().__init__(*args, **kwargs) self._module_networks = torch.nn.ModuleDict() - for idx, net in enumerate(self._networks): - if isinstance(net, torch.nn.Module): - self._module_networks[str(idx)] = net + super().__init__(*args, **kwargs) def __setitem__(self, key: int | tuple, value: Any) -> None: + idx = self._convert_key(key) super().__setitem__(key, value) - if isinstance(value, torch.nn.Module): - self._module_networks[str(self._convert_key(key))] = value + net = self._networks[idx] + key_str = str(idx) + if isinstance(net, torch.nn.Module): + self._module_networks[key_str] = net + elif key_str in self._module_networks: + del self._module_networks[key_str] class LayerNorm(LayerNormDP, NativeLayer): From fa03351be77fe9b1f1e5173d0855d6e6d912ab9e Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 11:25:14 +0800 Subject: [PATCH 08/33] simplify init method of se_e2_a descriptor. fig bug in consistent UT --- deepmd/pt_expt/descriptor/se_e2_a.py | 25 ------------------------- source/tests/consistent/common.py | 2 +- 2 files changed, 1 insertion(+), 26 deletions(-) diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index bb0c0cb2bd..19a0d56734 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -27,7 +27,6 @@ class DescrptSeA(DescrptSeADP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) DescrptSeADP.__init__(self, *args, **kwargs) - self._convert_state() def __setattr__(self, name: str, value: Any) -> None: if name in {"davg", "dstd"} and "_buffers" in self.__dict__: @@ -53,30 +52,6 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) return super().__setattr__(name, value) - def _convert_state(self) -> None: - if self.davg is not None: - davg = torch.as_tensor(self.davg, device=env.DEVICE) - if "davg" in self._buffers: - self._buffers["davg"] = davg - else: - if hasattr(self, "davg"): - delattr(self, "davg") - self.register_buffer("davg", davg) - if self.dstd is not None: - dstd = torch.as_tensor(self.dstd, device=env.DEVICE) - if "dstd" in self._buffers: - self._buffers["dstd"] = dstd - else: - if hasattr(self, "dstd"): - delattr(self, "dstd") - self.register_buffer("dstd", dstd) - if self.embeddings is not None: - self.embeddings = NetworkCollection.deserialize(self.embeddings.serialize()) - if self.emask is not None: - self.emask = PairExcludeMask( - self.ntypes, exclude_types=list(self.emask.get_exclude_types()) - ) - def forward( self, nlist: torch.Tensor, diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index 3d60f6def0..76b7e9cb53 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -92,7 +92,7 @@ class CommonTest(ABC): """Native DP model class.""" pt_class: ClassVar[type | None] """PyTorch model class.""" - pt_expt_class: ClassVar[type | None] + pt_expt_class: ClassVar[type | None] = None """PyTorch exportable model class.""" jax_class: ClassVar[type | None] """JAX model class.""" From 09b33f19daef30dad5dcd27be4811cde994b8bad Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 11:34:44 +0800 Subject: [PATCH 09/33] restructure the test folders. add test_common. --- deepmd/pt_expt/common.py | 2 +- source/tests/pt_expt/descriptor/__init__.py | 1 + .../{model => descriptor}/test_se_e2_a.py | 0 source/tests/pt_expt/utils/__init__.py | 1 + source/tests/pt_expt/utils/test_common.py | 25 +++++++++++++++++++ 5 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 source/tests/pt_expt/descriptor/__init__.py rename source/tests/pt_expt/{model => descriptor}/test_se_e2_a.py (100%) create mode 100644 source/tests/pt_expt/utils/__init__.py create mode 100644 source/tests/pt_expt/utils/test_common.py diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index f065eeb76d..b66c0ff66d 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -31,5 +31,5 @@ def to_torch_array(array: Any) -> torch.Tensor | None: if array is None: return None if torch.is_tensor(array): - return array + return array.to(device=env.DEVICE) return torch.as_tensor(array, device=env.DEVICE) diff --git a/source/tests/pt_expt/descriptor/__init__.py b/source/tests/pt_expt/descriptor/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt_expt/descriptor/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt_expt/model/test_se_e2_a.py b/source/tests/pt_expt/descriptor/test_se_e2_a.py similarity index 100% rename from source/tests/pt_expt/model/test_se_e2_a.py rename to source/tests/pt_expt/descriptor/test_se_e2_a.py diff --git a/source/tests/pt_expt/utils/__init__.py b/source/tests/pt_expt/utils/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt_expt/utils/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt_expt/utils/test_common.py b/source/tests/pt_expt/utils/test_common.py new file mode 100644 index 0000000000..63c4983f23 --- /dev/null +++ b/source/tests/pt_expt/utils/test_common.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib + +import numpy as np + +from deepmd.pt_expt.common import ( + to_torch_array, +) +from deepmd.pt_expt.utils import ( + env, +) + +torch = importlib.import_module("torch") + + +def test_to_torch_array_moves_device() -> None: + arr = np.arange(6, dtype=np.float32).reshape(2, 3) + tensor = to_torch_array(arr) + assert torch.is_tensor(tensor) + assert tensor.device == env.DEVICE + + input_tensor = torch.as_tensor(arr, device=torch.device("cpu")) + output_tensor = to_torch_array(input_tensor) + assert torch.is_tensor(output_tensor) + assert output_tensor.device == env.DEVICE From 67f2e544a6d228652ba9ea311a52a5c8defec210 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 11:54:41 +0800 Subject: [PATCH 10/33] add test_exclusion_mask.py --- .../pt_expt/utils/test_exclusion_mask.py | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 source/tests/pt_expt/utils/test_exclusion_mask.py diff --git a/source/tests/pt_expt/utils/test_exclusion_mask.py b/source/tests/pt_expt/utils/test_exclusion_mask.py new file mode 100644 index 0000000000..7168579052 --- /dev/null +++ b/source/tests/pt_expt/utils/test_exclusion_mask.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +import unittest + +import numpy as np + +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.exclude_mask import ( + AtomExcludeMask, + PairExcludeMask, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + +torch = importlib.import_module("torch") + + +class TestAtomExcludeMask(unittest.TestCase): + def test_build_type_exclude_mask(self) -> None: + nf = 2 + nt = 3 + exclude_types = [0, 2] + atype = np.array( + [ + [0, 2, 1, 2, 0, 1, 0], + [1, 2, 0, 0, 2, 2, 1], + ], + dtype=np.int32, + ).reshape([nf, -1]) + expected_mask = np.array( + [ + [0, 0, 1, 0, 0, 1, 0], + [1, 0, 0, 0, 0, 0, 1], + ] + ).reshape([nf, -1]) + des = AtomExcludeMask(nt, exclude_types=exclude_types) + mask = des.build_type_exclude_mask(torch.as_tensor(atype, device=env.DEVICE)) + np.testing.assert_equal(mask.detach().cpu().numpy(), expected_mask) + + +class TestPairExcludeMask(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + + def test_build_type_exclude_mask(self) -> None: + exclude_types = [[0, 1]] + expected_mask = np.array( + [ + [1, 1, 1, 1, 1, 0, 1], + [1, 1, 1, 1, 1, 0, 1], + [0, 0, 1, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 0, 1], + [1, 1, 1, 1, 1, 0, 1], + ] + ).reshape(self.nf, self.nloc, sum(self.sel)) + des = PairExcludeMask(self.nt, exclude_types=exclude_types) + mask = des.build_type_exclude_mask( + torch.as_tensor(self.nlist, device=env.DEVICE), + torch.as_tensor(self.atype_ext, device=env.DEVICE), + ) + np.testing.assert_equal(mask.detach().cpu().numpy(), expected_mask) From f7d83ddfae60920d05b76f8a13fd4a35bb350a79 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 11:58:38 +0800 Subject: [PATCH 11/33] fix poitential import issue in test. --- source/tests/pt_expt/conftest.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 source/tests/pt_expt/conftest.py diff --git a/source/tests/pt_expt/conftest.py b/source/tests/pt_expt/conftest.py new file mode 100644 index 0000000000..ec025c2202 --- /dev/null +++ b/source/tests/pt_expt/conftest.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import pytest + +pytest.importorskip("torch") From 0c96bb6fecf1564433d0f98d00d05866dcb9fbd5 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 12:11:12 +0800 Subject: [PATCH 12/33] correct __call__(). fix bug --- deepmd/pt_expt/descriptor/se_e2_a.py | 6 +++++- deepmd/pt_expt/utils/network.py | 6 ++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 19a0d56734..7a4d4a71d9 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -28,6 +28,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) DescrptSeADP.__init__(self, *args, **kwargs) + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + def __setattr__(self, name: str, value: Any) -> None: if name in {"davg", "dstd"} and "_buffers" in self.__dict__: tensor = ( @@ -54,9 +58,9 @@ def __setattr__(self, name: str, value: Any) -> None: def forward( self, - nlist: torch.Tensor, extended_coord: torch.Tensor, extended_atype: torch.Tensor, + nlist: torch.Tensor, extended_atype_embd: torch.Tensor | None = None, mapping: torch.Tensor | None = None, type_embedding: torch.Tensor | None = None, diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 7a85634dca..fffb98b1ef 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -43,6 +43,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) NativeLayerDP.__init__(self, *args, **kwargs) + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return torch.nn.Module.__call__(self, *args, **kwargs) + def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"} and "_parameters" in self.__dict__: val = to_torch_array(value) @@ -74,6 +77,9 @@ def __init__(self, layers: list[dict] | None = None) -> None: super().__init__(layers) self.layers = torch.nn.ModuleList(self.layers) + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return torch.nn.Module.__call__(self, *args, **kwargs) + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.call(x) From 9dca9128b50ecf7a92f902e6a6dca0b475ac1e9c Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 12:54:28 +0800 Subject: [PATCH 13/33] fix registration issue --- deepmd/pt_expt/descriptor/se_e2_a.py | 4 +++- deepmd/pt_expt/utils/network.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 7a4d4a71d9..7df1148e38 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -40,7 +40,9 @@ def __setattr__(self, name: str, value: Any) -> None: if name in self._buffers: self._buffers[name] = tensor return - return super().__setattr__(name, tensor) + # Register on first assignment so buffers are in state_dict and moved by .to(). + self.register_buffer(name, tensor) + return if name == "embeddings" and "_modules" in self.__dict__: if value is not None and not isinstance(value, torch.nn.Module): if hasattr(value, "serialize"): diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index fffb98b1ef..5f66959d16 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -64,7 +64,9 @@ def __setattr__(self, name: str, value: Any) -> None: if name in self._buffers: self._buffers[name] = val return - return super().__setattr__(name, val) + # Register on first assignment so tensors are in state_dict and moved by .to(). + self.register_buffer(name, val) + return return super().__setattr__(name, value) def forward(self, x: torch.Tensor) -> torch.Tensor: From 17f0a5d1ae06eecfa25b4f2e3e2e09a4570fc0ef Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 13:00:54 +0800 Subject: [PATCH 14/33] fix pt-expt file extension --- deepmd/backend/pt_expt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/backend/pt_expt.py b/deepmd/backend/pt_expt.py index 38745c690c..e651332e2b 100644 --- a/deepmd/backend/pt_expt.py +++ b/deepmd/backend/pt_expt.py @@ -41,7 +41,7 @@ class PyTorchExportableBackend(Backend): | Backend.Feature.IO ) """The features of the backend.""" - suffixes: ClassVar[list[str]] = [".pth", ".pt"] + suffixes: ClassVar[list[str]] = [".pte"] """The suffixes of the backend.""" def is_available(self) -> bool: From 8ce93baafa42c0c51046152084765f44101809b9 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 13:03:21 +0800 Subject: [PATCH 15/33] fix(pt): expansion of get_default_nthreads() --- deepmd/pt/utils/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index 90d0d536c1..aa384b31b5 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -91,7 +91,7 @@ # throw warnings if threads not set set_default_nthreads() -inter_nthreads, intra_nthreads = get_default_nthreads() +intra_nthreads, inter_nthreads = get_default_nthreads() if inter_nthreads > 0: # the behavior of 0 is not documented torch.set_num_interop_threads(inter_nthreads) if intra_nthreads > 0: From 309198894d37ad0f0417fc3475919fb02e3fc63c Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 13:08:58 +0800 Subject: [PATCH 16/33] fix bug of intra-inter --- deepmd/pt_expt/utils/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt_expt/utils/env.py b/deepmd/pt_expt/utils/env.py index bd644e7206..b5042f6f2a 100644 --- a/deepmd/pt_expt/utils/env.py +++ b/deepmd/pt_expt/utils/env.py @@ -92,7 +92,7 @@ # throw warnings if threads not set set_default_nthreads() -inter_nthreads, intra_nthreads = get_default_nthreads() +intra_nthreads, inter_nthreads = get_default_nthreads() if inter_nthreads > 0: # the behavior of 0 is not documented torch.set_num_interop_threads(inter_nthreads) if intra_nthreads > 0: From 85f05833353a8bb388ce0fd16cdf0059d579b5e4 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 13:13:42 +0800 Subject: [PATCH 17/33] fix bug of default dp inter value --- deepmd/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/env.py b/deepmd/env.py index 7b29a338f1..c9d0fb241f 100644 --- a/deepmd/env.py +++ b/deepmd/env.py @@ -138,7 +138,7 @@ def get_default_nthreads() -> tuple[int, int]: ), int( os.environ.get( "DP_INTER_OP_PARALLELISM_THREADS", - os.environ.get("TF_INTRA_OP_PARALLELISM_THREADS", "0"), + os.environ.get("TF_INTER_OP_PARALLELISM_THREADS", "0"), ) ) From d33324de397d728d15fa773bdb828783085e4156 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 15:15:07 +0800 Subject: [PATCH 18/33] fix cicd --- source/tests/consistent/descriptor/common.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index 7c8cbce744..50efe32a08 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -153,13 +153,17 @@ def eval_pt_expt_descriptor( box: np.ndarray, mixed_types: bool = False, ) -> Any: - ext_coords, ext_atype, mapping = extend_coord_with_ghosts( + # Use the torch-native neighbor list utilities to avoid array_api_compat + # allocations on CUDA. The array_api path can hit torch empty/ones/eye/etc + # on CUDA, which all rely on aten::empty_strided and fail in CI builds + # where that CUDA kernel is not available. + ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt( torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3), pt_expt_obj.get_rcut(), ) - nlist = build_neighbor_list( + nlist = build_neighbor_list_pt( ext_coords, ext_atype, natoms[0], From 4de9a565c01c6193a94eaecbc18475c4be04dc08 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 15:57:27 +0800 Subject: [PATCH 19/33] feat: add support for se_r --- deepmd/dpmodel/descriptor/se_r.py | 4 +- deepmd/pt_expt/descriptor/__init__.py | 4 + deepmd/pt_expt/descriptor/se_r.py | 83 +++++++++++ .../tests/consistent/descriptor/test_se_r.py | 25 ++++ source/tests/pt_expt/descriptor/test_se_r.py | 132 ++++++++++++++++++ 5 files changed, 247 insertions(+), 1 deletion(-) create mode 100644 deepmd/pt_expt/descriptor/se_r.py create mode 100644 source/tests/pt_expt/descriptor/test_se_r.py diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index 6decd91a23..b38d561e95 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -391,7 +391,9 @@ def call( ng = self.neuron[-1] xyz_scatter = xp.zeros( - [nf, nloc, ng], dtype=get_xp_precision(xp, self.precision) + [nf, nloc, ng], + dtype=get_xp_precision(xp, self.precision), + device=array_api_compat.device(coord_ext), ) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) rr = xp.astype(rr, xyz_scatter.dtype) diff --git a/deepmd/pt_expt/descriptor/__init__.py b/deepmd/pt_expt/descriptor/__init__.py index 089e5619e0..4d9469a93a 100644 --- a/deepmd/pt_expt/descriptor/__init__.py +++ b/deepmd/pt_expt/descriptor/__init__.py @@ -5,8 +5,12 @@ from .se_e2_a import ( DescrptSeA, ) +from .se_r import ( + DescrptSeR, +) __all__ = [ "BaseDescriptor", "DescrptSeA", + "DescrptSeR", ] diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py new file mode 100644 index 0000000000..f4969ce927 --- /dev/null +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP +from deepmd.pt_expt.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.pt_expt.utils.network import ( + NetworkCollection, +) + +torch = importlib.import_module("torch") + + +@BaseDescriptor.register("se_e2_r_expt") +@BaseDescriptor.register("se_r_expt") +class DescrptSeR(DescrptSeRDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + DescrptSeRDP.__init__(self, *args, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + if name in {"davg", "dstd"} and "_buffers" in self.__dict__: + tensor = ( + None if value is None else torch.as_tensor(value, device=env.DEVICE) + ) + if name in self._buffers: + self._buffers[name] = tensor + return + # Register on first assignment so buffers are in state_dict and moved by .to(). + self.register_buffer(name, tensor) + return + if name == "embeddings" and "_modules" in self.__dict__: + if value is not None and not isinstance(value, torch.nn.Module): + if hasattr(value, "serialize"): + value = NetworkCollection.deserialize(value.serialize()) + elif isinstance(value, dict): + value = NetworkCollection.deserialize(value) + return super().__setattr__(name, value) + if name == "emask" and "_modules" in self.__dict__: + if value is not None and not isinstance(value, torch.nn.Module): + value = PairExcludeMask( + self.ntypes, exclude_types=list(value.get_exclude_types()) + ) + return super().__setattr__(name, value) + return super().__setattr__(name, value) + + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + extended_atype_embd: torch.Tensor | None = None, + mapping: torch.Tensor | None = None, + type_embedding: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + del extended_atype_embd, type_embedding + descrpt, rot_mat, g2, h2, sw = self.call( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + return descrpt, rot_mat, g2, h2, sw diff --git a/source/tests/consistent/descriptor/test_se_r.py b/source/tests/consistent/descriptor/test_se_r.py index 3420c5592f..9aafea1578 100644 --- a/source/tests/consistent/descriptor/test_se_r.py +++ b/source/tests/consistent/descriptor/test_se_r.py @@ -15,6 +15,7 @@ INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, parameterized, @@ -27,6 +28,10 @@ from deepmd.pt.model.descriptor.se_r import DescrptSeR as DescrptSeRPT else: DescrptSeAPT = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.descriptor.se_r import DescrptSeR as DescrptSeRPTExpt +else: + DescrptSeRPTExpt = None if INSTALLED_TF: from deepmd.tf.descriptor.se_r import DescrptSeR as DescrptSeRTF else: @@ -84,6 +89,16 @@ def skip_pt(self) -> bool: ) = self.param return not type_one_side or CommonTest.skip_pt + @property + def skip_pt_expt(self) -> bool: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + ) = self.param + return not type_one_side or CommonTest.skip_pt_expt + @property def skip_dp(self) -> bool: ( @@ -117,6 +132,7 @@ def skip_array_api_strict(self) -> bool: tf_class = DescrptSeRTF dp_class = DescrptSeRDP pt_class = DescrptSeRPT + pt_expt_class = DescrptSeRPTExpt jax_class = DescrptSeRJAX array_api_strict_class = DescrptSeRArrayAPIStrict args = descrpt_se_r_args() @@ -183,6 +199,15 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def eval_jax(self, jax_obj: Any) -> Any: return self.eval_jax_descriptor( jax_obj, diff --git a/source/tests/pt_expt/descriptor/test_se_r.py b/source/tests/pt_expt/descriptor/test_se_r.py new file mode 100644 index 0000000000..6e7339801c --- /dev/null +++ b/source/tests/pt_expt/descriptor/test_se_r.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +import itertools +import unittest + +import numpy as np + +from deepmd.dpmodel.descriptor import DescrptSeR as DPDescrptSeR +from deepmd.pt_expt.descriptor.se_r import ( + DescrptSeR, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.env import ( + PRECISION_DICT, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...pt.model.test_mlp import ( + get_tols, +) +from ...seed import ( + GLOBAL_SEED, +) + +torch = importlib.import_module("torch") + + +class TestDescrptSeR(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_consistency(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 1)) + dstd = rng.normal(size=(self.nt, nnei, 1)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec, em in itertools.product( + [False, True], + ["float64", "float32"], + [[], [[0, 1]], [[1, 1]]], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + dd0 = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + exclude_types=em, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + dd1 = DescrptSeR.deserialize(dd0.serialize()) + rd1, _, _, _, sw1 = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], + rd0.detach().cpu().numpy()[1], + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + dd2 = DPDescrptSeR.deserialize(dd0.serialize()) + rd2, _, _, _, sw2 = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + for aa, bb in zip([rd1, sw1], [rd2, sw2], strict=True): + np.testing.assert_allclose( + aa.detach().cpu().numpy(), + bb, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + def test_exportable(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 1)) + dstd = rng.normal(size=(self.nt, nnei, 1)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [False, True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + dd0 = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) From f4dc0afec4909dd4c052e9dcd565b4be01b6ed92 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 17:50:35 +0800 Subject: [PATCH 20/33] fix device of xp array --- deepmd/dpmodel/descriptor/se_r.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index b38d561e95..4fdf50beba 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -309,9 +309,12 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.dstd) + device = array_api_compat.device(self.dstd) if not self.set_davg_zero: - self.davg = xp.asarray(mean, dtype=self.davg.dtype, copy=True) - self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True) + self.davg = xp.asarray( + mean, dtype=self.davg.dtype, copy=True, device=device + ) + self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True, device=device) def set_stat_mean_and_stddev( self, From 238483531cd8455219ad5d21065f516f0486c97b Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 17:51:50 +0800 Subject: [PATCH 21/33] fix device of xp array --- deepmd/dpmodel/descriptor/dpa1.py | 9 +++++++-- deepmd/dpmodel/descriptor/repflows.py | 9 +++++++-- deepmd/dpmodel/descriptor/repformers.py | 9 +++++++-- deepmd/dpmodel/descriptor/se_e2_a.py | 7 +++++-- deepmd/dpmodel/descriptor/se_t.py | 7 +++++-- deepmd/dpmodel/descriptor/se_t_tebd.py | 9 +++++++-- 6 files changed, 38 insertions(+), 12 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 5228ba55b2..f09ab24dfe 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -909,9 +909,14 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.stddev) + device = array_api_compat.device(self.stddev) if not self.set_davg_zero: - self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True) - self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True) + self.mean = xp.asarray( + mean, dtype=self.mean.dtype, copy=True, device=device + ) + self.stddev = xp.asarray( + stddev, dtype=self.stddev.dtype, copy=True, device=device + ) def get_stats(self) -> dict[str, StatItem]: """Get the statistics of the descriptor.""" diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 706fc690e4..7ba4f92662 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -453,9 +453,14 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.stddev) + device = array_api_compat.device(self.stddev) if not self.set_davg_zero: - self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True) - self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True) + self.mean = xp.asarray( + mean, dtype=self.mean.dtype, copy=True, device=device + ) + self.stddev = xp.asarray( + stddev, dtype=self.stddev.dtype, copy=True, device=device + ) def get_stats(self) -> dict[str, StatItem]: """Get the statistics of the descriptor.""" diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index 79d4f9228f..06f5c1c943 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -417,9 +417,14 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.stddev) + device = array_api_compat.device(self.stddev) if not self.set_davg_zero: - self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True) - self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True) + self.mean = xp.asarray( + mean, dtype=self.mean.dtype, copy=True, device=device + ) + self.stddev = xp.asarray( + stddev, dtype=self.stddev.dtype, copy=True, device=device + ) def get_stats(self) -> dict[str, StatItem]: """Get the statistics of the descriptor.""" diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 3ca28ba556..77afb110e9 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -350,9 +350,12 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.dstd) + device = array_api_compat.device(self.dstd) if not self.set_davg_zero: - self.davg = xp.asarray(mean, dtype=self.davg.dtype, copy=True) - self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True) + self.davg = xp.asarray( + mean, dtype=self.davg.dtype, copy=True, device=device + ) + self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True, device=device) def set_stat_mean_and_stddev( self, diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index 863187dd4c..749a5da188 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -290,9 +290,12 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.dstd) + device = array_api_compat.device(self.dstd) if not self.set_davg_zero: - self.davg = xp.asarray(mean, dtype=self.davg.dtype, copy=True) - self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True) + self.davg = xp.asarray( + mean, dtype=self.davg.dtype, copy=True, device=device + ) + self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True, device=device) def set_stat_mean_and_stddev( self, diff --git a/deepmd/dpmodel/descriptor/se_t_tebd.py b/deepmd/dpmodel/descriptor/se_t_tebd.py index e118d5abd4..0a2d46c015 100644 --- a/deepmd/dpmodel/descriptor/se_t_tebd.py +++ b/deepmd/dpmodel/descriptor/se_t_tebd.py @@ -694,9 +694,14 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.stddev) + device = array_api_compat.device(self.stddev) if not self.set_davg_zero: - self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True) - self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True) + self.mean = xp.asarray( + mean, dtype=self.mean.dtype, copy=True, device=device + ) + self.stddev = xp.asarray( + stddev, dtype=self.stddev.dtype, copy=True, device=device + ) def get_stats(self) -> dict[str, StatItem]: """Get the statistics of the descriptor.""" From 9646d71b6d05d7c30fb70ec7e0be708e122bdbde Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 17:52:16 +0800 Subject: [PATCH 22/33] revert extend_coord_with_ghosts --- source/tests/consistent/descriptor/common.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index 50efe32a08..7c8cbce744 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -153,17 +153,13 @@ def eval_pt_expt_descriptor( box: np.ndarray, mixed_types: bool = False, ) -> Any: - # Use the torch-native neighbor list utilities to avoid array_api_compat - # allocations on CUDA. The array_api path can hit torch empty/ones/eye/etc - # on CUDA, which all rely on aten::empty_strided and fail in CI builds - # where that CUDA kernel is not available. - ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt( + ext_coords, ext_atype, mapping = extend_coord_with_ghosts( torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3), pt_expt_obj.get_rcut(), ) - nlist = build_neighbor_list_pt( + nlist = build_neighbor_list( ext_coords, ext_atype, natoms[0], From f270069dd07c10bb9d908bd203de0e6e7ba72412 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 18:11:20 +0800 Subject: [PATCH 23/33] raise error for non-implemented methods --- deepmd/backend/pt_expt.py | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/deepmd/backend/pt_expt.py b/deepmd/backend/pt_expt.py index e651332e2b..ade9eb51f3 100644 --- a/deepmd/backend/pt_expt.py +++ b/deepmd/backend/pt_expt.py @@ -76,9 +76,7 @@ def deep_eval(self) -> type["DeepEvalBackend"]: type[DeepEvalBackend] The Deep Eval backend of the backend. """ - from deepmd.pt.infer.deep_eval import DeepEval as DeepEvalPT - - return DeepEvalPT + raise NotImplementedError @property def neighbor_stat(self) -> type["NeighborStat"]: @@ -89,11 +87,7 @@ def neighbor_stat(self) -> type["NeighborStat"]: type[NeighborStat] The neighbor statistics of the backend. """ - from deepmd.pt.utils.neighbor_stat import ( - NeighborStat, - ) - - return NeighborStat + raise NotImplementedError @property def serialize_hook(self) -> Callable[[str], dict]: @@ -104,11 +98,7 @@ def serialize_hook(self) -> Callable[[str], dict]: Callable[[str], dict] The serialize hook of the backend. """ - from deepmd.pt.utils.serialization import ( - serialize_from_file, - ) - - return serialize_from_file + raise NotImplementedError @property def deserialize_hook(self) -> Callable[[str, dict], None]: @@ -119,8 +109,4 @@ def deserialize_hook(self) -> Callable[[str, dict], None]: Callable[[str, dict], None] The deserialize hook of the backend. """ - from deepmd.pt.utils.serialization import ( - deserialize_to_file, - ) - - return deserialize_to_file + raise NotImplementedError From 57433d3e1e82b00006a9062c4a3610bbf3b52c45 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 22:17:46 +0800 Subject: [PATCH 24/33] restore import torch --- deepmd/pt_expt/common.py | 4 +--- deepmd/pt_expt/descriptor/base_descriptor.py | 5 ++--- deepmd/pt_expt/descriptor/se_e2_a.py | 5 ++--- deepmd/pt_expt/descriptor/se_r.py | 5 ++--- deepmd/pt_expt/utils/env.py | 3 +-- deepmd/pt_expt/utils/exclude_mask.py | 5 ++--- deepmd/pt_expt/utils/network.py | 4 +--- pyproject.toml | 3 +++ source/tests/pt_expt/descriptor/test_se_e2_a.py | 4 +--- source/tests/pt_expt/descriptor/test_se_r.py | 4 +--- source/tests/pt_expt/utils/test_common.py | 4 +--- source/tests/pt_expt/utils/test_exclusion_mask.py | 4 +--- 12 files changed, 18 insertions(+), 32 deletions(-) diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index b66c0ff66d..db8b94989b 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -1,18 +1,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib from typing import ( Any, overload, ) import numpy as np +import torch from deepmd.pt_expt.utils import ( env, ) -torch = importlib.import_module("torch") - @overload def to_torch_array(array: np.ndarray) -> torch.Tensor: ... diff --git a/deepmd/pt_expt/descriptor/base_descriptor.py b/deepmd/pt_expt/descriptor/base_descriptor.py index 51e9325bba..986435205a 100644 --- a/deepmd/pt_expt/descriptor/base_descriptor.py +++ b/deepmd/pt_expt/descriptor/base_descriptor.py @@ -1,10 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib + +import torch from deepmd.dpmodel.descriptor import ( make_base_descriptor, ) -torch = importlib.import_module("torch") - BaseDescriptor = make_base_descriptor(torch.Tensor, "forward") diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 7df1148e38..21c0a4eeb7 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib from typing import ( Any, ) +import torch + from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, @@ -18,8 +19,6 @@ NetworkCollection, ) -torch = importlib.import_module("torch") - @BaseDescriptor.register("se_e2_a_expt") @BaseDescriptor.register("se_a_expt") diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index f4969ce927..508785949c 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib from typing import ( Any, ) +import torch + from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, @@ -18,8 +19,6 @@ NetworkCollection, ) -torch = importlib.import_module("torch") - @BaseDescriptor.register("se_e2_r_expt") @BaseDescriptor.register("se_r_expt") diff --git a/deepmd/pt_expt/utils/env.py b/deepmd/pt_expt/utils/env.py index b5042f6f2a..ce13e4ef42 100644 --- a/deepmd/pt_expt/utils/env.py +++ b/deepmd/pt_expt/utils/env.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib import logging import multiprocessing import os @@ -18,7 +17,7 @@ ) log = logging.getLogger(__name__) -torch = importlib.import_module("torch") +import torch if sys.platform != "win32": try: diff --git a/deepmd/pt_expt/utils/exclude_mask.py b/deepmd/pt_expt/utils/exclude_mask.py index ed296c9f98..15fbbc8e34 100644 --- a/deepmd/pt_expt/utils/exclude_mask.py +++ b/deepmd/pt_expt/utils/exclude_mask.py @@ -1,17 +1,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib from typing import ( Any, ) +import torch + from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP from deepmd.pt_expt.utils import ( env, ) -torch = importlib.import_module("torch") - class AtomExcludeMask(AtomExcludeMaskDP): def __setattr__(self, name: str, value: Any) -> None: diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 5f66959d16..3effcfc488 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib from typing import ( Any, ClassVar, ) import numpy as np +import torch from deepmd.dpmodel.common import ( NativeOP, @@ -22,8 +22,6 @@ to_torch_array, ) -torch = importlib.import_module("torch") - class TorchArrayParam(torch.nn.Parameter): def __new__( # noqa: PYI034 diff --git a/pyproject.toml b/pyproject.toml index bd403dfaf2..15eb0b2ae5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -411,6 +411,7 @@ convention = "numpy" banned-module-level-imports = [ "deepmd.tf", "deepmd.pt", + "deepmd.pt_expt", "deepmd.pd", "deepmd.jax", "tensorflow", @@ -432,12 +433,14 @@ runtime-evaluated-base-classes = ["torch.nn.Module"] "data/**" = ["ANN"] "deepmd/tf/**" = ["TID253", "ANN"] "deepmd/pt/**" = ["TID253", "B905"] +"deepmd/pt_expt/**" = ["TID253", "B905"] "deepmd/jax/**" = ["TID253"] "deepmd/pd/**" = ["TID253", "B905"] "source/**" = ["ANN"] "source/tests/tf/**" = ["TID253", "ANN"] "source/tests/pt/**" = ["TID253", "ANN"] +"source/tests/pt_expt/**" = ["TID253", "ANN"] "source/tests/jax/**" = ["TID253", "ANN"] "source/tests/pd/**" = ["TID253", "ANN"] "source/tests/universal/pt/**" = ["TID253", "ANN"] diff --git a/source/tests/pt_expt/descriptor/test_se_e2_a.py b/source/tests/pt_expt/descriptor/test_se_e2_a.py index 57923b97a3..e63138e43b 100644 --- a/source/tests/pt_expt/descriptor/test_se_e2_a.py +++ b/source/tests/pt_expt/descriptor/test_se_e2_a.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib import itertools import unittest import numpy as np +import torch from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA from deepmd.pt_expt.descriptor.se_e2_a import ( @@ -29,8 +29,6 @@ GLOBAL_SEED, ) -torch = importlib.import_module("torch") - class TestDescrptSeA(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self) -> None: diff --git a/source/tests/pt_expt/descriptor/test_se_r.py b/source/tests/pt_expt/descriptor/test_se_r.py index 6e7339801c..c789b13652 100644 --- a/source/tests/pt_expt/descriptor/test_se_r.py +++ b/source/tests/pt_expt/descriptor/test_se_r.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib import itertools import unittest import numpy as np +import torch from deepmd.dpmodel.descriptor import DescrptSeR as DPDescrptSeR from deepmd.pt_expt.descriptor.se_r import ( @@ -26,8 +26,6 @@ GLOBAL_SEED, ) -torch = importlib.import_module("torch") - class TestDescrptSeR(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self) -> None: diff --git a/source/tests/pt_expt/utils/test_common.py b/source/tests/pt_expt/utils/test_common.py index 63c4983f23..ee8a7ca324 100644 --- a/source/tests/pt_expt/utils/test_common.py +++ b/source/tests/pt_expt/utils/test_common.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib import numpy as np +import torch from deepmd.pt_expt.common import ( to_torch_array, @@ -10,8 +10,6 @@ env, ) -torch = importlib.import_module("torch") - def test_to_torch_array_moves_device() -> None: arr = np.arange(6, dtype=np.float32).reshape(2, 3) diff --git a/source/tests/pt_expt/utils/test_exclusion_mask.py b/source/tests/pt_expt/utils/test_exclusion_mask.py index 7168579052..b3707ef69d 100644 --- a/source/tests/pt_expt/utils/test_exclusion_mask.py +++ b/source/tests/pt_expt/utils/test_exclusion_mask.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib import unittest import numpy as np +import torch from deepmd.pt_expt.utils import ( env, @@ -16,8 +16,6 @@ TestCaseSingleFrameWithNlist, ) -torch = importlib.import_module("torch") - class TestAtomExcludeMask(unittest.TestCase): def test_build_type_exclude_mask(self) -> None: From eedcbaf4f67ff6a9dface303450c4110b6c139b2 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 22:31:03 +0800 Subject: [PATCH 25/33] fix(pt,pt-expt): guard thread setters --- deepmd/pt/utils/env.py | 15 ++++++++++++-- deepmd/pt_expt/utils/env.py | 15 ++++++++++++-- source/tests/pt/test_env_threads.py | 28 ++++++++++++++++++++++++++ source/tests/pt_expt/utils/test_env.py | 28 ++++++++++++++++++++++++++ 4 files changed, 82 insertions(+), 4 deletions(-) create mode 100644 source/tests/pt/test_env_threads.py create mode 100644 source/tests/pt_expt/utils/test_env.py diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index aa384b31b5..9f453c895c 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -93,9 +93,20 @@ set_default_nthreads() intra_nthreads, inter_nthreads = get_default_nthreads() if inter_nthreads > 0: # the behavior of 0 is not documented - torch.set_num_interop_threads(inter_nthreads) + # torch.set_num_interop_threads can only be called once per process. + # Guard to avoid RuntimeError when multiple backends are imported. + try: + if torch.get_num_interop_threads() != inter_nthreads: + torch.set_num_interop_threads(inter_nthreads) + except RuntimeError as err: + log.warning(f"Could not set torch interop threads: {err}") if intra_nthreads > 0: - torch.set_num_threads(intra_nthreads) + # torch.set_num_threads can also fail if called after threads are created. + try: + if torch.get_num_threads() != intra_nthreads: + torch.set_num_threads(intra_nthreads) + except RuntimeError as err: + log.warning(f"Could not set torch intra threads: {err}") __all__ = [ "CACHE_PER_SYS", diff --git a/deepmd/pt_expt/utils/env.py b/deepmd/pt_expt/utils/env.py index ce13e4ef42..56cec25d49 100644 --- a/deepmd/pt_expt/utils/env.py +++ b/deepmd/pt_expt/utils/env.py @@ -93,9 +93,20 @@ set_default_nthreads() intra_nthreads, inter_nthreads = get_default_nthreads() if inter_nthreads > 0: # the behavior of 0 is not documented - torch.set_num_interop_threads(inter_nthreads) + # torch.set_num_interop_threads can only be called once per process. + # Guard to avoid RuntimeError when both pt and pt_expt env modules are imported. + try: + if torch.get_num_interop_threads() != inter_nthreads: + torch.set_num_interop_threads(inter_nthreads) + except RuntimeError as err: + log.warning(f"Could not set torch interop threads: {err}") if intra_nthreads > 0: - torch.set_num_threads(intra_nthreads) + # torch.set_num_threads can also fail if called after threads are created. + try: + if torch.get_num_threads() != intra_nthreads: + torch.set_num_threads(intra_nthreads) + except RuntimeError as err: + log.warning(f"Could not set torch intra threads: {err}") __all__ = [ "CACHE_PER_SYS", diff --git a/source/tests/pt/test_env_threads.py b/source/tests/pt/test_env_threads.py new file mode 100644 index 0000000000..eb6604ceb8 --- /dev/null +++ b/source/tests/pt/test_env_threads.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +import logging + +import torch + +import deepmd.env as common_env + + +def test_env_threads_guard_handles_runtimeerror(monkeypatch, caplog) -> None: + def raise_err(*_args, **_kwargs) -> None: + raise RuntimeError("boom") + + monkeypatch.setattr(common_env, "set_default_nthreads", lambda: None) + monkeypatch.setattr(common_env, "get_default_nthreads", lambda: (1, 1)) + monkeypatch.setattr(torch, "get_num_interop_threads", lambda: 2) + monkeypatch.setattr(torch, "set_num_interop_threads", raise_err) + monkeypatch.setattr(torch, "get_num_threads", lambda: 2) + monkeypatch.setattr(torch, "set_num_threads", raise_err) + + caplog.set_level(logging.WARNING, logger="deepmd.pt.utils.env") + import deepmd.pt.utils.env as env + + importlib.reload(env) + + messages = [record.getMessage() for record in caplog.records] + assert any("Could not set torch interop threads" in msg for msg in messages) + assert any("Could not set torch intra threads" in msg for msg in messages) diff --git a/source/tests/pt_expt/utils/test_env.py b/source/tests/pt_expt/utils/test_env.py new file mode 100644 index 0000000000..bbdc696aea --- /dev/null +++ b/source/tests/pt_expt/utils/test_env.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +import logging + +import torch + +import deepmd.env as common_env + + +def test_env_threads_guard_handles_runtimeerror(monkeypatch, caplog) -> None: + def raise_err(*_args, **_kwargs) -> None: + raise RuntimeError("boom") + + monkeypatch.setattr(common_env, "set_default_nthreads", lambda: None) + monkeypatch.setattr(common_env, "get_default_nthreads", lambda: (1, 1)) + monkeypatch.setattr(torch, "get_num_interop_threads", lambda: 2) + monkeypatch.setattr(torch, "set_num_interop_threads", raise_err) + monkeypatch.setattr(torch, "get_num_threads", lambda: 2) + monkeypatch.setattr(torch, "set_num_threads", raise_err) + + caplog.set_level(logging.WARNING, logger="deepmd.pt_expt.utils.env") + import deepmd.pt_expt.utils.env as env + + importlib.reload(env) + + messages = [record.getMessage() for record in caplog.records] + assert any("Could not set torch interop threads" in msg for msg in messages) + assert any("Could not set torch intra threads" in msg for msg in messages) From d8b2cf43faa618a15e1e590688bd002bf75abe28 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 22:54:34 +0800 Subject: [PATCH 26/33] make exclusion mask modules --- deepmd/pt_expt/utils/exclude_mask.py | 26 ++++++++++++++++--- .../pt_expt/utils/test_exclusion_mask.py | 8 ++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/deepmd/pt_expt/utils/exclude_mask.py b/deepmd/pt_expt/utils/exclude_mask.py index 15fbbc8e34..e757283e1c 100644 --- a/deepmd/pt_expt/utils/exclude_mask.py +++ b/deepmd/pt_expt/utils/exclude_mask.py @@ -12,15 +12,33 @@ ) -class AtomExcludeMask(AtomExcludeMaskDP): +class AtomExcludeMask(AtomExcludeMaskDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + AtomExcludeMaskDP.__init__(self, *args, **kwargs) + def __setattr__(self, name: str, value: Any) -> None: - if name == "type_mask": + if name == "type_mask" and "_buffers" in self.__dict__: value = None if value is None else torch.as_tensor(value, device=env.DEVICE) + if name in self._buffers: + self._buffers[name] = value + return + self.register_buffer(name, value) + return return super().__setattr__(name, value) -class PairExcludeMask(PairExcludeMaskDP): +class PairExcludeMask(PairExcludeMaskDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + PairExcludeMaskDP.__init__(self, *args, **kwargs) + def __setattr__(self, name: str, value: Any) -> None: - if name == "type_mask": + if name == "type_mask" and "_buffers" in self.__dict__: value = None if value is None else torch.as_tensor(value, device=env.DEVICE) + if name in self._buffers: + self._buffers[name] = value + return + self.register_buffer(name, value) + return return super().__setattr__(name, value) diff --git a/source/tests/pt_expt/utils/test_exclusion_mask.py b/source/tests/pt_expt/utils/test_exclusion_mask.py index b3707ef69d..6f836913af 100644 --- a/source/tests/pt_expt/utils/test_exclusion_mask.py +++ b/source/tests/pt_expt/utils/test_exclusion_mask.py @@ -39,6 +39,10 @@ def test_build_type_exclude_mask(self) -> None: mask = des.build_type_exclude_mask(torch.as_tensor(atype, device=env.DEVICE)) np.testing.assert_equal(mask.detach().cpu().numpy(), expected_mask) + def test_type_mask_is_buffer(self) -> None: + des = AtomExcludeMask(3, exclude_types=[0]) + assert "type_mask" in des.state_dict() + class TestPairExcludeMask(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self) -> None: @@ -62,3 +66,7 @@ def test_build_type_exclude_mask(self) -> None: torch.as_tensor(self.atype_ext, device=env.DEVICE), ) np.testing.assert_equal(mask.detach().cpu().numpy(), expected_mask) + + def test_type_mask_is_buffer(self) -> None: + des = PairExcludeMask(self.nt, exclude_types=[[0, 1]]) + assert "type_mask" in des.state_dict() From aeef15a99d6e9ccf8f9cbb93a149451adf7b615e Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 23:46:24 +0800 Subject: [PATCH 27/33] fix(pt-expt): clear params on None --- deepmd/pt_expt/utils/network.py | 6 ++++++ source/tests/pt_expt/utils/test_network.py | 21 +++++++++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 source/tests/pt_expt/utils/test_network.py diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 3effcfc488..721a511f5f 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -48,6 +48,12 @@ def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"} and "_parameters" in self.__dict__: val = to_torch_array(value) if val is None: + if name in self._parameters: + self._parameters[name] = None + return + if name in self._buffers: + self._buffers[name] = None + return return super().__setattr__(name, None) if getattr(self, "trainable", False): param = ( diff --git a/source/tests/pt_expt/utils/test_network.py b/source/tests/pt_expt/utils/test_network.py new file mode 100644 index 0000000000..ad7c2a7e3d --- /dev/null +++ b/source/tests/pt_expt/utils/test_network.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +from deepmd.pt_expt.utils.network import ( + NativeLayer, +) + + +def test_native_layer_clears_parameter_on_none() -> None: + layer = NativeLayer(2, 3, trainable=True) + assert layer.w is not None + layer.w = None + assert layer.w is None + assert layer._parameters.get("w") is None + + +def test_native_layer_clears_buffer_on_none() -> None: + layer = NativeLayer(2, 3, trainable=False) + assert layer.w is not None + layer.w = None + assert layer.w is None + assert layer._buffers.get("w") is None From 8bdb1f89eb509efc8d6133812ec5e9a2c678ae7b Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 7 Feb 2026 18:51:50 +0800 Subject: [PATCH 28/33] fix bug --- source/tests/pt/test_env_threads.py | 12 +++++++++--- source/tests/pt_expt/utils/test_env.py | 12 +++++++++--- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/source/tests/pt/test_env_threads.py b/source/tests/pt/test_env_threads.py index eb6604ceb8..50de1996d8 100644 --- a/source/tests/pt/test_env_threads.py +++ b/source/tests/pt/test_env_threads.py @@ -7,7 +7,7 @@ import deepmd.env as common_env -def test_env_threads_guard_handles_runtimeerror(monkeypatch, caplog) -> None: +def test_env_threads_guard_handles_runtimeerror(monkeypatch) -> None: def raise_err(*_args, **_kwargs) -> None: raise RuntimeError("boom") @@ -18,11 +18,17 @@ def raise_err(*_args, **_kwargs) -> None: monkeypatch.setattr(torch, "get_num_threads", lambda: 2) monkeypatch.setattr(torch, "set_num_threads", raise_err) - caplog.set_level(logging.WARNING, logger="deepmd.pt.utils.env") + messages: list[str] = [] + original_warning = logging.Logger.warning + + def capture_warning(self, msg, *args, **kwargs): # type: ignore[no-untyped-def] + messages.append(str(msg)) + return original_warning(self, msg, *args, **kwargs) + + monkeypatch.setattr(logging.Logger, "warning", capture_warning) import deepmd.pt.utils.env as env importlib.reload(env) - messages = [record.getMessage() for record in caplog.records] assert any("Could not set torch interop threads" in msg for msg in messages) assert any("Could not set torch intra threads" in msg for msg in messages) diff --git a/source/tests/pt_expt/utils/test_env.py b/source/tests/pt_expt/utils/test_env.py index bbdc696aea..a589c80ae1 100644 --- a/source/tests/pt_expt/utils/test_env.py +++ b/source/tests/pt_expt/utils/test_env.py @@ -7,7 +7,7 @@ import deepmd.env as common_env -def test_env_threads_guard_handles_runtimeerror(monkeypatch, caplog) -> None: +def test_env_threads_guard_handles_runtimeerror(monkeypatch) -> None: def raise_err(*_args, **_kwargs) -> None: raise RuntimeError("boom") @@ -18,11 +18,17 @@ def raise_err(*_args, **_kwargs) -> None: monkeypatch.setattr(torch, "get_num_threads", lambda: 2) monkeypatch.setattr(torch, "set_num_threads", raise_err) - caplog.set_level(logging.WARNING, logger="deepmd.pt_expt.utils.env") + messages: list[str] = [] + original_warning = logging.Logger.warning + + def capture_warning(self, msg, *args, **kwargs): # type: ignore[no-untyped-def] + messages.append(str(msg)) + return original_warning(self, msg, *args, **kwargs) + + monkeypatch.setattr(logging.Logger, "warning", capture_warning) import deepmd.pt_expt.utils.env as env importlib.reload(env) - messages = [record.getMessage() for record in caplog.records] assert any("Could not set torch interop threads" in msg for msg in messages) assert any("Could not set torch intra threads" in msg for msg in messages) From d3b01da5075898e807aa08a34e2fcb49d3b47749 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 8 Feb 2026 16:36:56 +0800 Subject: [PATCH 29/33] utility to handel dpmodel -> pt_expt conversion --- deepmd/pt_expt/common.py | 253 ++++++++++++++++++++++++++- deepmd/pt_expt/descriptor/se_e2_a.py | 39 +---- deepmd/pt_expt/descriptor/se_r.py | 39 +---- deepmd/pt_expt/utils/__init__.py | 4 + deepmd/pt_expt/utils/exclude_mask.py | 39 +++-- deepmd/pt_expt/utils/network.py | 7 + 6 files changed, 293 insertions(+), 88 deletions(-) diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index db8b94989b..e687fa8e48 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -1,4 +1,22 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +"""Common utilities for the pt_expt backend. + +This module provides the core infrastructure for automatically wrapping dpmodel +classes (array_api_compat-based) as PyTorch modules. The key insight is to +detect attributes by their **value type** rather than by hard-coded names: + +- numpy arrays → torch buffers (persistent state like statistics, masks) +- dpmodel objects → pt_expt torch.nn.Module wrappers (via registry lookup) +- None values → clear existing buffers + +This eliminates the need to manually enumerate attribute names in each wrapper's +__setattr__ method, making the codebase more maintainable when dpmodel adds +new attributes. +""" + +from collections.abc import ( + Callable, +) from typing import ( Any, overload, @@ -7,11 +25,203 @@ import numpy as np import torch -from deepmd.pt_expt.utils import ( - env, -) +# --------------------------------------------------------------------------- +# dpmodel → pt_expt converter registry +# --------------------------------------------------------------------------- +_DPMODEL_TO_PT_EXPT: dict[type, Callable[[Any], torch.nn.Module]] = {} +"""Registry mapping dpmodel classes to their pt_expt converter functions. + +This registry is populated at module import time via `register_dpmodel_mapping` +calls in each pt_expt wrapper module (e.g., exclude_mask.py, network.py). When +dpmodel_setattr encounters a dpmodel object, it looks up the object's type in +this registry to find the appropriate converter. + +Examples of registered mappings: +- AtomExcludeMaskDP → lambda v: AtomExcludeMask(v.ntypes, exclude_types=...) +- NetworkCollectionDP → lambda v: NetworkCollection.deserialize(v.serialize()) +""" + + +def register_dpmodel_mapping( + dpmodel_cls: type, converter: Callable[[Any], torch.nn.Module] +) -> None: + """Register a converter that turns a dpmodel instance into a pt_expt Module. + + This function is called at module import time by each pt_expt wrapper to + register how dpmodel objects should be converted when they're assigned as + attributes. The converter is a callable that takes a dpmodel instance and + returns the corresponding pt_expt torch.nn.Module wrapper. + + Parameters + ---------- + dpmodel_cls : type + The dpmodel class to register (e.g., AtomExcludeMaskDP, NetworkCollectionDP). + This is the key used for lookup in dpmodel_setattr. + converter : Callable[[Any], torch.nn.Module] + A callable that converts a dpmodel instance to a pt_expt module. + Common patterns: + - Reconstruct from constructor args: lambda v: PtExptClass(v.ntypes, ...) + - Round-trip via serialization: lambda v: PtExptClass.deserialize(v.serialize()) + + Notes + ----- + This function must be called AFTER the pt_expt wrapper class is defined but + BEFORE dpmodel_setattr might encounter instances of dpmodel_cls. In practice, + this means calling it immediately after the wrapper class definition at module + import time. + + Examples + -------- + >>> register_dpmodel_mapping( + ... AtomExcludeMaskDP, + ... lambda v: AtomExcludeMask( + ... v.ntypes, exclude_types=list(v.get_exclude_types()) + ... ), + ... ) + """ + _DPMODEL_TO_PT_EXPT[dpmodel_cls] = converter + + +def try_convert_module(value: Any) -> torch.nn.Module | None: + """Convert a dpmodel object to its pt_expt wrapper if a converter is registered. + + This function looks up the exact type of *value* in the _DPMODEL_TO_PT_EXPT + registry. If a converter is found, it invokes it to produce a torch.nn.Module + wrapper; otherwise it returns None. + + Parameters + ---------- + value : Any + The value to potentially convert. Typically a dpmodel object like + AtomExcludeMaskDP or NetworkCollectionDP. + + Returns + ------- + torch.nn.Module or None + The converted pt_expt module if a converter is registered for value's + type, otherwise None. + + Notes + ----- + This function uses exact type matching (not isinstance checks) to ensure + predictable behavior. Each dpmodel class must be explicitly registered via + register_dpmodel_mapping. + + The function is called by dpmodel_setattr when it encounters an object that + might be a dpmodel instance. If conversion succeeds, the caller should use + the converted module instead of the original value. + """ + converter = _DPMODEL_TO_PT_EXPT.get(type(value)) + if converter is not None: + return converter(value) + return None + + +def dpmodel_setattr(obj: torch.nn.Module, name: str, value: Any) -> tuple[bool, Any]: + """Common __setattr__ logic for pt_expt wrappers around dpmodel classes. + This function implements automatic attribute detection by value type, eliminating + the need to hard-code attribute names in each wrapper's __setattr__ method. It + handles three cases: + 1. **numpy arrays → torch buffers**: Persistent state like statistics (davg, dstd) + or masks that should be saved in state_dict and moved with .to(device). + 2. **None values → clear buffers**: Setting an existing buffer to None. + 3. **dpmodel objects → pt_expt modules**: Nested dpmodel objects like + AtomExcludeMaskDP or NetworkCollectionDP are converted to their pt_expt + wrappers via the registry. + + Parameters + ---------- + obj : torch.nn.Module + The pt_expt wrapper object whose attribute is being set. Must be a + torch.nn.Module (caller should verify this). + name : str + The attribute name being set. + value : Any + The value being assigned. This function inspects the type to determine + how to handle it. + + Returns + ------- + handled : bool + True if the attribute has been fully set (caller should NOT call + super().__setattr__). False if the caller should forward the (possibly + converted) value to super().__setattr__(name, value). + value : Any + The value to use. May be converted (e.g., dpmodel object → pt_expt module) + or unchanged (e.g., scalar, list, or unregistered object). + + Notes + ----- + **Why this design is safe:** + + - In dpmodel, all persistent arrays use `self.xxx = np.array(...)`. Scalars + use `.item()`, lists use `.tolist()`. So `isinstance(value, np.ndarray)` + reliably identifies buffer-worthy attributes. + - torch.Tensor values assigned to existing buffers fall through to + torch.nn.Module.__setattr__, which correctly updates them. + - dpmodel objects are identified by registry lookup (exact type match), so + only explicitly registered types are converted. + - The function checks `"_buffers" in obj.__dict__` to ensure the object has + been initialized as a torch.nn.Module before attempting buffer operations. + + **Circular import resolution:** + + The function uses a deferred import `from deepmd.pt_expt.utils import env` + inside the function body. This breaks the circular dependency chain: + common.py → utils/__init__.py → exclude_mask.py → common.py. The import is + cached by Python after the first call, so there's no performance penalty. + + **Usage pattern:** + + Typical wrapper classes use this three-line pattern: + + >>> class MyWrapper(MyDPModel, torch.nn.Module): + ... def __setattr__(self, name, value): + ... handled, value = dpmodel_setattr(self, name, value) + ... if not handled: + ... super().__setattr__(name, value) + + Examples + -------- + >>> # Case 1: numpy array → buffer + >>> obj.davg = np.array([1.0, 2.0]) # becomes torch.Tensor buffer + >>> + >>> # Case 2: clear buffer + >>> obj.davg = None # sets buffer to None + >>> + >>> # Case 3: dpmodel object → pt_expt module + >>> obj.emask = AtomExcludeMaskDP(...) # becomes AtomExcludeMask module + """ + from deepmd.pt_expt.utils import env # deferred - avoids circular import + + # numpy array → torch buffer + if isinstance(value, np.ndarray) and "_buffers" in obj.__dict__: + tensor = torch.as_tensor(value, device=env.DEVICE) + if name in obj._buffers: + obj._buffers[name] = tensor + return True, tensor + obj.register_buffer(name, tensor) + return True, tensor + + # clear an existing buffer to None + if value is None and "_buffers" in obj.__dict__ and name in obj._buffers: + obj._buffers[name] = None + return True, None + + # dpmodel object → pt_expt module + if "_modules" in obj.__dict__: + converted = try_convert_module(value) + if converted is not None: + return False, converted + + return False, value + + +# --------------------------------------------------------------------------- +# Utility +# --------------------------------------------------------------------------- @overload def to_torch_array(array: np.ndarray) -> torch.Tensor: ... @@ -25,7 +235,42 @@ def to_torch_array(array: torch.Tensor) -> torch.Tensor: ... def to_torch_array(array: Any) -> torch.Tensor | None: - """Convert input to a torch tensor on the pt-expt device.""" + """Convert input to a torch tensor on the pt_expt device. + + This utility function handles conversion from various array-like types (numpy + arrays, torch tensors on different devices, etc.) to torch tensors on the + pt_expt backend's configured device. + + Parameters + ---------- + array : Any + The input to convert. Can be: + - None (returns None) + - torch.Tensor (moves to pt_expt device) + - numpy array or array-like (converts to torch.Tensor on pt_expt device) + + Returns + ------- + torch.Tensor or None + The input as a torch tensor on the pt_expt device (env.DEVICE), or None + if the input was None. + + Notes + ----- + This function uses the same deferred import pattern as dpmodel_setattr to + avoid circular dependencies. The env module determines the target device + (typically CPU for pt_expt). + + Examples + -------- + >>> import numpy as np + >>> arr = np.array([1.0, 2.0, 3.0]) + >>> tensor = to_torch_array(arr) + >>> tensor.device + device(type='cpu') # or whatever env.DEVICE is set to + """ + from deepmd.pt_expt.utils import env # deferred - avoids circular import + if array is None: return None if torch.is_tensor(array): diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 21c0a4eeb7..1ccb4d2dda 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -6,18 +6,12 @@ import torch from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP +from deepmd.pt_expt.common import ( + dpmodel_setattr, +) from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, ) -from deepmd.pt_expt.utils import ( - env, -) -from deepmd.pt_expt.utils.exclude_mask import ( - PairExcludeMask, -) -from deepmd.pt_expt.utils.network import ( - NetworkCollection, -) @BaseDescriptor.register("se_e2_a_expt") @@ -32,30 +26,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return torch.nn.Module.__call__(self, *args, **kwargs) def __setattr__(self, name: str, value: Any) -> None: - if name in {"davg", "dstd"} and "_buffers" in self.__dict__: - tensor = ( - None if value is None else torch.as_tensor(value, device=env.DEVICE) - ) - if name in self._buffers: - self._buffers[name] = tensor - return - # Register on first assignment so buffers are in state_dict and moved by .to(). - self.register_buffer(name, tensor) - return - if name == "embeddings" and "_modules" in self.__dict__: - if value is not None and not isinstance(value, torch.nn.Module): - if hasattr(value, "serialize"): - value = NetworkCollection.deserialize(value.serialize()) - elif isinstance(value, dict): - value = NetworkCollection.deserialize(value) - return super().__setattr__(name, value) - if name == "emask" and "_modules" in self.__dict__: - if value is not None and not isinstance(value, torch.nn.Module): - value = PairExcludeMask( - self.ntypes, exclude_types=list(value.get_exclude_types()) - ) - return super().__setattr__(name, value) - return super().__setattr__(name, value) + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) def forward( self, diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index 508785949c..7a406fb499 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -6,18 +6,12 @@ import torch from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, +) from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, ) -from deepmd.pt_expt.utils import ( - env, -) -from deepmd.pt_expt.utils.exclude_mask import ( - PairExcludeMask, -) -from deepmd.pt_expt.utils.network import ( - NetworkCollection, -) @BaseDescriptor.register("se_e2_r_expt") @@ -32,30 +26,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return torch.nn.Module.__call__(self, *args, **kwargs) def __setattr__(self, name: str, value: Any) -> None: - if name in {"davg", "dstd"} and "_buffers" in self.__dict__: - tensor = ( - None if value is None else torch.as_tensor(value, device=env.DEVICE) - ) - if name in self._buffers: - self._buffers[name] = tensor - return - # Register on first assignment so buffers are in state_dict and moved by .to(). - self.register_buffer(name, tensor) - return - if name == "embeddings" and "_modules" in self.__dict__: - if value is not None and not isinstance(value, torch.nn.Module): - if hasattr(value, "serialize"): - value = NetworkCollection.deserialize(value.serialize()) - elif isinstance(value, dict): - value = NetworkCollection.deserialize(value) - return super().__setattr__(name, value) - if name == "emask" and "_modules" in self.__dict__: - if value is not None and not isinstance(value, torch.nn.Module): - value = PairExcludeMask( - self.ntypes, exclude_types=list(value.get_exclude_types()) - ) - return super().__setattr__(name, value) - return super().__setattr__(name, value) + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) def forward( self, diff --git a/deepmd/pt_expt/utils/__init__.py b/deepmd/pt_expt/utils/__init__.py index f90cf82249..93f765a27c 100644 --- a/deepmd/pt_expt/utils/__init__.py +++ b/deepmd/pt_expt/utils/__init__.py @@ -4,8 +4,12 @@ AtomExcludeMask, PairExcludeMask, ) +from .network import ( + NetworkCollection, +) __all__ = [ "AtomExcludeMask", + "NetworkCollection", "PairExcludeMask", ] diff --git a/deepmd/pt_expt/utils/exclude_mask.py b/deepmd/pt_expt/utils/exclude_mask.py index e757283e1c..4060b8c446 100644 --- a/deepmd/pt_expt/utils/exclude_mask.py +++ b/deepmd/pt_expt/utils/exclude_mask.py @@ -7,8 +7,9 @@ from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP -from deepmd.pt_expt.utils import ( - env, +from deepmd.pt_expt.common import ( + dpmodel_setattr, + register_dpmodel_mapping, ) @@ -18,14 +19,15 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: AtomExcludeMaskDP.__init__(self, *args, **kwargs) def __setattr__(self, name: str, value: Any) -> None: - if name == "type_mask" and "_buffers" in self.__dict__: - value = None if value is None else torch.as_tensor(value, device=env.DEVICE) - if name in self._buffers: - self._buffers[name] = value - return - self.register_buffer(name, value) - return - return super().__setattr__(name, value) + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + +register_dpmodel_mapping( + AtomExcludeMaskDP, + lambda v: AtomExcludeMask(v.ntypes, exclude_types=list(v.get_exclude_types())), +) class PairExcludeMask(PairExcludeMaskDP, torch.nn.Module): @@ -34,11 +36,12 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: PairExcludeMaskDP.__init__(self, *args, **kwargs) def __setattr__(self, name: str, value: Any) -> None: - if name == "type_mask" and "_buffers" in self.__dict__: - value = None if value is None else torch.as_tensor(value, device=env.DEVICE) - if name in self._buffers: - self._buffers[name] = value - return - self.register_buffer(name, value) - return - return super().__setattr__(name, value) + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + +register_dpmodel_mapping( + PairExcludeMaskDP, + lambda v: PairExcludeMask(v.ntypes, exclude_types=list(v.get_exclude_types())), +) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 721a511f5f..84d0024a85 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -19,6 +19,7 @@ make_multilayer_network, ) from deepmd.pt_expt.common import ( + register_dpmodel_mapping, to_torch_array, ) @@ -121,5 +122,11 @@ def __setitem__(self, key: int | tuple, value: Any) -> None: del self._module_networks[key_str] +register_dpmodel_mapping( + NetworkCollectionDP, + lambda v: NetworkCollection.deserialize(v.serialize()), +) + + class LayerNorm(LayerNormDP, NativeLayer): pass From 3452a2a8c0ef161d4bd827066c77e2b1bda7be77 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 8 Feb 2026 17:06:19 +0800 Subject: [PATCH 30/33] fix to_numpy_array device --- deepmd/dpmodel/common.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index bd6f7dac49..dabbc34e01 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -121,10 +121,11 @@ def to_numpy_array(x: Optional["Array"]) -> np.ndarray | None: try: # asarray is not within Array API standard, so may fail return np.asarray(x) - except (ValueError, AttributeError): + except (ValueError, AttributeError, TypeError): xp = array_api_compat.array_namespace(x) # to fix BufferError: Cannot export readonly array since signalling readonly is unsupported by DLPack. - x = xp.asarray(x, copy=True) + # Move to CPU device to ensure numpy compatibility + x = xp.asarray(x, device="cpu", copy=True) return np.from_dlpack(x) From 87e9b9d7b45b02e63cf4142258b85303e8a20b88 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 9 Feb 2026 19:53:38 +0800 Subject: [PATCH 31/33] better type checking --- deepmd/pt_expt/common.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index e687fa8e48..114c08e24b 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -25,10 +25,14 @@ import numpy as np import torch +from deepmd.dpmodel.common import ( + NativeOP, +) + # --------------------------------------------------------------------------- # dpmodel → pt_expt converter registry # --------------------------------------------------------------------------- -_DPMODEL_TO_PT_EXPT: dict[type, Callable[[Any], torch.nn.Module]] = {} +_DPMODEL_TO_PT_EXPT: dict[type[NativeOP], Callable[[NativeOP], torch.nn.Module]] = {} """Registry mapping dpmodel classes to their pt_expt converter functions. This registry is populated at module import time via `register_dpmodel_mapping` @@ -43,7 +47,7 @@ def register_dpmodel_mapping( - dpmodel_cls: type, converter: Callable[[Any], torch.nn.Module] + dpmodel_cls: type[NativeOP], converter: Callable[[NativeOP], torch.nn.Module] ) -> None: """Register a converter that turns a dpmodel instance into a pt_expt Module. @@ -54,10 +58,10 @@ def register_dpmodel_mapping( Parameters ---------- - dpmodel_cls : type + dpmodel_cls : type[NativeOP] The dpmodel class to register (e.g., AtomExcludeMaskDP, NetworkCollectionDP). This is the key used for lookup in dpmodel_setattr. - converter : Callable[[Any], torch.nn.Module] + converter : Callable[[NativeOP], torch.nn.Module] A callable that converts a dpmodel instance to a pt_expt module. Common patterns: - Reconstruct from constructor args: lambda v: PtExptClass(v.ntypes, ...) @@ -212,9 +216,17 @@ def dpmodel_setattr(obj: torch.nn.Module, name: str, value: Any) -> tuple[bool, # dpmodel object → pt_expt module if "_modules" in obj.__dict__: - converted = try_convert_module(value) - if converted is not None: - return False, converted + # Check if this is a NativeOP that needs conversion + if isinstance(value, NativeOP) and not isinstance(value, torch.nn.Module): + converted = try_convert_module(value) + if converted is not None: + return False, converted + # If it's a NativeOP but not registered, this is likely a bug + raise TypeError( + f"Attempted to assign a dpmodel object of type {type(value).__name__} " + f"but no converter is registered. Please call register_dpmodel_mapping " + f"for this type." + ) return False, value From ef84c6c7d5c472d0f27b31d1c1fe8048e136924d Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 9 Feb 2026 19:58:24 +0800 Subject: [PATCH 32/33] fix --- deepmd/pt_expt/common.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index 114c08e24b..01fd36856c 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -216,17 +216,12 @@ def dpmodel_setattr(obj: torch.nn.Module, name: str, value: Any) -> tuple[bool, # dpmodel object → pt_expt module if "_modules" in obj.__dict__: - # Check if this is a NativeOP that needs conversion - if isinstance(value, NativeOP) and not isinstance(value, torch.nn.Module): - converted = try_convert_module(value) - if converted is not None: - return False, converted - # If it's a NativeOP but not registered, this is likely a bug - raise TypeError( - f"Attempted to assign a dpmodel object of type {type(value).__name__} " - f"but no converter is registered. Please call register_dpmodel_mapping " - f"for this type." - ) + converted = try_convert_module(value) + if converted is not None: + return False, converted + # Note: Some NativeOP objects (like EnvMat) don't need conversion and can + # be used directly. If a NativeOP truly needs conversion but isn't registered, + # it will fail at runtime when the object is actually used. return False, value From 55e094e2b80b9f05eea1733423cdc109d1de329f Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 9 Feb 2026 20:17:00 +0800 Subject: [PATCH 33/33] raise error --- deepmd/pt_expt/common.py | 34 ++++++++++++++++++++++++++------ deepmd/pt_expt/utils/__init__.py | 11 +++++++++++ 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index 01fd36856c..c7c6cff99b 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -216,12 +216,19 @@ def dpmodel_setattr(obj: torch.nn.Module, name: str, value: Any) -> tuple[bool, # dpmodel object → pt_expt module if "_modules" in obj.__dict__: - converted = try_convert_module(value) - if converted is not None: - return False, converted - # Note: Some NativeOP objects (like EnvMat) don't need conversion and can - # be used directly. If a NativeOP truly needs conversion but isn't registered, - # it will fail at runtime when the object is actually used. + # Try to convert dpmodel objects that aren't already torch.nn.Modules + if not isinstance(value, torch.nn.Module): + converted = try_convert_module(value) + if converted is not None: + return False, converted + # If this is a NativeOP that should have been registered but wasn't, raise error + if isinstance(value, NativeOP): + raise TypeError( + f"Attempted to assign a dpmodel object of type {type(value).__name__} " + f"but no converter is registered. Please call register_dpmodel_mapping " + f"for this type. If this object doesn't need conversion, register it " + f"with an identity converter: lambda v: v" + ) return False, value @@ -283,3 +290,18 @@ def to_torch_array(array: Any) -> torch.Tensor | None: if torch.is_tensor(array): return array.to(device=env.DEVICE) return torch.as_tensor(array, device=env.DEVICE) + + +# Import utils to trigger dpmodel→pt_expt converter registrations +# This must happen after the functions above are defined to avoid circular imports +def _ensure_registrations() -> None: + """Import pt_expt.utils modules to register converters. + + This function is called on module import to ensure all dpmodel→pt_expt + converters are registered before any descriptors/fittings try to use them. + """ + # Import triggers registration of NetworkCollection, ExcludeMask, EnvMat + from deepmd.pt_expt import utils # noqa: F401 + + +_ensure_registrations() diff --git a/deepmd/pt_expt/utils/__init__.py b/deepmd/pt_expt/utils/__init__.py index 93f765a27c..bcd3d4450a 100644 --- a/deepmd/pt_expt/utils/__init__.py +++ b/deepmd/pt_expt/utils/__init__.py @@ -1,5 +1,12 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.utils.env_mat import ( + EnvMat, +) +from deepmd.pt_expt.common import ( + register_dpmodel_mapping, +) + from .exclude_mask import ( AtomExcludeMask, PairExcludeMask, @@ -8,6 +15,10 @@ NetworkCollection, ) +# Register EnvMat with identity converter - it doesn't need wrapping +# as it's a stateless utility class +register_dpmodel_mapping(EnvMat, lambda v: v) + __all__ = [ "AtomExcludeMask", "NetworkCollection",