diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index feaded0b01..ca8663d13b 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -785,7 +785,112 @@ def deserialize(cls, data: dict) -> "EmbeddingNet": return EN -EmbeddingNet = make_embedding_network(NativeNet, NativeLayer) +class EmbeddingNet(NativeNet): + """The embedding network. + + Parameters + ---------- + in_dim + Input dimension. + neuron + The number of neurons in each layer. The output dimension + is the same as the dimension of the last layer. + activation_function + The activation function. + resnet_dt + Use time step at the resnet architecture. + precision + Floating point precision for the model parameters. + seed : int, optional + Random seed. + bias : bool, Optional + Whether to use bias in the embedding layer. + trainable : bool or list[bool], Optional + Whether the weights are trainable. If a list, each element + corresponds to a layer. + """ + + def __init__( + self, + in_dim: int, + neuron: list[int] = [24, 48, 96], + activation_function: str = "tanh", + resnet_dt: bool = False, + precision: str = DEFAULT_PRECISION, + seed: int | list[int] | None = None, + bias: bool = True, + trainable: bool | list[bool] = True, + ) -> None: + layers = [] + i_in = in_dim + if isinstance(trainable, bool): + trainable = [trainable] * len(neuron) + for idx, ii in enumerate(neuron): + i_ot = ii + layers.append( + NativeLayer( + i_in, + i_ot, + bias=bias, + use_timestep=resnet_dt, + activation_function=activation_function, + resnet=True, + precision=precision, + seed=child_seed(seed, idx), + trainable=trainable[idx], + ).serialize() + ) + i_in = i_ot + super().__init__(layers) + self.in_dim = in_dim + self.neuron = neuron + self.activation_function = activation_function + self.resnet_dt = resnet_dt + self.precision = precision + self.bias = bias + + def serialize(self) -> dict: + """Serialize the network to a dict. + + Returns + ------- + dict + The serialized network. + """ + return { + "@class": "EmbeddingNetwork", + "@version": 2, + "in_dim": self.in_dim, + "neuron": self.neuron.copy(), + "activation_function": self.activation_function, + "resnet_dt": self.resnet_dt, + "bias": self.bias, + # make deterministic + "precision": np.dtype(PRECISION_DICT[self.precision]).name, + "layers": [layer.serialize() for layer in self.layers], + } + + @classmethod + def deserialize(cls, data: dict) -> "EmbeddingNet": + """Deserialize the network from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 2, 1) + data.pop("@class", None) + layers = data.pop("layers") + obj = cls(**data) + # Reinitialize layers from serialized data, using the same layer type + # that __init__ created (respects subclass overrides via MRO). + layer_type = type(obj.layers[0]) + obj.layers = type(obj.layers)( + [layer_type.deserialize(layer) for layer in layers] + ) + return obj def make_fitting_network( diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index db8b94989b..e687fa8e48 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -1,4 +1,22 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +"""Common utilities for the pt_expt backend. + +This module provides the core infrastructure for automatically wrapping dpmodel +classes (array_api_compat-based) as PyTorch modules. The key insight is to +detect attributes by their **value type** rather than by hard-coded names: + +- numpy arrays → torch buffers (persistent state like statistics, masks) +- dpmodel objects → pt_expt torch.nn.Module wrappers (via registry lookup) +- None values → clear existing buffers + +This eliminates the need to manually enumerate attribute names in each wrapper's +__setattr__ method, making the codebase more maintainable when dpmodel adds +new attributes. +""" + +from collections.abc import ( + Callable, +) from typing import ( Any, overload, @@ -7,11 +25,203 @@ import numpy as np import torch -from deepmd.pt_expt.utils import ( - env, -) +# --------------------------------------------------------------------------- +# dpmodel → pt_expt converter registry +# --------------------------------------------------------------------------- +_DPMODEL_TO_PT_EXPT: dict[type, Callable[[Any], torch.nn.Module]] = {} +"""Registry mapping dpmodel classes to their pt_expt converter functions. + +This registry is populated at module import time via `register_dpmodel_mapping` +calls in each pt_expt wrapper module (e.g., exclude_mask.py, network.py). When +dpmodel_setattr encounters a dpmodel object, it looks up the object's type in +this registry to find the appropriate converter. + +Examples of registered mappings: +- AtomExcludeMaskDP → lambda v: AtomExcludeMask(v.ntypes, exclude_types=...) +- NetworkCollectionDP → lambda v: NetworkCollection.deserialize(v.serialize()) +""" + + +def register_dpmodel_mapping( + dpmodel_cls: type, converter: Callable[[Any], torch.nn.Module] +) -> None: + """Register a converter that turns a dpmodel instance into a pt_expt Module. + + This function is called at module import time by each pt_expt wrapper to + register how dpmodel objects should be converted when they're assigned as + attributes. The converter is a callable that takes a dpmodel instance and + returns the corresponding pt_expt torch.nn.Module wrapper. + + Parameters + ---------- + dpmodel_cls : type + The dpmodel class to register (e.g., AtomExcludeMaskDP, NetworkCollectionDP). + This is the key used for lookup in dpmodel_setattr. + converter : Callable[[Any], torch.nn.Module] + A callable that converts a dpmodel instance to a pt_expt module. + Common patterns: + - Reconstruct from constructor args: lambda v: PtExptClass(v.ntypes, ...) + - Round-trip via serialization: lambda v: PtExptClass.deserialize(v.serialize()) + + Notes + ----- + This function must be called AFTER the pt_expt wrapper class is defined but + BEFORE dpmodel_setattr might encounter instances of dpmodel_cls. In practice, + this means calling it immediately after the wrapper class definition at module + import time. + + Examples + -------- + >>> register_dpmodel_mapping( + ... AtomExcludeMaskDP, + ... lambda v: AtomExcludeMask( + ... v.ntypes, exclude_types=list(v.get_exclude_types()) + ... ), + ... ) + """ + _DPMODEL_TO_PT_EXPT[dpmodel_cls] = converter + + +def try_convert_module(value: Any) -> torch.nn.Module | None: + """Convert a dpmodel object to its pt_expt wrapper if a converter is registered. + + This function looks up the exact type of *value* in the _DPMODEL_TO_PT_EXPT + registry. If a converter is found, it invokes it to produce a torch.nn.Module + wrapper; otherwise it returns None. + + Parameters + ---------- + value : Any + The value to potentially convert. Typically a dpmodel object like + AtomExcludeMaskDP or NetworkCollectionDP. + + Returns + ------- + torch.nn.Module or None + The converted pt_expt module if a converter is registered for value's + type, otherwise None. + + Notes + ----- + This function uses exact type matching (not isinstance checks) to ensure + predictable behavior. Each dpmodel class must be explicitly registered via + register_dpmodel_mapping. + + The function is called by dpmodel_setattr when it encounters an object that + might be a dpmodel instance. If conversion succeeds, the caller should use + the converted module instead of the original value. + """ + converter = _DPMODEL_TO_PT_EXPT.get(type(value)) + if converter is not None: + return converter(value) + return None + + +def dpmodel_setattr(obj: torch.nn.Module, name: str, value: Any) -> tuple[bool, Any]: + """Common __setattr__ logic for pt_expt wrappers around dpmodel classes. + This function implements automatic attribute detection by value type, eliminating + the need to hard-code attribute names in each wrapper's __setattr__ method. It + handles three cases: + 1. **numpy arrays → torch buffers**: Persistent state like statistics (davg, dstd) + or masks that should be saved in state_dict and moved with .to(device). + 2. **None values → clear buffers**: Setting an existing buffer to None. + 3. **dpmodel objects → pt_expt modules**: Nested dpmodel objects like + AtomExcludeMaskDP or NetworkCollectionDP are converted to their pt_expt + wrappers via the registry. + + Parameters + ---------- + obj : torch.nn.Module + The pt_expt wrapper object whose attribute is being set. Must be a + torch.nn.Module (caller should verify this). + name : str + The attribute name being set. + value : Any + The value being assigned. This function inspects the type to determine + how to handle it. + + Returns + ------- + handled : bool + True if the attribute has been fully set (caller should NOT call + super().__setattr__). False if the caller should forward the (possibly + converted) value to super().__setattr__(name, value). + value : Any + The value to use. May be converted (e.g., dpmodel object → pt_expt module) + or unchanged (e.g., scalar, list, or unregistered object). + + Notes + ----- + **Why this design is safe:** + + - In dpmodel, all persistent arrays use `self.xxx = np.array(...)`. Scalars + use `.item()`, lists use `.tolist()`. So `isinstance(value, np.ndarray)` + reliably identifies buffer-worthy attributes. + - torch.Tensor values assigned to existing buffers fall through to + torch.nn.Module.__setattr__, which correctly updates them. + - dpmodel objects are identified by registry lookup (exact type match), so + only explicitly registered types are converted. + - The function checks `"_buffers" in obj.__dict__` to ensure the object has + been initialized as a torch.nn.Module before attempting buffer operations. + + **Circular import resolution:** + + The function uses a deferred import `from deepmd.pt_expt.utils import env` + inside the function body. This breaks the circular dependency chain: + common.py → utils/__init__.py → exclude_mask.py → common.py. The import is + cached by Python after the first call, so there's no performance penalty. + + **Usage pattern:** + + Typical wrapper classes use this three-line pattern: + + >>> class MyWrapper(MyDPModel, torch.nn.Module): + ... def __setattr__(self, name, value): + ... handled, value = dpmodel_setattr(self, name, value) + ... if not handled: + ... super().__setattr__(name, value) + + Examples + -------- + >>> # Case 1: numpy array → buffer + >>> obj.davg = np.array([1.0, 2.0]) # becomes torch.Tensor buffer + >>> + >>> # Case 2: clear buffer + >>> obj.davg = None # sets buffer to None + >>> + >>> # Case 3: dpmodel object → pt_expt module + >>> obj.emask = AtomExcludeMaskDP(...) # becomes AtomExcludeMask module + """ + from deepmd.pt_expt.utils import env # deferred - avoids circular import + + # numpy array → torch buffer + if isinstance(value, np.ndarray) and "_buffers" in obj.__dict__: + tensor = torch.as_tensor(value, device=env.DEVICE) + if name in obj._buffers: + obj._buffers[name] = tensor + return True, tensor + obj.register_buffer(name, tensor) + return True, tensor + + # clear an existing buffer to None + if value is None and "_buffers" in obj.__dict__ and name in obj._buffers: + obj._buffers[name] = None + return True, None + + # dpmodel object → pt_expt module + if "_modules" in obj.__dict__: + converted = try_convert_module(value) + if converted is not None: + return False, converted + + return False, value + + +# --------------------------------------------------------------------------- +# Utility +# --------------------------------------------------------------------------- @overload def to_torch_array(array: np.ndarray) -> torch.Tensor: ... @@ -25,7 +235,42 @@ def to_torch_array(array: torch.Tensor) -> torch.Tensor: ... def to_torch_array(array: Any) -> torch.Tensor | None: - """Convert input to a torch tensor on the pt-expt device.""" + """Convert input to a torch tensor on the pt_expt device. + + This utility function handles conversion from various array-like types (numpy + arrays, torch tensors on different devices, etc.) to torch tensors on the + pt_expt backend's configured device. + + Parameters + ---------- + array : Any + The input to convert. Can be: + - None (returns None) + - torch.Tensor (moves to pt_expt device) + - numpy array or array-like (converts to torch.Tensor on pt_expt device) + + Returns + ------- + torch.Tensor or None + The input as a torch tensor on the pt_expt device (env.DEVICE), or None + if the input was None. + + Notes + ----- + This function uses the same deferred import pattern as dpmodel_setattr to + avoid circular dependencies. The env module determines the target device + (typically CPU for pt_expt). + + Examples + -------- + >>> import numpy as np + >>> arr = np.array([1.0, 2.0, 3.0]) + >>> tensor = to_torch_array(arr) + >>> tensor.device + device(type='cpu') # or whatever env.DEVICE is set to + """ + from deepmd.pt_expt.utils import env # deferred - avoids circular import + if array is None: return None if torch.is_tensor(array): diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 21c0a4eeb7..1ccb4d2dda 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -6,18 +6,12 @@ import torch from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP +from deepmd.pt_expt.common import ( + dpmodel_setattr, +) from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, ) -from deepmd.pt_expt.utils import ( - env, -) -from deepmd.pt_expt.utils.exclude_mask import ( - PairExcludeMask, -) -from deepmd.pt_expt.utils.network import ( - NetworkCollection, -) @BaseDescriptor.register("se_e2_a_expt") @@ -32,30 +26,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return torch.nn.Module.__call__(self, *args, **kwargs) def __setattr__(self, name: str, value: Any) -> None: - if name in {"davg", "dstd"} and "_buffers" in self.__dict__: - tensor = ( - None if value is None else torch.as_tensor(value, device=env.DEVICE) - ) - if name in self._buffers: - self._buffers[name] = tensor - return - # Register on first assignment so buffers are in state_dict and moved by .to(). - self.register_buffer(name, tensor) - return - if name == "embeddings" and "_modules" in self.__dict__: - if value is not None and not isinstance(value, torch.nn.Module): - if hasattr(value, "serialize"): - value = NetworkCollection.deserialize(value.serialize()) - elif isinstance(value, dict): - value = NetworkCollection.deserialize(value) - return super().__setattr__(name, value) - if name == "emask" and "_modules" in self.__dict__: - if value is not None and not isinstance(value, torch.nn.Module): - value = PairExcludeMask( - self.ntypes, exclude_types=list(value.get_exclude_types()) - ) - return super().__setattr__(name, value) - return super().__setattr__(name, value) + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) def forward( self, diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index 508785949c..7a406fb499 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -6,18 +6,12 @@ import torch from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, +) from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, ) -from deepmd.pt_expt.utils import ( - env, -) -from deepmd.pt_expt.utils.exclude_mask import ( - PairExcludeMask, -) -from deepmd.pt_expt.utils.network import ( - NetworkCollection, -) @BaseDescriptor.register("se_e2_r_expt") @@ -32,30 +26,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return torch.nn.Module.__call__(self, *args, **kwargs) def __setattr__(self, name: str, value: Any) -> None: - if name in {"davg", "dstd"} and "_buffers" in self.__dict__: - tensor = ( - None if value is None else torch.as_tensor(value, device=env.DEVICE) - ) - if name in self._buffers: - self._buffers[name] = tensor - return - # Register on first assignment so buffers are in state_dict and moved by .to(). - self.register_buffer(name, tensor) - return - if name == "embeddings" and "_modules" in self.__dict__: - if value is not None and not isinstance(value, torch.nn.Module): - if hasattr(value, "serialize"): - value = NetworkCollection.deserialize(value.serialize()) - elif isinstance(value, dict): - value = NetworkCollection.deserialize(value) - return super().__setattr__(name, value) - if name == "emask" and "_modules" in self.__dict__: - if value is not None and not isinstance(value, torch.nn.Module): - value = PairExcludeMask( - self.ntypes, exclude_types=list(value.get_exclude_types()) - ) - return super().__setattr__(name, value) - return super().__setattr__(name, value) + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) def forward( self, diff --git a/deepmd/pt_expt/utils/__init__.py b/deepmd/pt_expt/utils/__init__.py index f90cf82249..93f765a27c 100644 --- a/deepmd/pt_expt/utils/__init__.py +++ b/deepmd/pt_expt/utils/__init__.py @@ -4,8 +4,12 @@ AtomExcludeMask, PairExcludeMask, ) +from .network import ( + NetworkCollection, +) __all__ = [ "AtomExcludeMask", + "NetworkCollection", "PairExcludeMask", ] diff --git a/deepmd/pt_expt/utils/exclude_mask.py b/deepmd/pt_expt/utils/exclude_mask.py index e757283e1c..4060b8c446 100644 --- a/deepmd/pt_expt/utils/exclude_mask.py +++ b/deepmd/pt_expt/utils/exclude_mask.py @@ -7,8 +7,9 @@ from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP -from deepmd.pt_expt.utils import ( - env, +from deepmd.pt_expt.common import ( + dpmodel_setattr, + register_dpmodel_mapping, ) @@ -18,14 +19,15 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: AtomExcludeMaskDP.__init__(self, *args, **kwargs) def __setattr__(self, name: str, value: Any) -> None: - if name == "type_mask" and "_buffers" in self.__dict__: - value = None if value is None else torch.as_tensor(value, device=env.DEVICE) - if name in self._buffers: - self._buffers[name] = value - return - self.register_buffer(name, value) - return - return super().__setattr__(name, value) + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + +register_dpmodel_mapping( + AtomExcludeMaskDP, + lambda v: AtomExcludeMask(v.ntypes, exclude_types=list(v.get_exclude_types())), +) class PairExcludeMask(PairExcludeMaskDP, torch.nn.Module): @@ -34,11 +36,12 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: PairExcludeMaskDP.__init__(self, *args, **kwargs) def __setattr__(self, name: str, value: Any) -> None: - if name == "type_mask" and "_buffers" in self.__dict__: - value = None if value is None else torch.as_tensor(value, device=env.DEVICE) - if name in self._buffers: - self._buffers[name] = value - return - self.register_buffer(name, value) - return - return super().__setattr__(name, value) + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + +register_dpmodel_mapping( + PairExcludeMaskDP, + lambda v: PairExcludeMask(v.ntypes, exclude_types=list(v.get_exclude_types())), +) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 721a511f5f..b115214056 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -10,15 +10,16 @@ from deepmd.dpmodel.common import ( NativeOP, ) +from deepmd.dpmodel.utils.network import EmbeddingNet as EmbeddingNetDP from deepmd.dpmodel.utils.network import LayerNorm as LayerNormDP from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP from deepmd.dpmodel.utils.network import NetworkCollection as NetworkCollectionDP from deepmd.dpmodel.utils.network import ( - make_embedding_network, make_fitting_network, make_multilayer_network, ) from deepmd.pt_expt.common import ( + register_dpmodel_mapping, to_torch_array, ) @@ -90,8 +91,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.call(x) -class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)): - pass +class EmbeddingNet(EmbeddingNetDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + EmbeddingNetDP.__init__(self, *args, **kwargs) + # EmbeddingNetDP.__init__ creates dpmodel NativeLayer instances. + # Convert to pt_expt NativeLayer and wrap in ModuleList. + self.layers = torch.nn.ModuleList( + [NativeLayer.deserialize(layer.serialize()) for layer in self.layers] + ) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return torch.nn.Module.__call__(self, *args, **kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.call(x) + + +register_dpmodel_mapping( + EmbeddingNetDP, + lambda v: EmbeddingNet.deserialize(v.serialize()), +) class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)): @@ -121,5 +141,11 @@ def __setitem__(self, key: int | tuple, value: Any) -> None: del self._module_networks[key_str] +register_dpmodel_mapping( + NetworkCollectionDP, + lambda v: NetworkCollection.deserialize(v.serialize()), +) + + class LayerNorm(LayerNormDP, NativeLayer): pass diff --git a/source/tests/common/dpmodel/test_network.py b/source/tests/common/dpmodel/test_network.py index 1ea5b1fdf9..3a95dd7af0 100644 --- a/source/tests/common/dpmodel/test_network.py +++ b/source/tests/common/dpmodel/test_network.py @@ -180,6 +180,76 @@ def test_embedding_net(self) -> None: inp = np.ones([ni], dtype=get_xp_precision(np, prec)) np.testing.assert_allclose(en0.call(inp), en1.call(inp)) + def test_is_concrete_class(self) -> None: + """Verify EmbeddingNet is a concrete class, not factory-generated.""" + in_dim = 4 + neuron = [8, 16, 32] + net = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + activation_function="tanh", + resnet_dt=True, + precision="float64", + ) + # Check it's the actual EmbeddingNet class, not a dynamic class + self.assertEqual(net.__class__.__name__, "EmbeddingNet") + self.assertEqual(net.__class__.__module__, "deepmd.dpmodel.utils.network") + # Verify it has the expected attributes + self.assertEqual(net.in_dim, in_dim) + self.assertEqual(net.neuron, neuron) + self.assertEqual(net.activation_function, "tanh") + self.assertEqual(net.resnet_dt, True) + self.assertEqual(len(net.layers), len(neuron)) + + def test_forward_pass(self) -> None: + """Test EmbeddingNet forward pass produces correct shapes.""" + in_dim = 4 + neuron = [8, 16, 32] + net = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + activation_function="tanh", + resnet_dt=True, + precision="float64", + ) + rng = np.random.default_rng() + x = rng.standard_normal((5, in_dim)) + out = net.call(x) + self.assertEqual(out.shape, (5, neuron[-1])) + self.assertEqual(out.dtype, np.float64) + + def test_trainable_parameter_variants(self) -> None: + """Test EmbeddingNet with different trainable configurations.""" + in_dim = 4 + neuron = [8, 16] + + # All trainable + net_trainable = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + trainable=True, + ) + for layer in net_trainable.layers: + self.assertTrue(layer.trainable) + + # All frozen + net_frozen = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + trainable=False, + ) + for layer in net_frozen.layers: + self.assertFalse(layer.trainable) + + # Mixed trainable + net_mixed = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + trainable=[True, False], + ) + self.assertTrue(net_mixed.layers[0].trainable) + self.assertFalse(net_mixed.layers[1].trainable) + class TestFittingNet(unittest.TestCase): def test_fitting_net(self) -> None: diff --git a/source/tests/pt_expt/utils/test_network.py b/source/tests/pt_expt/utils/test_network.py index ad7c2a7e3d..083ea2fbd9 100644 --- a/source/tests/pt_expt/utils/test_network.py +++ b/source/tests/pt_expt/utils/test_network.py @@ -1,9 +1,22 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +import numpy as np +import torch + +from deepmd.dpmodel.utils.network import EmbeddingNet as DPEmbeddingNet +from deepmd.pt_expt.utils import ( + env, +) from deepmd.pt_expt.utils.network import ( + EmbeddingNet, NativeLayer, ) +from ...seed import ( + GLOBAL_SEED, +) + def test_native_layer_clears_parameter_on_none() -> None: layer = NativeLayer(2, 3, trainable=True) @@ -19,3 +32,252 @@ def test_native_layer_clears_buffer_on_none() -> None: layer.w = None assert layer.w is None assert layer._buffers.get("w") is None + + +class TestEmbeddingNetRefactor(unittest.TestCase): + """Tests for the refactored EmbeddingNet pt_expt wrapper and integration.""" + + def setUp(self) -> None: + self.in_dim = 4 + self.neuron = [8, 16, 32] + self.activation = "tanh" + self.resnet_dt = True + self.precision = "float64" + + def test_pt_expt_embedding_net_wraps_dpmodel(self) -> None: + """Verify pt_expt EmbeddingNet correctly wraps dpmodel.""" + net = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + # Check it's a torch.nn.Module + self.assertIsInstance(net, torch.nn.Module) + # Check it's also a DPEmbeddingNet + self.assertIsInstance(net, DPEmbeddingNet) + # Check layers are converted to pt_expt NativeLayer (torch modules) + self.assertIsInstance(net.layers, torch.nn.ModuleList) + for layer in net.layers: + self.assertIsInstance(layer, NativeLayer) + self.assertIsInstance(layer, torch.nn.Module) + + def test_pt_expt_embedding_net_forward(self) -> None: + """Test pt_expt EmbeddingNet forward pass returns torch.Tensor.""" + net = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + x = torch.randn(5, self.in_dim, dtype=torch.float64, device=env.DEVICE) + out = net(x) + self.assertIsInstance(out, torch.Tensor) + self.assertEqual(out.shape, (5, self.neuron[-1])) + self.assertEqual(out.dtype, torch.float64) + + def test_serialization_round_trip_pt_expt(self) -> None: + """Test pt_expt EmbeddingNet serialization/deserialization.""" + net = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + x = torch.randn(5, self.in_dim, dtype=torch.float64, device=env.DEVICE) + out1 = net(x) + + # Serialize and deserialize + serialized = net.serialize() + net2 = EmbeddingNet.deserialize(serialized) + + # Verify layers are still pt_expt NativeLayer modules + self.assertIsInstance(net2.layers, torch.nn.ModuleList) + for layer in net2.layers: + self.assertIsInstance(layer, NativeLayer) + + out2 = net2(x) + np.testing.assert_allclose( + out1.detach().cpu().numpy(), + out2.detach().cpu().numpy(), + ) + + def test_deserialize_preserves_layer_type(self) -> None: + """Test that deserialize uses type(obj.layers[0]) to preserve subclass layers. + + This is the key fix: dpmodel's deserialize no longer hardcodes + super(EmbeddingNet, obj).__init__(layers), which would overwrite + pt_expt's converted layers. Instead it uses type(obj.layers[0]) + to respect the subclass's layer type. + """ + # Create pt_expt EmbeddingNet + net = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + + # Verify layers are pt_expt NativeLayer (torch modules) + for layer in net.layers: + self.assertIsInstance(layer, torch.nn.Module) + self.assertTrue(hasattr(layer, "_parameters")) + + # Deserialize + serialized = net.serialize() + net2 = EmbeddingNet.deserialize(serialized) + + # Verify deserialized layers are STILL pt_expt NativeLayer, not dpmodel + for layer in net2.layers: + self.assertIsInstance(layer, torch.nn.Module) + self.assertTrue(hasattr(layer, "_parameters")) + # This would fail if deserialize used hardcoded dpmodel layers + self.assertIsInstance(layer, NativeLayer) + + def test_cross_backend_consistency(self) -> None: + """Test numerical consistency between dpmodel and pt_expt EmbeddingNet.""" + # Create both with same seed + dp_net = DPEmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + pt_net = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + + # Test forward pass + rng = np.random.default_rng() + x_np = rng.standard_normal((5, self.in_dim)) + x_torch = torch.from_numpy(x_np) + + out_dp = dp_net.call(x_np) + out_pt = pt_net(x_torch).detach().cpu().numpy() + + np.testing.assert_allclose(out_dp, out_pt, rtol=1e-10, atol=1e-10) + + def test_registry_converts_dpmodel_to_pt_expt(self) -> None: + """Test that the registry auto-converts dpmodel EmbeddingNet to pt_expt.""" + from deepmd.pt_expt.common import ( + try_convert_module, + ) + + # Create dpmodel EmbeddingNet + dp_net = DPEmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + + # Try to convert via registry + converted = try_convert_module(dp_net) + + # Should return pt_expt EmbeddingNet + self.assertIsNotNone(converted) + self.assertIsInstance(converted, torch.nn.Module) + self.assertIsInstance(converted, EmbeddingNet) + + # Verify layers are pt_expt NativeLayer + for layer in converted.layers: + self.assertIsInstance(layer, NativeLayer) + self.assertIsInstance(layer, torch.nn.Module) + + def test_auto_conversion_in_setattr(self) -> None: + """Test that dpmodel_setattr auto-converts EmbeddingNet attributes.""" + from deepmd.pt_expt.common import ( + dpmodel_setattr, + ) + + # Create a simple torch module + class TestModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.dummy = None + + obj = TestModule() + + # Create dpmodel EmbeddingNet + dp_net = DPEmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + + # Use dpmodel_setattr to set it + handled, value = dpmodel_setattr(obj, "embedding_net", dp_net) + + # Should not be handled (returns converted value for caller to set) + self.assertFalse(handled) + # Value should be converted to pt_expt EmbeddingNet + self.assertIsInstance(value, torch.nn.Module) + self.assertIsInstance(value, EmbeddingNet) + + def test_trainable_parameter_handling(self) -> None: + """Test that trainable parameters work correctly in pt_expt.""" + # Test with trainable=True + net_trainable = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + trainable=True, + seed=GLOBAL_SEED, + ) + + # Count trainable parameters + param_count = sum( + p.numel() for p in net_trainable.parameters() if p.requires_grad + ) + self.assertGreater(param_count, 0) + + # Check all layer parameters are trainable + for layer in net_trainable.layers: + if layer.w is not None: + self.assertTrue(layer.w.requires_grad) + if layer.b is not None: + self.assertTrue(layer.b.requires_grad) + + # Test with trainable=False + net_frozen = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + trainable=False, + seed=GLOBAL_SEED, + ) + + # Count trainable parameters (should be 0) + param_count_frozen = sum( + p.numel() for p in net_frozen.parameters() if p.requires_grad + ) + self.assertEqual(param_count_frozen, 0) + + # Check all layer weights are buffers, not parameters + for layer in net_frozen.layers: + if layer.w is not None: + self.assertFalse(layer.w.requires_grad)