diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index 2353e207a3..9866ddbc3a 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import math +from collections.abc import ( + Callable, +) from typing import ( Any, ) @@ -30,6 +33,9 @@ map_atom_exclude_types, map_pair_exclude_types, ) +from deepmd.utils.path import ( + DPPath, +) from .make_base_atomic_model import ( make_base_atomic_model, @@ -246,6 +252,196 @@ def call( aparam=aparam, ) + def get_intensive(self) -> bool: + """Whether the fitting property is intensive.""" + return False + + def get_compute_stats_distinguish_types(self) -> bool: + """Get whether the fitting net computes stats which are not distinguished between different types of atoms.""" + return True + + def compute_or_load_out_stat( + self, + merged: Callable[[], list[dict]] | list[dict], + stat_file_path: DPPath | None = None, + ) -> None: + """ + Compute the output statistics (e.g. energy bias) for the fitting net from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `np.ndarray` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + stat_file_path : Optional[DPPath] + The path to the stat file. + + """ + self.change_out_bias( + merged, + stat_file_path=stat_file_path, + bias_adjust_mode="set-by-statistic", + ) + + def change_out_bias( + self, + sample_merged: Callable[[], list[dict]] | list[dict], + stat_file_path: DPPath | None = None, + bias_adjust_mode: str = "change-by-statistic", + ) -> None: + """Change the output bias according to the input data and the pretrained model. + + Parameters + ---------- + sample_merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `np.ndarray` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + bias_adjust_mode : str + The mode for changing output bias : ['change-by-statistic', 'set-by-statistic'] + 'change-by-statistic' : perform predictions on labels of target dataset, + and do least square on the errors to obtain the target shift as bias. + 'set-by-statistic' : directly use the statistic output bias in the target dataset. + stat_file_path : Optional[DPPath] + The path to the stat file. + """ + from deepmd.dpmodel.utils.stat import ( + compute_output_stats, + ) + + if bias_adjust_mode == "change-by-statistic": + delta_bias, out_std = compute_output_stats( + sample_merged, + self.get_ntypes(), + keys=list(self.atomic_output_def().keys()), + stat_file_path=stat_file_path, + model_forward=self._get_forward_wrapper_func(), + rcond=self.rcond, + preset_bias=self.preset_out_bias, + stats_distinguish_types=self.get_compute_stats_distinguish_types(), + intensive=self.get_intensive(), + ) + self._store_out_stat(delta_bias, out_std, add=True) + elif bias_adjust_mode == "set-by-statistic": + bias_out, std_out = compute_output_stats( + sample_merged, + self.get_ntypes(), + keys=list(self.atomic_output_def().keys()), + stat_file_path=stat_file_path, + rcond=self.rcond, + preset_bias=self.preset_out_bias, + stats_distinguish_types=self.get_compute_stats_distinguish_types(), + intensive=self.get_intensive(), + ) + self._store_out_stat(bias_out, std_out) + else: + raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode) + + def _store_out_stat( + self, + out_bias: dict[str, np.ndarray], + out_std: dict[str, np.ndarray], + add: bool = False, + ) -> None: + """Store output bias and std into the model.""" + ntypes = self.get_ntypes() + out_bias_data = np.copy(self.out_bias) + out_std_data = np.copy(self.out_std) + for kk in out_bias.keys(): + assert kk in out_std.keys() + idx = self._get_bias_index(kk) + size = self._varsize(self.atomic_output_def()[kk].shape) + if not add: + out_bias_data[idx, :, :size] = out_bias[kk].reshape(ntypes, size) + else: + out_bias_data[idx, :, :size] += out_bias[kk].reshape(ntypes, size) + out_std_data[idx, :, :size] = out_std[kk].reshape(ntypes, size) + self.out_bias = out_bias_data + self.out_std = out_std_data + + def _get_forward_wrapper_func(self) -> Callable[..., dict[str, np.ndarray]]: + """Get a forward wrapper of the atomic model for output bias calculation.""" + import array_api_compat + + from deepmd.dpmodel.utils.nlist import ( + extend_input_and_build_neighbor_list, + ) + + def model_forward( + coord: np.ndarray, + atype: np.ndarray, + box: np.ndarray | None, + fparam: np.ndarray | None = None, + aparam: np.ndarray | None = None, + ) -> dict[str, np.ndarray]: + # Get reference array to determine the target array type and device + # Use out_bias as reference since it's always present + ref_array = self.out_bias + xp = array_api_compat.array_namespace(ref_array) + + # Convert numpy inputs to the model's array type with correct device + device = getattr(ref_array, "device", None) + if device is not None: + # For torch tensors + coord = xp.asarray(coord, device=device) + atype = xp.asarray(atype, device=device) + if box is not None: + # Check if box is all zeros before converting + if np.allclose(box, 0.0): + box = None + else: + box = xp.asarray(box, device=device) + if fparam is not None: + fparam = xp.asarray(fparam, device=device) + if aparam is not None: + aparam = xp.asarray(aparam, device=device) + else: + # For numpy arrays + coord = xp.asarray(coord) + atype = xp.asarray(atype) + if box is not None: + if np.allclose(box, 0.0): + box = None + else: + box = xp.asarray(box) + if fparam is not None: + fparam = xp.asarray(fparam) + if aparam is not None: + aparam = xp.asarray(aparam) + + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord, + atype, + self.get_rcut(), + self.get_sel(), + mixed_types=self.mixed_types(), + box=box, + ) + atomic_ret = self.forward_common_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) + # Convert outputs back to numpy arrays + return {kk: to_numpy_array(vv) for kk, vv in atomic_ret.items()} + + return model_forward + def serialize(self) -> dict: return { "type_map": self.type_map, diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index dabbc34e01..cc730ddda6 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -121,7 +121,8 @@ def to_numpy_array(x: Optional["Array"]) -> np.ndarray | None: try: # asarray is not within Array API standard, so may fail return np.asarray(x) - except (ValueError, AttributeError, TypeError): + except (ValueError, AttributeError, TypeError, RuntimeError): + # RuntimeError: handles torch tensors with requires_grad=True xp = array_api_compat.array_namespace(x) # to fix BufferError: Cannot export readonly array since signalling readonly is unsupported by DLPack. # Move to CPU device to ensure numpy compatibility diff --git a/deepmd/dpmodel/descriptor/descriptor.py b/deepmd/dpmodel/descriptor/descriptor.py index 9b0e067972..ad49a7cb8d 100644 --- a/deepmd/dpmodel/descriptor/descriptor.py +++ b/deepmd/dpmodel/descriptor/descriptor.py @@ -12,7 +12,7 @@ NoReturn, ) -import numpy as np +import array_api_compat from deepmd.dpmodel.array_api import ( Array, @@ -173,7 +173,18 @@ def extend_descrpt_stat( extend_dstd = des_with_stat["dstd"] else: extend_shape = [len(type_map), *list(des["davg"].shape[1:])] - extend_davg = np.zeros(extend_shape, dtype=des["davg"].dtype) - extend_dstd = np.ones(extend_shape, dtype=des["dstd"].dtype) - des["davg"] = np.concatenate([des["davg"], extend_davg], axis=0) - des["dstd"] = np.concatenate([des["dstd"], extend_dstd], axis=0) + # Use array_api_compat to infer device and dtype from context + xp = array_api_compat.array_namespace(des["davg"]) + extend_davg = xp.zeros( + extend_shape, + dtype=des["davg"].dtype, + device=array_api_compat.device(des["davg"]), + ) + extend_dstd = xp.ones( + extend_shape, + dtype=des["dstd"].dtype, + device=array_api_compat.device(des["dstd"]), + ) + xp = array_api_compat.array_namespace(des["davg"]) + des["davg"] = xp.concat([des["davg"], extend_davg], axis=0) + des["dstd"] = xp.concat([des["dstd"], extend_dstd], axis=0) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index f09ab24dfe..2f9aa69b62 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -1049,6 +1049,8 @@ def call( idx_j = xp.reshape(nei_type, (-1,)) # (nf x nl x nnei) x ng idx = xp.tile(xp.reshape((idx_i + idx_j), (-1, 1)), (1, ng)) + # Cast to int64 for PyTorch backend (take_along_dim requires Long indices) + idx = xp.astype(idx, xp.int64) # (ntypes) * ntypes * nt type_embedding_nei = xp.tile( xp.reshape(type_embedding, (1, ntypes_with_padding, nt)), diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index 749a5da188..95b66759de 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -369,7 +369,11 @@ def call( sec = self.sel_cumsum ng = self.neuron[-1] - result = xp.zeros([nf * nloc, ng], dtype=get_xp_precision(xp, self.precision)) + result = xp.zeros( + [nf * nloc, ng], + dtype=get_xp_precision(xp, self.precision), + device=array_api_compat.device(coord_ext), + ) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) # merge nf and nloc axis, so for type_one_side == False, # we don't require atype is the same in all frames diff --git a/deepmd/dpmodel/descriptor/se_t_tebd.py b/deepmd/dpmodel/descriptor/se_t_tebd.py index 0a2d46c015..05c9dc77af 100644 --- a/deepmd/dpmodel/descriptor/se_t_tebd.py +++ b/deepmd/dpmodel/descriptor/se_t_tebd.py @@ -769,7 +769,9 @@ def call( sw = xp.where( nlist_mask[:, :, None], xp.reshape(sw, (nf * nloc, nnei, 1)), - xp.zeros((nf * nloc, nnei, 1), dtype=sw.dtype), + xp.zeros( + (nf * nloc, nnei, 1), dtype=sw.dtype, device=array_api_compat.device(sw) + ), ) # nfnl x nnei x 4 @@ -832,6 +834,8 @@ def call( # (nf x nl x nt_i x nt_j) x ng idx = xp.tile(xp.reshape((idx_i + idx_j), (-1, 1)), (1, ng)) + # Cast to int64 for PyTorch backend (take_along_dim requires Long indices) + idx = xp.astype(idx, xp.int64) # ntypes * (ntypes) * nt type_embedding_i = xp.tile( diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index a4089468f3..fabc39ae96 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -261,8 +261,18 @@ def compute_input_stats( fparam_std, ) fparam_inv_std = 1.0 / fparam_std - self.fparam_avg = fparam_avg.astype(self.fparam_avg.dtype) - self.fparam_inv_std = fparam_inv_std.astype(self.fparam_inv_std.dtype) + # Use array_api_compat to handle both numpy and torch + xp = array_api_compat.array_namespace(self.fparam_avg) + self.fparam_avg = xp.asarray( + fparam_avg, + dtype=self.fparam_avg.dtype, + device=array_api_compat.device(self.fparam_avg), + ) + self.fparam_inv_std = xp.asarray( + fparam_inv_std, + dtype=self.fparam_inv_std.dtype, + device=array_api_compat.device(self.fparam_inv_std), + ) # stat aparam if self.numb_aparam > 0: sys_sumv = [] @@ -284,8 +294,18 @@ def compute_input_stats( aparam_std, ) aparam_inv_std = 1.0 / aparam_std - self.aparam_avg = aparam_avg.astype(self.aparam_avg.dtype) - self.aparam_inv_std = aparam_inv_std.astype(self.aparam_inv_std.dtype) + # Use array_api_compat to handle both numpy and torch + xp = array_api_compat.array_namespace(self.aparam_avg) + self.aparam_avg = xp.asarray( + aparam_avg, + dtype=self.aparam_avg.dtype, + device=array_api_compat.device(self.aparam_avg), + ) + self.aparam_inv_std = xp.asarray( + aparam_inv_std, + dtype=self.aparam_inv_std.dtype, + device=array_api_compat.device(self.aparam_inv_std), + ) @abstractmethod def _net_out_dim(self) -> int: @@ -566,7 +586,9 @@ def _call_common( # calculate the prediction if not self.mixed_types: outs = xp.zeros( - [nf, nloc, net_dim_out], dtype=get_xp_precision(xp, self.precision) + [nf, nloc, net_dim_out], + dtype=get_xp_precision(xp, self.precision), + device=array_api_compat.device(descriptor), ) for type_i in range(self.ntypes): mask = xp.tile( diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 100a0c13b6..4679412d4b 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -1009,7 +1009,121 @@ def deserialize(cls, data: dict) -> "FittingNet": return FN -FittingNet = make_fitting_network(EmbeddingNet, NativeNet, NativeLayer) +class FittingNet(EmbeddingNet): + """The fitting network. It may be implemented as an embedding + net connected with a linear output layer. + + Parameters + ---------- + in_dim + Input dimension. + out_dim + Output dimension + neuron + The number of neurons in each hidden layer. + activation_function + The activation function. + resnet_dt + Use time step at the resnet architecture. + precision + Floating point precision for the model parameters. + bias_out + The last linear layer has bias. + seed : int, optional + Random seed. + trainable : bool or list[bool], optional + Whether the network is trainable. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + neuron: list[int] = [24, 48, 96], + activation_function: str = "tanh", + resnet_dt: bool = False, + precision: str = DEFAULT_PRECISION, + bias_out: bool = True, + seed: int | list[int] | None = None, + trainable: bool | list[bool] = True, + ) -> None: + if trainable is None: + trainable = [True] * (len(neuron) + 1) + elif isinstance(trainable, bool): + trainable = [trainable] * (len(neuron) + 1) + else: + pass + super().__init__( + in_dim, + neuron=neuron, + activation_function=activation_function, + resnet_dt=resnet_dt, + precision=precision, + seed=seed, + trainable=trainable[:-1], + ) + i_in = neuron[-1] if len(neuron) > 0 else in_dim + i_ot = out_dim + self.layers.append( + NativeLayer( + i_in, + i_ot, + bias=bias_out, + use_timestep=False, + activation_function=None, + resnet=False, + precision=precision, + seed=child_seed(seed, len(neuron)), + trainable=trainable[-1], + ) + ) + self.out_dim = out_dim + self.bias_out = bias_out + + def serialize(self) -> dict: + """Serialize the network to a dict. + + Returns + ------- + dict + The serialized network. + """ + return { + "@class": "FittingNetwork", + "@version": 1, + "in_dim": self.in_dim, + "out_dim": self.out_dim, + "neuron": self.neuron.copy(), + "activation_function": self.activation_function, + "resnet_dt": self.resnet_dt, + "precision": self.precision, + "bias_out": self.bias_out, + "layers": [layer.serialize() for layer in self.layers], + } + + @classmethod + def deserialize(cls, data: dict) -> "FittingNet": + """Deserialize the network from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("@class", None) + layers = data.pop("layers") + obj = cls(**data) + # Use type(obj.layers[0]) to respect subclass layer types + if obj.layers: + layer_type = type(obj.layers[0]) + obj.layers = type(obj.layers)( + [layer_type.deserialize(layer) for layer in layers] + ) + else: + obj.layers = type(obj.layers)([]) + return obj class NetworkCollection: diff --git a/deepmd/dpmodel/utils/type_embed.py b/deepmd/dpmodel/utils/type_embed.py index 8a45f964f8..bc1146203d 100644 --- a/deepmd/dpmodel/utils/type_embed.py +++ b/deepmd/dpmodel/utils/type_embed.py @@ -100,11 +100,21 @@ def call(self) -> Array: sample_array = self.embedding_net[0]["w"] xp = array_api_compat.array_namespace(sample_array) if not self.use_econf_tebd: - embed = self.embedding_net(xp.eye(self.ntypes, dtype=sample_array.dtype)) + embed = self.embedding_net( + xp.eye( + self.ntypes, + dtype=sample_array.dtype, + device=array_api_compat.device(sample_array), + ) + ) else: embed = self.embedding_net(self.econf_tebd) if self.padding: - embed_pad = xp.zeros((1, embed.shape[-1]), dtype=embed.dtype) + embed_pad = xp.zeros( + (1, embed.shape[-1]), + dtype=embed.dtype, + device=array_api_compat.device(embed), + ) embed = xp.concat([embed, embed_pad], axis=0) return embed @@ -180,32 +190,51 @@ def change_type_map( "'activation_function' must be 'Linear' when performing type changing on resnet structure!" ) first_layer_matrix = self.embedding_net.layers[0].w - eye_vector = np.eye(self.ntypes, dtype=PRECISION_DICT[self.precision]) + # Use array_api_compat to handle both numpy and torch + xp = array_api_compat.array_namespace(first_layer_matrix) + eye_vector = xp.eye( + self.ntypes, + dtype=first_layer_matrix.dtype, + device=array_api_compat.device(first_layer_matrix), + ) # preprocess for resnet connection if self.neuron[0] == self.ntypes: - first_layer_matrix += eye_vector + first_layer_matrix = first_layer_matrix + eye_vector elif self.neuron[0] == self.ntypes * 2: - first_layer_matrix += np.concatenate([eye_vector, eye_vector], axis=-1) + first_layer_matrix = first_layer_matrix + xp.concat( + [eye_vector, eye_vector], axis=-1 + ) # randomly initialize params for the unseen types - rng = np.random.default_rng() if has_new_type: - extend_type_params = rng.random( + # Create random params with same dtype and device as first_layer_matrix + extend_type_params = np.random.default_rng().random( [len(type_map), first_layer_matrix.shape[-1]], + dtype=PRECISION_DICT[self.precision], + ) + extend_type_params = xp.asarray( + extend_type_params, dtype=first_layer_matrix.dtype, + device=array_api_compat.device(first_layer_matrix), ) - first_layer_matrix = np.concatenate( + first_layer_matrix = xp.concat( [first_layer_matrix, extend_type_params], axis=0 ) first_layer_matrix = first_layer_matrix[remap_index] new_ntypes = len(type_map) - eye_vector = np.eye(new_ntypes, dtype=PRECISION_DICT[self.precision]) + eye_vector = xp.eye( + new_ntypes, + dtype=first_layer_matrix.dtype, + device=array_api_compat.device(first_layer_matrix), + ) if self.neuron[0] == new_ntypes: - first_layer_matrix -= eye_vector + first_layer_matrix = first_layer_matrix - eye_vector elif self.neuron[0] == new_ntypes * 2: - first_layer_matrix -= np.concatenate([eye_vector, eye_vector], axis=-1) + first_layer_matrix = first_layer_matrix - xp.concat( + [eye_vector, eye_vector], axis=-1 + ) self.embedding_net.layers[0].num_in = new_ntypes self.embedding_net.layers[0].w = first_layer_matrix diff --git a/deepmd/pt_expt/atomic_model/__init__.py b/deepmd/pt_expt/atomic_model/__init__.py new file mode 100644 index 0000000000..51ee9f4186 --- /dev/null +++ b/deepmd/pt_expt/atomic_model/__init__.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .dp_atomic_model import ( + DPAtomicModel, +) +from .energy_atomic_model import ( + DPEnergyAtomicModel, +) + +__all__ = [ + "DPAtomicModel", + "DPEnergyAtomicModel", +] diff --git a/deepmd/pt_expt/atomic_model/dp_atomic_model.py b/deepmd/pt_expt/atomic_model/dp_atomic_model.py new file mode 100644 index 0000000000..5c00192661 --- /dev/null +++ b/deepmd/pt_expt/atomic_model/dp_atomic_model.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.atomic_model.dp_atomic_model import DPAtomicModel as DPAtomicModelDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, + register_dpmodel_mapping, +) + + +class DPAtomicModel(DPAtomicModelDP, torch.nn.Module): + # Import at class level to set base classes for deserialization + # These will be used by the dpmodel deserialize method to create pt_expt instances + from deepmd.pt_expt.descriptor.base_descriptor import ( + BaseDescriptor, + ) + from deepmd.pt_expt.fitting.base_fitting import ( + BaseFitting, + ) + + base_descriptor_cls = BaseDescriptor + base_fitting_cls = BaseFitting + + def __init__( + self, descriptor: Any, fitting: Any, *args: Any, **kwargs: Any + ) -> None: + torch.nn.Module.__init__(self) + # Convert descriptor and fitting to pt_expt versions if they are dpmodel instances + # The dpmodel_setattr mechanism will handle this automatically via registry + from deepmd.pt_expt.common import ( + try_convert_module, + ) + + descriptor_pt = try_convert_module(descriptor) + fitting_pt = try_convert_module(fitting) + # If conversion failed (not registered), use original (assume already pt_expt) + if descriptor_pt is None: + descriptor_pt = descriptor + if fitting_pt is None: + fitting_pt = fitting + DPAtomicModelDP.__init__(self, descriptor_pt, fitting_pt, *args, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + return self.forward_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) + + +register_dpmodel_mapping( + DPAtomicModelDP, + lambda v: DPAtomicModel.deserialize(v.serialize()), +) diff --git a/deepmd/pt_expt/atomic_model/energy_atomic_model.py b/deepmd/pt_expt/atomic_model/energy_atomic_model.py new file mode 100644 index 0000000000..5f34d215cf --- /dev/null +++ b/deepmd/pt_expt/atomic_model/energy_atomic_model.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.atomic_model.energy_atomic_model import ( + DPEnergyAtomicModel as DPEnergyAtomicModelDP, +) +from deepmd.pt_expt.common import ( + register_dpmodel_mapping, +) + +from .dp_atomic_model import ( + DPAtomicModel, +) + + +class DPEnergyAtomicModel(DPAtomicModel): + """Energy atomic model for pt_expt backend. + + This is a thin wrapper around DPAtomicModel that validates + the fitting is an EnergyFittingNet or InvarFitting. + """ + + pass + + +register_dpmodel_mapping( + DPEnergyAtomicModelDP, + lambda v: DPEnergyAtomicModel.deserialize(v.serialize()), +) diff --git a/deepmd/pt_expt/descriptor/__init__.py b/deepmd/pt_expt/descriptor/__init__.py index 4d9469a93a..7feda7d703 100644 --- a/deepmd/pt_expt/descriptor/__init__.py +++ b/deepmd/pt_expt/descriptor/__init__.py @@ -1,4 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +# Import to register converters +from . import se_t_tebd_block # noqa: F401 from .base_descriptor import ( BaseDescriptor, ) @@ -8,9 +10,17 @@ from .se_r import ( DescrptSeR, ) +from .se_t import ( + DescrptSeT, +) +from .se_t_tebd import ( + DescrptSeTTebd, +) __all__ = [ "BaseDescriptor", "DescrptSeA", "DescrptSeR", + "DescrptSeT", + "DescrptSeTTebd", ] diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 1ccb4d2dda..f8a98abd86 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -14,8 +14,8 @@ ) -@BaseDescriptor.register("se_e2_a_expt") -@BaseDescriptor.register("se_a_expt") +@BaseDescriptor.register("se_e2_a") +@BaseDescriptor.register("se_a") class DescrptSeA(DescrptSeADP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) @@ -35,9 +35,7 @@ def forward( extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, - extended_atype_embd: torch.Tensor | None = None, mapping: torch.Tensor | None = None, - type_embedding: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, @@ -45,7 +43,6 @@ def forward( torch.Tensor | None, torch.Tensor | None, ]: - del extended_atype_embd, type_embedding descrpt, rot_mat, g2, h2, sw = self.call( extended_coord, extended_atype, diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index 7a406fb499..0484c0dea4 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -14,8 +14,8 @@ ) -@BaseDescriptor.register("se_e2_r_expt") -@BaseDescriptor.register("se_r_expt") +@BaseDescriptor.register("se_e2_r") +@BaseDescriptor.register("se_r") class DescrptSeR(DescrptSeRDP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) @@ -35,9 +35,7 @@ def forward( extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, - extended_atype_embd: torch.Tensor | None = None, mapping: torch.Tensor | None = None, - type_embedding: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, @@ -45,7 +43,6 @@ def forward( torch.Tensor | None, torch.Tensor | None, ]: - del extended_atype_embd, type_embedding descrpt, rot_mat, g2, h2, sw = self.call( extended_coord, extended_atype, diff --git a/deepmd/pt_expt/descriptor/se_t.py b/deepmd/pt_expt/descriptor/se_t.py new file mode 100644 index 0000000000..6d732790ca --- /dev/null +++ b/deepmd/pt_expt/descriptor/se_t.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, +) +from deepmd.pt_expt.descriptor.base_descriptor import ( + BaseDescriptor, +) + + +@BaseDescriptor.register("se_e3") +@BaseDescriptor.register("se_at") +@BaseDescriptor.register("se_a_3be") +class DescrptSeT(DescrptSeTDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + DescrptSeTDP.__init__(self, *args, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + descrpt, rot_mat, g2, h2, sw = self.call( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + return descrpt, rot_mat, g2, h2, sw diff --git a/deepmd/pt_expt/descriptor/se_t_tebd.py b/deepmd/pt_expt/descriptor/se_t_tebd.py new file mode 100644 index 0000000000..f28e1564cc --- /dev/null +++ b/deepmd/pt_expt/descriptor/se_t_tebd.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, +) +from deepmd.pt_expt.descriptor.base_descriptor import ( + BaseDescriptor, +) + + +@BaseDescriptor.register("se_e3_tebd") +class DescrptSeTTebd(DescrptSeTTebdDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + DescrptSeTTebdDP.__init__(self, *args, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + descrpt, rot_mat, g2, h2, sw = self.call( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + return descrpt, rot_mat, g2, h2, sw diff --git a/deepmd/pt_expt/descriptor/se_t_tebd_block.py b/deepmd/pt_expt/descriptor/se_t_tebd_block.py new file mode 100644 index 0000000000..7a0faaf170 --- /dev/null +++ b/deepmd/pt_expt/descriptor/se_t_tebd_block.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.descriptor.se_t_tebd import ( + DescrptBlockSeTTebd as DescrptBlockSeTTebdDP, +) +from deepmd.pt_expt.common import ( + dpmodel_setattr, + register_dpmodel_mapping, +) + + +class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + DescrptBlockSeTTebdDP.__init__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + +register_dpmodel_mapping( + DescrptBlockSeTTebdDP, + lambda v: DescrptBlockSeTTebd.deserialize(v.serialize()), +) diff --git a/deepmd/pt_expt/fitting/__init__.py b/deepmd/pt_expt/fitting/__init__.py new file mode 100644 index 0000000000..4a7c8100de --- /dev/null +++ b/deepmd/pt_expt/fitting/__init__.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .base_fitting import ( + BaseFitting, +) +from .ener_fitting import ( + EnergyFittingNet, +) +from .invar_fitting import ( + InvarFitting, +) + +__all__ = [ + "BaseFitting", + "EnergyFittingNet", + "InvarFitting", +] diff --git a/deepmd/pt_expt/fitting/base_fitting.py b/deepmd/pt_expt/fitting/base_fitting.py new file mode 100644 index 0000000000..f42e572578 --- /dev/null +++ b/deepmd/pt_expt/fitting/base_fitting.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import torch + +from deepmd.dpmodel.fitting import ( + make_base_fitting, +) + +BaseFitting = make_base_fitting(torch.Tensor, "forward") diff --git a/deepmd/pt_expt/fitting/ener_fitting.py b/deepmd/pt_expt/fitting/ener_fitting.py new file mode 100644 index 0000000000..425040ae75 --- /dev/null +++ b/deepmd/pt_expt/fitting/ener_fitting.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, + register_dpmodel_mapping, +) +from deepmd.pt_expt.utils.network import ( + NetworkCollection, +) + +from .base_fitting import ( + BaseFitting, +) + + +@BaseFitting.register("ener") +class EnergyFittingNet(EnergyFittingNetDP, torch.nn.Module): + """Energy fitting net for pt_expt backend. + + This inherits from dpmodel EnergyFittingNet to get the correct serialize() method. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + EnergyFittingNetDP.__init__(self, *args, **kwargs) + # Convert dpmodel NetworkCollection to pt_expt NetworkCollection + self.nets = NetworkCollection.deserialize(self.nets.serialize()) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: torch.Tensor | None = None, + g2: torch.Tensor | None = None, + h2: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + return self.call( + descriptor, + atype, + gr=gr, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + + +register_dpmodel_mapping( + EnergyFittingNetDP, + lambda v: EnergyFittingNet.deserialize(v.serialize()), +) diff --git a/deepmd/pt_expt/fitting/invar_fitting.py b/deepmd/pt_expt/fitting/invar_fitting.py new file mode 100644 index 0000000000..aa37026284 --- /dev/null +++ b/deepmd/pt_expt/fitting/invar_fitting.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.fitting.invar_fitting import InvarFitting as InvarFittingDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, + register_dpmodel_mapping, +) +from deepmd.pt_expt.fitting.base_fitting import ( + BaseFitting, +) +from deepmd.pt_expt.utils.network import ( + NetworkCollection, +) + + +@BaseFitting.register("invar") +class InvarFitting(InvarFittingDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + InvarFittingDP.__init__(self, *args, **kwargs) + # Convert dpmodel NetworkCollection to pt_expt NetworkCollection + self.nets = NetworkCollection.deserialize(self.nets.serialize()) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: torch.Tensor | None = None, + g2: torch.Tensor | None = None, + h2: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + return self.call( + descriptor, + atype, + gr=gr, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + + +register_dpmodel_mapping( + InvarFittingDP, + lambda v: InvarFitting.deserialize(v.serialize()), +) diff --git a/deepmd/pt_expt/utils/__init__.py b/deepmd/pt_expt/utils/__init__.py index bcd3d4450a..f32ec66c54 100644 --- a/deepmd/pt_expt/utils/__init__.py +++ b/deepmd/pt_expt/utils/__init__.py @@ -14,6 +14,9 @@ from .network import ( NetworkCollection, ) +from .type_embed import ( + TypeEmbedNet, +) # Register EnvMat with identity converter - it doesn't need wrapping # as it's a stateless utility class @@ -23,4 +26,5 @@ "AtomExcludeMask", "NetworkCollection", "PairExcludeMask", + "TypeEmbedNet", ] diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index b115214056..ee957316c9 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -11,11 +11,11 @@ NativeOP, ) from deepmd.dpmodel.utils.network import EmbeddingNet as EmbeddingNetDP +from deepmd.dpmodel.utils.network import FittingNet as FittingNetDP from deepmd.dpmodel.utils.network import LayerNorm as LayerNormDP from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP from deepmd.dpmodel.utils.network import NetworkCollection as NetworkCollectionDP from deepmd.dpmodel.utils.network import ( - make_fitting_network, make_multilayer_network, ) from deepmd.pt_expt.common import ( @@ -114,8 +114,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) -class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)): - pass +class FittingNet(FittingNetDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + FittingNetDP.__init__(self, *args, **kwargs) + # Convert dpmodel layers to pt_expt NativeLayer + self.layers = torch.nn.ModuleList( + [NativeLayer.deserialize(layer.serialize()) for layer in self.layers] + ) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return torch.nn.Module.__call__(self, *args, **kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.call(x) + + +register_dpmodel_mapping( + FittingNetDP, + lambda v: FittingNet.deserialize(v.serialize()), +) class NetworkCollection(NetworkCollectionDP, torch.nn.Module): diff --git a/deepmd/pt_expt/utils/type_embed.py b/deepmd/pt_expt/utils/type_embed.py new file mode 100644 index 0000000000..da4cf09028 --- /dev/null +++ b/deepmd/pt_expt/utils/type_embed.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, + register_dpmodel_mapping, +) + +# Import network to ensure EmbeddingNet is registered before TypeEmbedNet is used +from deepmd.pt_expt.utils import network # noqa: F401 + + +class TypeEmbedNet(TypeEmbedNetDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + TypeEmbedNetDP.__init__(self, *args, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + # Use common dpmodel_setattr which handles embedding_net conversion via registry + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + def forward(self) -> torch.Tensor: + # Call dpmodel's implementation (now with proper device handling) + return self.call() + + +register_dpmodel_mapping( + TypeEmbedNetDP, + lambda v: TypeEmbedNet.deserialize(v.serialize()), +) diff --git a/source/tests/common/dpmodel/test_network.py b/source/tests/common/dpmodel/test_network.py index a63d4f356a..207355a7f9 100644 --- a/source/tests/common/dpmodel/test_network.py +++ b/source/tests/common/dpmodel/test_network.py @@ -313,6 +313,104 @@ def test_fitting_net(self) -> None: en1.call(inp) np.testing.assert_allclose(en0.call(inp), en1.call(inp)) + def test_is_concrete_class(self) -> None: + """Verify FittingNet is a concrete class, not factory-generated.""" + in_dim = 4 + out_dim = 1 + neuron = [8, 16] + net = FittingNet( + in_dim=in_dim, + out_dim=out_dim, + neuron=neuron, + activation_function="tanh", + resnet_dt=True, + precision="float64", + bias_out=True, + ) + # Check it's the actual FittingNet class, not a dynamic class + self.assertEqual(net.__class__.__name__, "FittingNet") + self.assertEqual(net.__class__.__module__, "deepmd.dpmodel.utils.network") + # Verify it has the expected attributes + self.assertEqual(net.in_dim, in_dim) + self.assertEqual(net.out_dim, out_dim) + self.assertEqual(net.neuron, neuron) + self.assertEqual(net.activation_function, "tanh") + self.assertEqual(net.resnet_dt, True) + self.assertEqual(net.bias_out, True) + # FittingNet has len(neuron) embedding layers + 1 output layer + self.assertEqual(len(net.layers), len(neuron) + 1) + + def test_forward_pass(self) -> None: + """Test FittingNet forward pass produces correct output shape.""" + in_dim = 4 + out_dim = 3 + neuron = [8, 16, 32] + net = FittingNet( + in_dim=in_dim, + out_dim=out_dim, + neuron=neuron, + activation_function="tanh", + resnet_dt=True, + precision="float64", + ) + # Single sample + rng = np.random.default_rng() + x = rng.standard_normal(in_dim) + out = net.call(x) + self.assertEqual(out.shape, (out_dim,)) + + # Batch of samples + batch_size = 5 + x_batch = rng.standard_normal((batch_size, in_dim)) + out_batch = net.call(x_batch) + self.assertEqual(out_batch.shape, (batch_size, out_dim)) + + def test_trainable_parameter_variants(self) -> None: + """Test FittingNet with different trainable configurations.""" + in_dim = 4 + out_dim = 2 + neuron = [8, 16] + + # Test 1: All layers trainable (default) + net_all_trainable = FittingNet( + in_dim=in_dim, + out_dim=out_dim, + neuron=neuron, + trainable=True, + ) + for layer in net_all_trainable.layers: + self.assertTrue(layer.trainable) + + # Test 2: All layers frozen + net_all_frozen = FittingNet( + in_dim=in_dim, + out_dim=out_dim, + neuron=neuron, + trainable=False, + ) + for layer in net_all_frozen.layers: + self.assertFalse(layer.trainable) + + # Test 3: Mixed trainable (embedding layers frozen, output layer trainable) + trainable_list = [False, False, True] # 2 embedding layers + 1 output layer + net_mixed = FittingNet( + in_dim=in_dim, + out_dim=out_dim, + neuron=neuron, + trainable=trainable_list, + ) + self.assertFalse(net_mixed.layers[0].trainable) # First embedding layer + self.assertFalse(net_mixed.layers[1].trainable) # Second embedding layer + self.assertTrue(net_mixed.layers[2].trainable) # Output layer + + # Test 4: Serialize/deserialize preserves trainable + serialized = net_mixed.serialize() + net_restored = FittingNet.deserialize(serialized) + for orig_layer, restored_layer in zip( + net_mixed.layers, net_restored.layers, strict=True + ): + self.assertEqual(orig_layer.trainable, restored_layer.trainable) + class TestNetworkCollection(unittest.TestCase): def setUp(self) -> None: diff --git a/source/tests/consistent/descriptor/test_se_t.py b/source/tests/consistent/descriptor/test_se_t.py index 49a948c39f..df03f270f5 100644 --- a/source/tests/consistent/descriptor/test_se_t.py +++ b/source/tests/consistent/descriptor/test_se_t.py @@ -15,6 +15,7 @@ INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, parameterized, @@ -27,6 +28,10 @@ from deepmd.pt.model.descriptor.se_t import DescrptSeT as DescrptSeTPT else: DescrptSeTPT = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.descriptor.se_t import DescrptSeT as DescrptSeTPTExpt +else: + DescrptSeTPTExpt = None if INSTALLED_TF: from deepmd.tf.descriptor.se_t import DescrptSeT as DescrptSeTTF else: @@ -92,6 +97,16 @@ def skip_dp(self) -> bool: ) = self.param return CommonTest.skip_dp + @property + def skip_pt_expt(self) -> bool: + ( + resnet_dt, + excluded_types, + precision, + env_protection, + ) = self.param + return CommonTest.skip_pt_expt + @property def skip_tf(self) -> bool: ( @@ -108,6 +123,7 @@ def skip_tf(self) -> bool: tf_class = DescrptSeTTF dp_class = DescrptSeTDP pt_class = DescrptSeTPT + pt_expt_class = DescrptSeTPTExpt jax_class = DescrptSeTJAX array_api_strict_class = DescrptSeTStrict args = descrpt_se_t_args() @@ -184,6 +200,15 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def eval_jax(self, jax_obj: Any) -> Any: return self.eval_jax_descriptor( jax_obj, diff --git a/source/tests/consistent/descriptor/test_se_t_tebd.py b/source/tests/consistent/descriptor/test_se_t_tebd.py index e53cd88311..7d33679e69 100644 --- a/source/tests/consistent/descriptor/test_se_t_tebd.py +++ b/source/tests/consistent/descriptor/test_se_t_tebd.py @@ -19,6 +19,7 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, CommonTest, parameterized, ) @@ -30,6 +31,12 @@ from deepmd.pt.model.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdPT else: DescrptSeTTebdPT = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.descriptor.se_t_tebd import ( + DescrptSeTTebd as DescrptSeTTebdPTExpt, + ) +else: + DescrptSeTTebdPTExpt = None DescrptSeTTebdTF = None if INSTALLED_JAX: from deepmd.jax.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdJAX @@ -118,6 +125,23 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + @property + def skip_pt_expt(self) -> bool: + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + excluded_types, + env_protection, + set_davg_zero, + smooth, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return CommonTest.skip_pt_expt + @property def skip_dp(self) -> bool: ( @@ -159,6 +183,7 @@ def skip_tf(self) -> bool: tf_class = DescrptSeTTebdTF dp_class = DescrptSeTTebdDP pt_class = DescrptSeTTebdPT + pt_expt_class = DescrptSeTTebdPTExpt pd_class = DescrptSeTTebdPD jax_class = DescrptSeTTebdJAX array_api_strict_class = DescrptSeTTebdStrict @@ -241,6 +266,16 @@ def eval_pt(self, pt_obj: Any) -> Any: mixed_types=True, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + def eval_jax(self, jax_obj: Any) -> Any: return self.eval_jax_descriptor( jax_obj, diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index 74e3b042ab..185a3d5801 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -19,6 +19,7 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, parameterized, @@ -34,6 +35,13 @@ from deepmd.pt.utils.env import DEVICE as PT_DEVICE else: EnerFittingPT = object +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.fitting.ener_fitting import ( + EnergyFittingNet as EnerFittingPTExpt, + ) + from deepmd.pt_expt.utils.env import DEVICE as PT_EXPT_DEVICE +else: + EnerFittingPTExpt = None if INSTALLED_TF: from deepmd.tf.fit.ener import EnerFitting as EnerFittingTF else: @@ -151,9 +159,23 @@ def skip_tf(self) -> bool: ) = self.param return not INSTALLED_TF or default_fparam is not None + @property + def skip_pt_expt(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + # PyTorch does not support bfloat16 for some operations + return CommonTest.skip_pt_expt or precision == "bfloat16" + tf_class = EnerFittingTF dp_class = EnerFittingDP pt_class = EnerFittingPT + pt_expt_class = EnerFittingPTExpt jax_class = EnerFittingJAX pd_class = EnerFittingPD array_api_strict_class = EnerFittingStrict @@ -237,6 +259,35 @@ def eval_pt(self, pt_obj: Any) -> Any: .numpy() ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + return ( + pt_expt_obj( + torch.from_numpy(self.inputs).to(device=PT_EXPT_DEVICE), + torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_EXPT_DEVICE), + fparam=( + torch.from_numpy(self.fparam).to(device=PT_EXPT_DEVICE) + if (numb_fparam and default_fparam is None) # test default_fparam + else None + ), + aparam=( + torch.from_numpy(self.aparam).to(device=PT_EXPT_DEVICE) + if numb_aparam + else None + ), + )["energy"] + .detach() + .cpu() + .numpy() + ) + def eval_dp(self, dp_obj: Any) -> Any: ( resnet_dt, @@ -367,3 +418,377 @@ def atol(self) -> float: return 1e-1 else: raise ValueError(f"Unknown precision: {precision}") + + +@parameterized( + (True,), # resnet_dt + ("float64",), # precision + (True,), # mixed_types + ((3, None),), # (numb_fparam, default_fparam) + ((3, False),), # (numb_aparam, use_aparam_as_mask) + ([],), # atom_ener +) +class TestEnerStat(CommonTest, FittingTest, unittest.TestCase): + @property + def data(self) -> dict: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + return { + "neuron": [5, 5, 5], + "resnet_dt": resnet_dt, + "precision": precision, + "numb_fparam": numb_fparam, + "numb_aparam": numb_aparam, + "default_fparam": default_fparam, + "seed": 20240217, + "atom_ener": atom_ener, + "use_aparam_as_mask": use_aparam_as_mask, + } + + @property + def skip_pt(self) -> bool: + return CommonTest.skip_pt + + @property + def skip_pt_expt(self) -> bool: + return CommonTest.skip_pt_expt + + @property + def skip_tf(self) -> bool: + return True + + skip_jax = not INSTALLED_JAX + + @property + def skip_array_api_strict(self) -> bool: + return not INSTALLED_ARRAY_API_STRICT + + @property + def skip_pd(self) -> bool: + return not INSTALLED_PD + + tf_class = EnerFittingTF + dp_class = EnerFittingDP + pt_class = EnerFittingPT + pt_expt_class = EnerFittingPTExpt + jax_class = EnerFittingJAX + pd_class = EnerFittingPD + array_api_strict_class = EnerFittingStrict + args = fitting_ener() + + def setUp(self) -> None: + CommonTest.setUp(self) + + self.ntypes = 2 + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + self.inputs = np.ones((1, 6, 20), dtype=GLOBAL_NP_FLOAT_PRECISION) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + # inconsistent if not sorted + self.atype.sort() + + # Prepare data for compute_input_stats + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + + # Create fparam and aparam with correct dimensions + rng = np.random.default_rng(20240217) + self.fparam = ( + rng.normal(size=(1, numb_fparam)).astype(GLOBAL_NP_FLOAT_PRECISION) + if numb_fparam > 0 + else None + ) + self.aparam = ( + rng.normal(size=(1, 6, numb_aparam)).astype(GLOBAL_NP_FLOAT_PRECISION) + if numb_aparam > 0 + else None + ) + + self.stat_data = [ + { + "fparam": rng.normal(size=(2, numb_fparam)).astype( + GLOBAL_NP_FLOAT_PRECISION + ), + "aparam": rng.normal(size=(2, 6, numb_aparam)).astype( + GLOBAL_NP_FLOAT_PRECISION + ), + }, + { + "fparam": rng.normal(size=(3, numb_fparam)).astype( + GLOBAL_NP_FLOAT_PRECISION + ), + "aparam": rng.normal(size=(3, 6, numb_aparam)).astype( + GLOBAL_NP_FLOAT_PRECISION + ), + }, + ] + + @property + def additional_data(self) -> dict: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + return { + "ntypes": self.ntypes, + "dim_descrpt": self.inputs.shape[-1], + "mixed_types": mixed_types, + } + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + return self.build_tf_fitting( + obj, + self.inputs.ravel(), + self.natoms, + self.atype, + self.fparam if numb_fparam else None, + self.aparam if numb_aparam else None, + suffix, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + # Convert stat_data to torch tensors for pt backend + pt_stat_data = [ + { + "fparam": torch.from_numpy(d["fparam"]).to(PT_DEVICE), + "aparam": torch.from_numpy(d["aparam"]).to(PT_DEVICE), + } + for d in self.stat_data + ] + pt_obj.compute_input_stats(pt_stat_data, protection=1e-2) + return ( + pt_obj( + torch.from_numpy(self.inputs).to(device=PT_DEVICE), + torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_DEVICE), + fparam=( + torch.from_numpy(self.fparam).to(device=PT_DEVICE) + if self.fparam is not None + else None + ), + aparam=( + torch.from_numpy(self.aparam).to(device=PT_DEVICE) + if self.aparam is not None + else None + ), + )["energy"] + .detach() + .cpu() + .numpy() + ) + + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + # dpmodel's compute_input_stats accepts numpy arrays + pt_expt_obj.compute_input_stats(self.stat_data, protection=1e-2) + return ( + pt_expt_obj( + torch.from_numpy(self.inputs).to(device=PT_EXPT_DEVICE), + torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_EXPT_DEVICE), + fparam=( + torch.from_numpy(self.fparam).to(device=PT_EXPT_DEVICE) + if self.fparam is not None + else None + ), + aparam=( + torch.from_numpy(self.aparam).to(device=PT_EXPT_DEVICE) + if self.aparam is not None + else None + ), + )["energy"] + .detach() + .cpu() + .numpy() + ) + + def eval_dp(self, dp_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + dp_obj.compute_input_stats(self.stat_data, protection=1e-2) + return dp_obj( + self.inputs, + self.atype.reshape(1, -1), + fparam=self.fparam, + aparam=self.aparam, + )["energy"] + + def eval_jax(self, jax_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + # Convert stat_data to jax arrays + jax_stat_data = [ + { + "fparam": jnp.asarray(d["fparam"]), + "aparam": jnp.asarray(d["aparam"]), + } + for d in self.stat_data + ] + jax_obj.compute_input_stats(jax_stat_data, protection=1e-2) + return np.asarray( + jax_obj( + jnp.asarray(self.inputs), + jnp.asarray(self.atype.reshape(1, -1)), + fparam=jnp.asarray(self.fparam) if self.fparam is not None else None, + aparam=jnp.asarray(self.aparam) if self.aparam is not None else None, + )["energy"] + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + # Convert stat_data to array_api_strict arrays + strict_stat_data = [ + { + "fparam": array_api_strict.asarray(d["fparam"]), + "aparam": array_api_strict.asarray(d["aparam"]), + } + for d in self.stat_data + ] + array_api_strict_obj.compute_input_stats(strict_stat_data, protection=1e-2) + return to_numpy_array( + array_api_strict_obj( + array_api_strict.asarray(self.inputs), + array_api_strict.asarray(self.atype.reshape(1, -1)), + fparam=array_api_strict.asarray(self.fparam) + if self.fparam is not None + else None, + aparam=array_api_strict.asarray(self.aparam) + if self.aparam is not None + else None, + )["energy"] + ) + + def eval_pd(self, pd_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + # Convert stat_data to paddle tensors + pd_stat_data = [ + { + "fparam": paddle.to_tensor(d["fparam"]).to(PD_DEVICE), + "aparam": paddle.to_tensor(d["aparam"]).to(PD_DEVICE), + } + for d in self.stat_data + ] + pd_obj.compute_input_stats(pd_stat_data, protection=1e-2) + return ( + pd_obj( + paddle.to_tensor(self.inputs).to(device=PD_DEVICE), + paddle.to_tensor(self.atype.reshape([1, -1])).to(device=PD_DEVICE), + fparam=( + paddle.to_tensor(self.fparam).to(device=PD_DEVICE) + if self.fparam is not None + else None + ), + aparam=( + paddle.to_tensor(self.aparam).to(device=PD_DEVICE) + if self.aparam is not None + else None + ), + )["energy"] + .detach() + .cpu() + .numpy() + ) + + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + if backend == self.RefBackend.TF: + # shape is not same + ret = ret[0].reshape(-1, self.natoms[0], 1) + return (ret,) + + @property + def rtol(self) -> float: + """Relative tolerance for comparing the return value.""" + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") + + @property + def atol(self) -> float: + """Absolute tolerance for comparing the return value.""" + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") diff --git a/source/tests/pt_expt/atomic_model/__init__.py b/source/tests/pt_expt/atomic_model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt_expt/atomic_model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt_expt/atomic_model/test_atomic_model_atomic_stat.py b/source/tests/pt_expt/atomic_model/test_atomic_model_atomic_stat.py new file mode 100644 index 0000000000..c393ad4b3b --- /dev/null +++ b/source/tests/pt_expt/atomic_model/test_atomic_model_atomic_stat.py @@ -0,0 +1,471 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tempfile +import unittest +from pathlib import ( + Path, +) +from typing import ( + NoReturn, +) + +import h5py +import numpy as np +import torch + +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pt_expt.atomic_model import ( + DPAtomicModel, +) +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting.base_fitting import ( + BaseFitting, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.utils.path import ( + DPPath, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + + +class FooFitting(BaseFitting, torch.nn.Module): + """Test fitting that returns fixed values for testing bias computation.""" + + def __init__(self): + torch.nn.Module.__init__(self) + BaseFitting.__init__(self) + + def output_def(self): + return FittingOutputDef( + [ + OutputVariableDef( + "foo", + [1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + OutputVariableDef( + "bar", + [1, 2], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + ] + ) + + def serialize(self) -> dict: + return { + "@class": "Fitting", + "type": "foo", + "@version": 1, + } + + @classmethod + def deserialize(cls, data: dict): + return cls() + + def get_dim_fparam(self) -> int: + return 0 + + def get_dim_aparam(self) -> int: + return 0 + + def get_sel_type(self) -> list[int]: + return [] + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat=None + ) -> None: + pass + + def get_type_map(self) -> list[str]: + return [] + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: torch.Tensor | None = None, + g2: torch.Tensor | None = None, + h2: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ): + nf, nloc, _ = descriptor.shape + ret = {} + ret["foo"] = ( + torch.Tensor( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ) + .view([nf, nloc, *self.output_def()["foo"].shape]) + .to(dtype=torch.float64, device=env.DEVICE) + ) + ret["bar"] = ( + torch.Tensor( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ) + .view([nf, nloc, *self.output_def()["bar"].shape]) + .to(dtype=torch.float64, device=env.DEVICE) + ) + return ret + + +class TestAtomicModelStat(unittest.TestCase, TestCaseSingleFrameWithNlist): + def tearDown(self) -> None: + self.tempdir.cleanup() + + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + self.merged_output_stat = [ + { + "coord": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "atype": torch.tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32), device=self.device + ), + "atype_ext": torch.tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32), + device=self.device, + ), + "box": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "natoms": torch.tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32), + device=self.device, + ), + # bias of foo: 5, 6 + "atom_foo": torch.tensor( + np.array([[5.0, 5.0, 5.0], [5.0, 6.0, 7.0]]).reshape(2, 3, 1), + device=self.device, + ), + # bias of bar: [1, 5], [3, 2] + "bar": torch.tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2), device=self.device + ), + "find_atom_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + { + "coord": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "atype": torch.tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32), device=self.device + ), + "atype_ext": torch.tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32), + device=self.device, + ), + "box": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "natoms": torch.tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32), + device=self.device, + ), + # bias of foo: 5, 6 from atomic label. + "foo": torch.tensor( + np.array([5.0, 7.0]).reshape(2, 1), device=self.device + ), + # bias of bar: [1, 5], [3, 2] + "bar": torch.tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2), device=self.device + ), + "find_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + ] + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def test_output_stat(self) -> None: + """Test output statistics computation for pt_expt atomic model.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = FooFitting().to(self.device) + type_map = ["foo", "bar"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(self.device) + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: vv.detach().cpu().numpy() for kk, vv in x.items()} + + # 1. test run without bias + # nf x na x odim + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + expected_ret0 = {} + expected_ret0["foo"] = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["foo"].shape]) + expected_ret0["bar"] = np.array( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["bar"].shape]) + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk]) + + # 2. test bias is applied + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + expected_std = np.ones( + (2, 2, 2), dtype=np.float64 + ) # 2 keys, 2 atypes, 2 max dims. + expected_std[0, :, :1] = np.array([0.0, 0.816496]).reshape( + 2, 1 + ) # updating std for foo based on [5.0, 5.0, 5.0], [5.0, 6.0, 7.0]] + np.testing.assert_almost_equal( + md0.out_std.detach().cpu().numpy(), expected_std, decimal=4 + ) + ret1 = cvt_ret(ret1) + # nt x odim + foo_bias = np.array([5.0, 6.0]).reshape(2, 1) + bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2) + expected_ret1 = {} + expected_ret1["foo"] = ret0["foo"] + foo_bias[at] + expected_ret1["bar"] = ret0["bar"] + bar_bias[at] + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + + # 3. test bias load from file + def raise_error() -> NoReturn: + raise RuntimeError + + md0.compute_or_load_out_stat(raise_error, stat_file_path=self.stat_file_path) + ret2 = md0.forward_common_atomic(*args) + ret2 = cvt_ret(ret2) + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret1[kk], ret2[kk]) + np.testing.assert_almost_equal( + md0.out_std.detach().cpu().numpy(), expected_std, decimal=4 + ) + + # 4. test change bias + md0.change_out_bias( + self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + # use atype_ext from merged_output_stat for inference (matching pt backend test) + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + self.merged_output_stat[0]["atype_ext"].to( + dtype=torch.int64, device=self.device + ), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + ret3 = md0.forward_common_atomic(*args) + ret3 = cvt_ret(ret3) + expected_std[0, :, :1] = np.array([1.24722, 0.47140]).reshape( + 2, 1 + ) # updating std for foo based on [4.0, 3.0, 2.0], [1.0, 1.0, 1.0]] + expected_ret3 = {} + # new bias [2.666, 1.333] + expected_ret3["foo"] = np.array( + [[3.6667, 4.6667, 4.3333], [6.6667, 6.3333, 7.3333]] + ).reshape(2, 3, 1) + for kk in ["foo"]: + np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4) + np.testing.assert_almost_equal( + md0.out_std.detach().cpu().numpy(), expected_std, decimal=4 + ) + + +class TestAtomicModelStatMergeGlobalAtomic( + unittest.TestCase, TestCaseSingleFrameWithNlist +): + """Test merging atomic and global stat when atomic label only covers some types.""" + + def tearDown(self) -> None: + self.tempdir.cleanup() + + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + self.merged_output_stat = [ + { + "coord": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "atype": torch.tensor( + np.array([[0, 0, 0], [0, 0, 0]], dtype=np.int32), + device=self.device, + ), + "atype_ext": torch.tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32), + device=self.device, + ), + "box": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "natoms": torch.tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32), + device=self.device, + ), + # bias of foo: 5.5, nan (only type 0 atoms) + "atom_foo": torch.tensor( + np.array([[5.0, 5.0, 5.0], [5.0, 6.0, 7.0]]).reshape(2, 3, 1), + device=self.device, + ), + # bias of bar: [1, 5], [3, 2] + "bar": torch.tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2), + device=self.device, + ), + "find_atom_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + { + "coord": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "atype": torch.tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32), + device=self.device, + ), + "atype_ext": torch.tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32), + device=self.device, + ), + "box": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "natoms": torch.tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32), + device=self.device, + ), + # bias of foo: 5.5, 3 from global label. + "foo": torch.tensor( + np.array([5.0, 7.0]).reshape(2, 1), device=self.device + ), + # bias of bar: [1, 5], [3, 2] + "bar": torch.tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2), + device=self.device, + ), + "find_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + ] + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def test_output_stat(self) -> None: + """Test merging atomic (type 0 only) and global stat for type 1.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = FooFitting().to(self.device) + type_map = ["foo", "bar"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(self.device) + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: vv.detach().cpu().numpy() for kk, vv in x.items()} + + # 1. test run without bias + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + expected_ret0 = {} + expected_ret0["foo"] = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["foo"].shape]) + expected_ret0["bar"] = np.array( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["bar"].shape]) + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk]) + + # 2. test bias is applied + # foo: type 0 from atomic (mean=5.5), type 1 from global (lstsq=3.0) + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + # nt x odim + foo_bias = np.array([5.5, 3.0]).reshape(2, 1) + bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2) + expected_ret1 = {} + expected_ret1["foo"] = ret0["foo"] + foo_bias[at] + expected_ret1["bar"] = ret0["bar"] + bar_bias[at] + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + + # 3. test bias load from file + def raise_error() -> NoReturn: + raise RuntimeError + + md0.compute_or_load_out_stat(raise_error, stat_file_path=self.stat_file_path) + ret2 = md0.forward_common_atomic(*args) + ret2 = cvt_ret(ret2) + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret1[kk], ret2[kk]) + + # 4. test change bias + md0.change_out_bias( + self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + # use atype_ext from merged_output_stat for inference + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + self.merged_output_stat[0]["atype_ext"].to( + dtype=torch.int64, device=self.device + ), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + ret3 = md0.forward_common_atomic(*args) + ret3 = cvt_ret(ret3) + expected_ret3 = {} + # new bias [2, -5] + expected_ret3["foo"] = np.array([[3, 4, -2], [6, 0, 1]]).reshape(2, 3, 1) + for kk in ["foo"]: + np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4) diff --git a/source/tests/pt_expt/atomic_model/test_atomic_model_global_stat.py b/source/tests/pt_expt/atomic_model/test_atomic_model_global_stat.py new file mode 100644 index 0000000000..e09e7c0c91 --- /dev/null +++ b/source/tests/pt_expt/atomic_model/test_atomic_model_global_stat.py @@ -0,0 +1,759 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tempfile +import unittest +from pathlib import ( + Path, +) +from typing import ( + NoReturn, +) + +import h5py +import numpy as np +import torch + +from deepmd.dpmodel.atomic_model import DPAtomicModel as DPDPAtomicModel +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pt_expt.atomic_model import ( + DPAtomicModel, +) +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + InvarFitting, +) +from deepmd.pt_expt.fitting.base_fitting import ( + BaseFitting, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.utils.path import ( + DPPath, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class FooFitting(BaseFitting, torch.nn.Module): + """Test fitting with multiple outputs for testing global statistics.""" + + def __init__(self): + torch.nn.Module.__init__(self) + BaseFitting.__init__(self) + + def output_def(self): + return FittingOutputDef( + [ + OutputVariableDef( + "foo", + [1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + OutputVariableDef( + "pix", + [1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + OutputVariableDef( + "bar", + [1, 2], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + ] + ) + + def serialize(self) -> dict: + return { + "@class": "Fitting", + "type": "foo", + "@version": 1, + } + + @classmethod + def deserialize(cls, data: dict): + return cls() + + def get_dim_fparam(self) -> int: + return 0 + + def get_dim_aparam(self) -> int: + return 0 + + def get_sel_type(self) -> list[int]: + return [] + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat=None + ) -> None: + pass + + def get_type_map(self) -> list[str]: + return [] + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: torch.Tensor | None = None, + g2: torch.Tensor | None = None, + h2: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ): + nf, nloc, _ = descriptor.shape + ret = {} + ret["foo"] = ( + torch.Tensor( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ) + .view([nf, nloc, *self.output_def()["foo"].shape]) + .to(dtype=torch.float64, device=env.DEVICE) + ) + ret["pix"] = ( + torch.Tensor( + [ + [3.0, 2.0, 1.0], + [6.0, 5.0, 4.0], + ] + ) + .view([nf, nloc, *self.output_def()["pix"].shape]) + .to(dtype=torch.float64, device=env.DEVICE) + ) + ret["bar"] = ( + torch.Tensor( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ) + .view([nf, nloc, *self.output_def()["bar"].shape]) + .to(dtype=torch.float64, device=env.DEVICE) + ) + return ret + + +def _to_numpy(x): + return x.detach().cpu().numpy() + + +class TestAtomicModelStat(unittest.TestCase, TestCaseSingleFrameWithNlist): + def tearDown(self) -> None: + self.tempdir.cleanup() + + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + self.merged_output_stat = [ + { + "coord": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "atype": torch.tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32), + device=self.device, + ), + "atype_ext": torch.tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32), + device=self.device, + ), + "box": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "natoms": torch.tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32), + device=self.device, + ), + # bias of foo: 1, 3 + "foo": torch.tensor( + np.array([5.0, 7.0]).reshape(2, 1), device=self.device + ), + # no bias of pix + # bias of bar: [1, 5], [3, 2] + "bar": torch.tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2), + device=self.device, + ), + "find_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + } + ] + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def test_output_stat(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = FooFitting().to(self.device) + type_map = ["foo", "bar"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(self.device) + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: _to_numpy(vv) for kk, vv in x.items()} + + # 1. test run without bias + # nf x na x odim + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + + expected_ret0 = {} + expected_ret0["foo"] = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["foo"].shape]) + expected_ret0["pix"] = np.array( + [ + [3.0, 2.0, 1.0], + [6.0, 5.0, 4.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["pix"].shape]) + expected_ret0["bar"] = np.array( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["bar"].shape]) + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk]) + + # 2. test bias is applied + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + expected_std = np.ones((3, 2, 2)) # 3 keys, 2 atypes, 2 max dims. + # nt x odim + foo_bias = np.array([1.0, 3.0]).reshape(2, 1) + bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2) + expected_ret1 = {} + expected_ret1["foo"] = ret0["foo"] + foo_bias[at] + expected_ret1["pix"] = ret0["pix"] + expected_ret1["bar"] = ret0["bar"] + bar_bias[at] + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + np.testing.assert_almost_equal(_to_numpy(md0.out_std), expected_std) + + # 3. test bias load from file + def raise_error() -> NoReturn: + raise RuntimeError + + md0.compute_or_load_out_stat(raise_error, stat_file_path=self.stat_file_path) + ret2 = md0.forward_common_atomic(*args) + ret2 = cvt_ret(ret2) + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret1[kk], ret2[kk]) + np.testing.assert_almost_equal(_to_numpy(md0.out_std), expected_std) + + # 4. test change bias + md0.change_out_bias( + self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + # use atype_ext from merged_output_stat for inference (matching pt backend test) + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + self.merged_output_stat[0]["atype_ext"].to( + dtype=torch.int64, device=self.device + ), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + ret3 = md0.forward_common_atomic(*args) + ret3 = cvt_ret(ret3) + ## model output on foo: [[2, 3, 6], [5, 8, 9]] given bias [1, 3] + ## foo sumed: [11, 22] compared with [5, 7], fit target is [-6, -15] + ## fit bias is [1, -8] + ## old bias + fit bias [2, -5] + ## new model output is [[3, 4, -2], [6, 0, 1]], which sumed to [5, 7] + expected_ret3 = {} + expected_ret3["foo"] = np.array([[3, 4, -2], [6, 0, 1]]).reshape(2, 3, 1) + expected_ret3["pix"] = ret0["pix"] + for kk in ["foo", "pix"]: + np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk]) + # bar is too complicated to be manually computed. + np.testing.assert_almost_equal(_to_numpy(md0.out_std), expected_std) + + def test_preset_bias(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = FooFitting().to(self.device) + type_map = ["foo", "bar"] + preset_out_bias = { + "foo": [None, 2], + "bar": np.array([7.0, 5.0, 13.0, 11.0]).reshape(2, 1, 2), + } + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + preset_out_bias=preset_out_bias, + ).to(self.device) + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: _to_numpy(vv) for kk, vv in x.items()} + + # 1. test run without bias + # nf x na x odim + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + expected_ret0 = {} + expected_ret0["foo"] = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["foo"].shape]) + expected_ret0["pix"] = np.array( + [ + [3.0, 2.0, 1.0], + [6.0, 5.0, 4.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["pix"].shape]) + expected_ret0["bar"] = np.array( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["bar"].shape]) + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk]) + + # 2. test bias is applied + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + # foo sums: [5, 7], + # given bias of type 1 being 2, the bias left for type 0 is [5-2*1, 7-2*2] = [3,3] + # the solution of type 0 is 1.8 + foo_bias = np.array([1.8, preset_out_bias["foo"][1]]).reshape(2, 1) + bar_bias = preset_out_bias["bar"] + expected_ret1 = {} + expected_ret1["foo"] = ret0["foo"] + foo_bias[at] + expected_ret1["pix"] = ret0["pix"] + expected_ret1["bar"] = ret0["bar"] + bar_bias[at] + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + + # 3. test bias load from file + def raise_error() -> NoReturn: + raise RuntimeError + + md0.compute_or_load_out_stat(raise_error, stat_file_path=self.stat_file_path) + ret2 = md0.forward_common_atomic(*args) + ret2 = cvt_ret(ret2) + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret1[kk], ret2[kk]) + + # 4. test change bias + md0.change_out_bias( + self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + # use atype_ext from merged_output_stat for inference + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + self.merged_output_stat[0]["atype_ext"].to( + dtype=torch.int64, device=self.device + ), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + ret3 = md0.forward_common_atomic(*args) + ret3 = cvt_ret(ret3) + ## model output on foo: [[2.8, 3.8, 5], [5.8, 7., 8.]] given bias [1.8, 2] + ## foo sumed: [11.6, 20.8] compared with [5, 7], fit target is [-6.6, -13.8] + ## fit bias is [-7, 2] (2 is assigned. -7 is fit to [-8.6, -17.8]) + ## old bias[1.8,2] + fit bias[-7, 2] = [-5.2, 4] + ## new model output is [[-4.2, -3.2, 7], [-1.2, 9, 10]] + expected_ret3 = {} + expected_ret3["foo"] = np.array([[-4.2, -3.2, 7.0], [-1.2, 9.0, 10.0]]).reshape( + 2, 3, 1 + ) + expected_ret3["pix"] = ret0["pix"] + for kk in ["foo", "pix"]: + np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk]) + # bar is too complicated to be manually computed. + + def test_preset_bias_all_none(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = FooFitting().to(self.device) + type_map = ["foo", "bar"] + preset_out_bias = { + "foo": [None, None], + } + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + preset_out_bias=preset_out_bias, + ).to(self.device) + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: _to_numpy(vv) for kk, vv in x.items()} + + # 1. test run without bias + # nf x na x odim + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + expected_ret0 = {} + expected_ret0["foo"] = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["foo"].shape]) + expected_ret0["pix"] = np.array( + [ + [3.0, 2.0, 1.0], + [6.0, 5.0, 4.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["pix"].shape]) + expected_ret0["bar"] = np.array( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["bar"].shape]) + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk]) + + # 2. test bias is applied (all None preset = same as no preset) + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + # nt x odim + foo_bias = np.array([1.0, 3.0]).reshape(2, 1) + bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2) + expected_ret1 = {} + expected_ret1["foo"] = ret0["foo"] + foo_bias[at] + expected_ret1["pix"] = ret0["pix"] + expected_ret1["bar"] = ret0["bar"] + bar_bias[at] + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + + def test_serialize(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = InvarFitting( + "foo", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ).to(self.device) + type_map = ["A", "B"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(self.device) + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + + def cvt_ret(x): + return {kk: _to_numpy(vv) for kk, vv in x.items()} + + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + md1 = DPAtomicModel.deserialize(md0.serialize()) + ret1 = md1.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + + for kk in ["foo"]: + np.testing.assert_almost_equal(ret0[kk], ret1[kk]) + + md2 = DPDPAtomicModel.deserialize(md0.serialize()) + args_np = [self.coord_ext, self.atype_ext, self.nlist] + ret2 = md2.forward_common_atomic(*args_np) + for kk in ["foo"]: + np.testing.assert_almost_equal(ret0[kk], ret2[kk]) + + +class TestChangeByStatMixedLabels(unittest.TestCase, TestCaseSingleFrameWithNlist): + """Test change-by-statistic with mixed atomic and global labels.""" + + def tearDown(self) -> None: + self.tempdir.cleanup() + + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + self.merged_output_stat = [ + { + "coord": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "atype": torch.tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32), + device=self.device, + ), + "atype_ext": torch.tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32), + device=self.device, + ), + "box": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "natoms": torch.tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32), + device=self.device, + ), + # foo: atomic label + "atom_foo": torch.tensor( + np.array([[5.0, 5.0, 5.0], [5.0, 6.0, 7.0]]).reshape(2, 3, 1), + device=self.device, + ), + # pix: global label + "pix": torch.tensor( + np.array([5.0, 12.0]).reshape(2, 1), device=self.device + ), + # bar: global label + "bar": torch.tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2), + device=self.device, + ), + "find_atom_foo": np.float32(1.0), + "find_pix": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + ] + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def test_change_by_statistic(self) -> None: + """Test change-by-statistic with atomic foo + global pix + global bar.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = FooFitting().to(self.device) + type_map = ["foo", "bar"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(self.device) + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + + def cvt_ret(x): + return {kk: _to_numpy(vv) for kk, vv in x.items()} + + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + + # set initial bias + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + + # change bias + md0.change_out_bias( + self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + # use atype_ext from merged_output_stat for inference + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + self.merged_output_stat[0]["atype_ext"].to( + dtype=torch.int64, device=self.device + ), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + ret3 = md0.forward_common_atomic(*args) + ret3 = cvt_ret(ret3) + # foo: atomic label, bias after set-by-stat: [5, 6] + # model output with bias [5,6], atype [[0,0,1],[0,1,1]]: + # [[6, 7, 9], [9, 11, 12]] + # atom_foo labels: [[5, 5, 5], [5, 6, 7]] + # per-atom delta: [[-1, -2, -4], [-4, -5, -5]] + # delta bias (mean per type): type0=-7/3, type1=-14/3 + # new bias = [5-7/3, 6-14/3] = [8/3, 4/3] + # new output: [[11/3, 14/3, 13/3], [20/3, 19/3, 22/3]] + expected_ret3 = {} + expected_ret3["foo"] = np.array( + [[3.6667, 4.6667, 4.3333], [6.6667, 6.3333, 7.3333]] + ).reshape(2, 3, 1) + # pix: global label, bias after set-by-stat: [-2/3, 19/3] + # model pix with bias, atype [[0,0,1],[0,1,1]]: + # [[7/3, 4/3, 22/3], [16/3, 34/3, 31/3]], sums [11, 27] + # labels [5, 12], delta [-6, -15] + # lstsq: delta bias [1, -8], new bias [1/3, -5/3] + # new output: [[10/3, 7/3, -2/3], [19/3, 10/3, 7/3]] + expected_ret3["pix"] = np.array( + [[3.3333, 2.3333, -0.6667], [6.3333, 3.3333, 2.3333]] + ).reshape(2, 3, 1) + for kk in ["foo", "pix"]: + np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4) + # bar is too complicated to be manually computed. + + +class TestEnergyModelStat(unittest.TestCase, TestCaseSingleFrameWithNlist): + """Test statistics computation with real energy fitting net.""" + + def tearDown(self) -> None: + self.tempdir.cleanup() + + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + self.merged_output_stat = [ + { + "coord": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "atype": torch.tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32), + device=self.device, + ), + "atype_ext": torch.tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32), + device=self.device, + ), + "box": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "natoms": torch.tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32), + device=self.device, + ), + # energy data + "energy": torch.tensor( + np.array([10.0, 20.0]).reshape(2, 1), device=self.device + ), + "find_energy": np.float32(1.0), + }, + ] + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def test_energy_stat(self) -> None: + """Test energy statistics computation with real energy fitting net.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ).to(self.device) + type_map = ["foo", "bar"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(self.device) + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + + # test run without bias + ret0 = md0.forward_common_atomic(*args) + self.assertIn("energy", ret0) + + # compute statistics + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + self.assertIn("energy", ret1) + + # Check that bias was computed (out_bias should be non-zero) + self.assertFalse(torch.all(md0.out_bias == 0)) + + # test bias load from file + def raise_error() -> NoReturn: + raise RuntimeError + + md0.compute_or_load_out_stat(raise_error, stat_file_path=self.stat_file_path) + ret2 = md0.forward_common_atomic(*args) + np.testing.assert_allclose( + ret1["energy"].detach().cpu().numpy(), + ret2["energy"].detach().cpu().numpy(), + ) + + # test change bias + md0.change_out_bias( + self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + ret3 = md0.forward_common_atomic(*args) + self.assertIn("energy", ret3) diff --git a/source/tests/pt_expt/atomic_model/test_dp_atomic_model.py b/source/tests/pt_expt/atomic_model/test_dp_atomic_model.py new file mode 100644 index 0000000000..49e60373d4 --- /dev/null +++ b/source/tests/pt_expt/atomic_model/test_dp_atomic_model.py @@ -0,0 +1,287 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.atomic_model import DPAtomicModel as DPDPAtomicModel +from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA +from deepmd.dpmodel.fitting import InvarFitting as DPInvarFitting +from deepmd.pt_expt.atomic_model import ( + DPAtomicModel, +) +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + InvarFitting, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, + TestCaseSingleFrameWithNlistWithVirtual, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestDPAtomicModel(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_self_consistency(self) -> None: + """Test that pt_expt atomic model serialize/deserialize preserves behavior.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ).to(self.device) + type_map = ["foo", "bar"] + + # test the case of exclusion + for atom_excl, pair_excl in itertools.product([[], [1]], [[], [[0, 1]]]): + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(self.device) + md0.reinit_atom_exclude(atom_excl) + md0.reinit_pair_exclude(pair_excl) + md1 = DPAtomicModel.deserialize(md0.serialize()).to(self.device) + + # Test forward pass + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + ret0 = md0.forward_common_atomic(*args) + ret1 = md1.forward_common_atomic(*args) + np.testing.assert_allclose( + ret0["energy"].detach().cpu().numpy(), + ret1["energy"].detach().cpu().numpy(), + ) + + def test_dp_consistency(self) -> None: + """Test numerical consistency between dpmodel and pt_expt atomic models.""" + nf, nloc, nnei = self.nlist.shape + ds = DPDescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft = DPInvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ) + type_map = ["foo", "bar"] + md0 = DPDPAtomicModel(ds, ft, type_map=type_map) + md1 = DPAtomicModel.deserialize(md0.serialize()).to(self.device) + + # dpmodel uses numpy arrays + args0 = [self.coord_ext, self.atype_ext, self.nlist] + # pt_expt uses torch tensors + args1 = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + ret0 = md0.forward_common_atomic(*args0) + ret1 = md1.forward_common_atomic(*args1) + np.testing.assert_allclose( + ret0["energy"], + ret1["energy"].detach().cpu().numpy(), + ) + + def test_exportable(self) -> None: + """Test that pt_expt atomic model can be exported with torch.export.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ).to(self.device) + type_map = ["foo", "bar"] + md0 = DPAtomicModel(ds, ft, type_map=type_map).to(self.device) + md0 = md0.eval() + + # Prepare inputs for export + coord = torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device) + atype = torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device) + nlist = torch.tensor(self.nlist, dtype=torch.int64, device=self.device) + + # Test forward pass + ret0 = md0(coord, atype, nlist) + self.assertIn("energy", ret0) + + # Test torch.export + exported = torch.export.export( + md0, + (coord, atype, nlist), + strict=False, + ) + self.assertIsNotNone(exported) + + # Test exported model produces same output + ret1 = exported.module()(coord, atype, nlist) + np.testing.assert_allclose( + ret0["energy"].detach().cpu().numpy(), + ret1["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + + def test_excl_consistency(self) -> None: + """Test that exclusion masks work correctly after serialize/deserialize.""" + type_map = ["foo", "bar"] + + # test the case of exclusion + for atom_excl, pair_excl in itertools.product([[], [1]], [[], [[0, 1]]]): + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ).to(self.device) + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(self.device) + md1 = DPAtomicModel.deserialize(md0.serialize()).to(self.device) + + md0.reinit_atom_exclude(atom_excl) + md0.reinit_pair_exclude(pair_excl) + # hacking! + md1.descriptor.reinit_exclude(pair_excl) + md1.fitting.reinit_exclude(atom_excl) + + # check energy consistency + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + ret0 = md0.forward_common_atomic(*args) + ret1 = md1.forward_common_atomic(*args) + np.testing.assert_allclose( + ret0["energy"].detach().cpu().numpy(), + ret1["energy"].detach().cpu().numpy(), + ) + + # check output def + out_names = [vv.name for vv in md0.atomic_output_def().get_data().values()] + self.assertEqual(out_names, ["energy", "mask"]) + if atom_excl != []: + for ii in md0.atomic_output_def().get_data().values(): + if ii.name == "mask": + self.assertEqual(ii.shape, [1]) + self.assertFalse(ii.reducible) + self.assertFalse(ii.r_differentiable) + self.assertFalse(ii.c_differentiable) + + # check mask + if atom_excl == []: + pass + elif atom_excl == [1]: + self.assertIn("mask", ret0.keys()) + expected = np.array([1, 1, 0], dtype=int) + expected = np.concatenate( + [expected, expected[self.perm[: self.nloc]]] + ).reshape(2, 3) + np.testing.assert_array_equal( + ret0["mask"].detach().cpu().numpy(), expected + ) + else: + raise ValueError(f"not expected atom_excl {atom_excl}") + + +class TestDPAtomicModelVirtualConsistency(unittest.TestCase): + def setUp(self) -> None: + self.case0 = TestCaseSingleFrameWithNlist() + self.case1 = TestCaseSingleFrameWithNlistWithVirtual() + self.case0.setUp() + self.case1.setUp() + self.device = env.DEVICE + + def test_virtual_consistency(self) -> None: + nf, _, _ = self.case0.nlist.shape + ds = DescrptSeA( + self.case0.rcut, + self.case0.rcut_smth, + self.case0.sel, + ) + ft = InvarFitting( + "energy", + self.case0.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ) + type_map = ["foo", "bar"] + md1 = DPAtomicModel(ds, ft, type_map=type_map).to(self.device) + + args0 = [ + torch.tensor(self.case0.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.case0.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.case0.nlist, dtype=torch.int64, device=self.device), + ] + args1 = [ + torch.tensor(self.case1.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.case1.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.case1.nlist, dtype=torch.int64, device=self.device), + ] + + ret0 = md1.forward_common_atomic(*args0) + ret1 = md1.forward_common_atomic(*args1) + + for dd in range(self.case0.nf): + np.testing.assert_allclose( + ret0["energy"][dd].detach().cpu().numpy(), + ret1["energy"][dd, self.case1.get_real_mapping[dd], :] + .detach() + .cpu() + .numpy(), + ) + expected_mask = np.array( + [ + [1, 0, 1, 1], + [1, 1, 0, 1], + ] + ) + np.testing.assert_equal(ret1["mask"].detach().cpu().numpy(), expected_mask) diff --git a/source/tests/pt_expt/descriptor/test_se_t.py b/source/tests/pt_expt/descriptor/test_se_t.py new file mode 100644 index 0000000000..921f10a54a --- /dev/null +++ b/source/tests/pt_expt/descriptor/test_se_t.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import DescrptSeT as DPDescrptSeT +from deepmd.pt_expt.descriptor.se_t import ( + DescrptSeT, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.env import ( + PRECISION_DICT, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...pt.model.test_mlp import ( + get_tols, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestDescrptSeT(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_consistency(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [False, True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + dd0 = DescrptSeT( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + dd1 = DescrptSeT.deserialize(dd0.serialize()) + rd1, gr1, _, _, sw1 = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], + rd0.detach().cpu().numpy()[1], + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + dd2 = DPDescrptSeT.deserialize(dd0.serialize()) + rd2, gr2, _, _, sw2 = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + # se_t returns None for gr/g2/h2, only compare rd and sw + np.testing.assert_allclose( + rd1.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + sw1.detach().cpu().numpy(), + sw2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + def test_exportable(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [False, True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + dd0 = DescrptSeT( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) diff --git a/source/tests/pt_expt/descriptor/test_se_t_tebd.py b/source/tests/pt_expt/descriptor/test_se_t_tebd.py new file mode 100644 index 0000000000..e84080882a --- /dev/null +++ b/source/tests/pt_expt/descriptor/test_se_t_tebd.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import DescrptSeTTebd as DPDescrptSeTTebd +from deepmd.pt_expt.descriptor.se_t_tebd import ( + DescrptSeTTebd, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.env import ( + PRECISION_DICT, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...pt.model.test_mlp import ( + get_tols, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestDescrptSeTTebd(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_consistency(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [True], # SeTTebd typically uses resnet_dt=True + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + dd0 = DescrptSeTTebd( + self.rcut, + self.rcut_smth, + self.sel, + self.nt, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + dd1 = DescrptSeTTebd.deserialize(dd0.serialize()) + rd1, gr1, _, _, sw1 = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], + rd0.detach().cpu().numpy()[1], + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + dd2 = DPDescrptSeTTebd.deserialize(dd0.serialize()) + rd2, gr2, _, _, sw2 = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + # se_t_tebd should return gr and sw, compare only descriptor and sw for now + # TODO: investigate why gr is None + np.testing.assert_allclose( + rd1.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + if gr1 is not None and gr2 is not None: + np.testing.assert_allclose( + gr1.detach().cpu().numpy(), + gr2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + sw1.detach().cpu().numpy(), + sw2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + def test_exportable(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + dd0 = DescrptSeTTebd( + self.rcut, + self.rcut_smth, + self.sel, + self.nt, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) diff --git a/source/tests/pt_expt/fitting/__init__.py b/source/tests/pt_expt/fitting/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt_expt/fitting/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt_expt/fitting/test_fitting_ener_fitting.py b/source/tests/pt_expt/fitting/test_fitting_ener_fitting.py new file mode 100644 index 0000000000..63ae82ab9a --- /dev/null +++ b/source/tests/pt_expt/fitting/test_fitting_ener_fitting.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + EnergyFittingNet, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestEnergyFittingNet(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_self_consistency( + self, + ) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + atype = self.atype_ext[:, :nloc] + + for nfp, nap in [(0, 0), (3, 0), (0, 4), (3, 4)]: + efn0 = EnergyFittingNet( + self.nt, + ds.dim_out, + numb_fparam=nfp, + numb_aparam=nap, + ).to(self.device) + efn1 = EnergyFittingNet.deserialize(efn0.serialize()).to(self.device) + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) + else: + ifp = None + if nap > 0: + iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( + self.device + ) + else: + iap = None + ret0 = efn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + ret1 = efn1( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + np.testing.assert_allclose( + ret0["energy"].detach().cpu().numpy(), + ret1["energy"].detach().cpu().numpy(), + ) + + def test_serialize_has_correct_type(self) -> None: + """Test that EnergyFittingNet serializes with type='ener' not 'invar'.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + + efn = EnergyFittingNet( + self.nt, + ds.dim_out, + ).to(self.device) + serialized = efn.serialize() + + # Check that the type is 'ener' not 'invar' + self.assertEqual(serialized["type"], "ener") + + # Check that it can be deserialized + efn2 = EnergyFittingNet.deserialize(serialized).to(self.device) + self.assertIsInstance(efn2, EnergyFittingNet) + + def test_torch_export_simple(self) -> None: + """Test that EnergyFittingNet can be exported with torch.export.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + rng = np.random.default_rng(GLOBAL_SEED) + + efn = EnergyFittingNet( + self.nt, + ds.dim_out, + numb_fparam=0, + numb_aparam=0, + ).to(self.device) + + # Prepare inputs + descriptor = torch.from_numpy( + rng.standard_normal((self.nf, self.nloc, ds.dim_out)) + ).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + + # Test forward pass works + ret = efn(descriptor, atype) + self.assertIn("energy", ret) + + # Test torch.export + exported = torch.export.export( + efn, + (descriptor, atype), + kwargs={}, + strict=False, + ) + self.assertIsNotNone(exported) + + # Test exported model produces same output + ret_exported = exported.module()(descriptor, atype) + np.testing.assert_allclose( + ret["energy"].detach().cpu().numpy(), + ret_exported["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + + def test_torch_export_with_aparam(self) -> None: + """Test that EnergyFittingNet with aparam can be exported.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + rng = np.random.default_rng(GLOBAL_SEED) + + efn = EnergyFittingNet( + self.nt, + ds.dim_out, + numb_fparam=0, + numb_aparam=4, + ).to(self.device) + + # Prepare inputs + descriptor = torch.from_numpy( + rng.normal(size=(self.nf, self.nloc, ds.dim_out)) + ).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + aparam = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, 4))).to( + self.device + ) + + # Test forward pass works + ret = efn(descriptor, atype, aparam=aparam) + self.assertIn("energy", ret) + + # Test torch.export + exported = torch.export.export( + efn, + (descriptor, atype), + kwargs={"aparam": aparam}, + strict=False, + ) + self.assertIsNotNone(exported) + + # Test exported model produces same output + ret_exported = exported.module()(descriptor, atype, aparam=aparam) + np.testing.assert_allclose( + ret["energy"].detach().cpu().numpy(), + ret_exported["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) diff --git a/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py b/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py new file mode 100644 index 0000000000..d682b37145 --- /dev/null +++ b/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py @@ -0,0 +1,311 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + InvarFitting, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestInvarFitting(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_self_consistency( + self, + ) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + atype = self.atype_ext[:, :nloc] + + for ( + mixed_types, + od, + nfp, + nap, + et, + ) in itertools.product( + [True, False], + [1, 2], + [0, 3], + [0, 4], + [[], [0], [1]], + ): + ifn0 = InvarFitting( + "energy", + self.nt, + ds.dim_out, + od, + numb_fparam=nfp, + numb_aparam=nap, + mixed_types=mixed_types, + exclude_types=et, + ).to(self.device) + ifn1 = InvarFitting.deserialize(ifn0.serialize()).to(self.device) + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) + else: + ifp = None + if nap > 0: + iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( + self.device + ) + else: + iap = None + ret0 = ifn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + ret1 = ifn1( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + np.testing.assert_allclose( + ret0["energy"].detach().cpu().numpy(), + ret1["energy"].detach().cpu().numpy(), + ) + sel_set = set(ifn0.get_sel_type()) + exclude_set = set(et) + self.assertEqual(sel_set | exclude_set, set(range(self.nt))) + self.assertEqual(sel_set & exclude_set, set()) + + def test_mask(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + atype = self.atype_ext[:, :nloc] + od = 2 + mixed_types = True + # exclude type 1 + et = [1] + ifn0 = InvarFitting( + "energy", + self.nt, + ds.dim_out, + od, + mixed_types=mixed_types, + exclude_types=et, + ).to(self.device) + ret0 = ifn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + ) + # atom index 2 is of type 1 that is excluded + zero_idx = 2 + np.testing.assert_allclose( + ret0["energy"][0, zero_idx, :].detach().cpu().numpy(), + np.zeros_like(ret0["energy"][0, zero_idx, :].detach().cpu().numpy()), + ) + zero_idx = 0 + np.testing.assert_allclose( + ret0["energy"][1, zero_idx, :].detach().cpu().numpy(), + np.zeros_like(ret0["energy"][1, zero_idx, :].detach().cpu().numpy()), + ) + + def test_self_exception( + self, + ) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + atype = self.atype_ext[:, :nloc] + + for ( + mixed_types, + od, + nfp, + nap, + ) in itertools.product( + [True, False], + [1, 2], + [0, 3], + [0, 4], + ): + ifn0 = InvarFitting( + "energy", + self.nt, + ds.dim_out, + od, + numb_fparam=nfp, + numb_aparam=nap, + mixed_types=mixed_types, + ).to(self.device) + + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) + else: + ifp = None + if nap > 0: + iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( + self.device + ) + else: + iap = None + with self.assertRaises(ValueError) as context: + ret0 = ifn0( + torch.from_numpy(dd[0][:, :, :-2]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + self.assertIn("input descriptor", str(context.exception)) + + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp - 1))).to( + self.device + ) + with self.assertRaises(ValueError) as context: + ret0 = ifn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + self.assertIn("input fparam", str(context.exception)) + + if nap > 0: + iap = torch.from_numpy( + rng.normal(size=(self.nf, self.nloc, nap - 1)) + ).to(self.device) + with self.assertRaises(ValueError) as context: + ifn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + self.assertIn("input aparam", str(context.exception)) + + def test_get_set(self) -> None: + ifn0 = InvarFitting( + "energy", + self.nt, + 3, + 1, + ).to(self.device) + rng = np.random.default_rng(GLOBAL_SEED) + foo = rng.normal([3, 4]) + for ii in [ + "bias_atom_e", + "fparam_avg", + "fparam_inv_std", + "aparam_avg", + "aparam_inv_std", + ]: + ifn0[ii] = torch.from_numpy(foo).to(self.device) + np.testing.assert_allclose( + foo, ifn0[ii].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + + def test_torch_export_simple(self) -> None: + """Test that InvarFitting can be exported with torch.export.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + rng = np.random.default_rng(GLOBAL_SEED) + + ifn = InvarFitting( + "energy", + self.nt, + ds.dim_out, + 1, + numb_fparam=0, + numb_aparam=0, + mixed_types=True, + ).to(self.device) + + # Prepare inputs + descriptor = torch.from_numpy( + rng.standard_normal((self.nf, self.nloc, ds.dim_out)) + ).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + + # Test forward pass works + ret = ifn(descriptor, atype) + self.assertIn("energy", ret) + + # Test torch.export + exported = torch.export.export( + ifn, + (descriptor, atype), + kwargs={}, + strict=False, # Use strict=False for now to handle dynamic shapes + ) + self.assertIsNotNone(exported) + + # Test exported model produces same output + ret_exported = exported.module()(descriptor, atype) + np.testing.assert_allclose( + ret["energy"].detach().cpu().numpy(), + ret_exported["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + + def test_torch_export_with_fparam(self) -> None: + """Test that InvarFitting with fparam can be exported.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + rng = np.random.default_rng(GLOBAL_SEED) + + ifn = InvarFitting( + "energy", + self.nt, + ds.dim_out, + 1, + numb_fparam=3, + numb_aparam=0, + mixed_types=True, + ).to(self.device) + + # Prepare inputs + descriptor = torch.from_numpy( + rng.normal(size=(self.nf, self.nloc, ds.dim_out)) + ).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + fparam = torch.from_numpy(rng.normal(size=(self.nf, 3))).to(self.device) + + # Test forward pass works + ret = ifn(descriptor, atype, fparam=fparam) + self.assertIn("energy", ret) + + # Test torch.export + exported = torch.export.export( + ifn, + (descriptor, atype), + kwargs={"fparam": fparam}, + strict=False, + ) + self.assertIsNotNone(exported) + + # Test exported model produces same output + ret_exported = exported.module()(descriptor, atype, fparam=fparam) + np.testing.assert_allclose( + ret["energy"].detach().cpu().numpy(), + ret_exported["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) diff --git a/source/tests/pt_expt/fitting/test_fitting_stat.py b/source/tests/pt_expt/fitting/test_fitting_stat.py new file mode 100644 index 0000000000..b473c9309c --- /dev/null +++ b/source/tests/pt_expt/fitting/test_fitting_stat.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + EnergyFittingNet, +) +from deepmd.pt_expt.utils import ( + env, +) + + +def _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds): + """Make fake data as numpy arrays for dpmodel compute_input_stats.""" + merged_output_stat = [] + nsys = len(sys_natoms) + ndof = len(avgs) + for ii in range(nsys): + sys_dict = {} + tmp_data_f = [] + tmp_data_a = [] + for jj in range(ndof): + rng = np.random.default_rng(2025 * ii + 220 * jj) + tmp_data_f.append( + rng.normal(loc=avgs[jj], scale=stds[jj], size=(sys_nframes[ii], 1)) + ) + rng = np.random.default_rng(220 * ii + 1636 * jj) + tmp_data_a.append( + rng.normal( + loc=avgs[jj], scale=stds[jj], size=(sys_nframes[ii], sys_natoms[ii]) + ) + ) + tmp_data_f = np.transpose(tmp_data_f, (1, 2, 0)) + tmp_data_a = np.transpose(tmp_data_a, (1, 2, 0)) + # dpmodel's compute_input_stats expects numpy arrays + sys_dict["fparam"] = tmp_data_f + sys_dict["aparam"] = tmp_data_a + merged_output_stat.append(sys_dict) + return merged_output_stat + + +def _brute_fparam_pt(data, ndim): + adata = [ii["fparam"] for ii in data] + all_data = [] + for ii in adata: + tmp = np.reshape(ii, [-1, ndim]) + if len(all_data) == 0: + all_data = np.array(tmp) + else: + all_data = np.concatenate((all_data, tmp), axis=0) + avg = np.average(all_data, axis=0) + std = np.std(all_data, axis=0) + return avg, std + + +def _brute_aparam_pt(data, ndim): + adata = [ii["aparam"] for ii in data] + all_data = [] + for ii in adata: + tmp = np.reshape(ii, [-1, ndim]) + if len(all_data) == 0: + all_data = np.array(tmp) + else: + all_data = np.concatenate((all_data, tmp), axis=0) + avg = np.average(all_data, axis=0) + std = np.std(all_data, axis=0) + return avg, std + + +class TestEnerFittingStat(unittest.TestCase): + def setUp(self) -> None: + self.device = env.DEVICE + + def test(self) -> None: + descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16) + fitting = EnergyFittingNet( + descrpt.get_ntypes(), + descrpt.get_dim_out(), + neuron=[240, 240, 240], + resnet_dt=True, + numb_fparam=3, + numb_aparam=3, + ).to(self.device) + avgs = [0, 10, 100] + stds = [2, 0.4, 0.00001] + sys_natoms = [10, 100] + sys_nframes = [5, 2] + all_data = _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds) + frefa, frefs = _brute_fparam_pt(all_data, len(avgs)) + arefa, arefs = _brute_aparam_pt(all_data, len(avgs)) + fitting.compute_input_stats(all_data, protection=1e-2) + frefs_inv = 1.0 / frefs + arefs_inv = 1.0 / arefs + frefs_inv[frefs_inv > 100] = 100 + arefs_inv[arefs_inv > 100] = 100 + # fparam_avg and fparam_inv_std are torch tensors on device + fparam_avg_np = ( + fitting.fparam_avg.detach().cpu().numpy() + if torch.is_tensor(fitting.fparam_avg) + else fitting.fparam_avg + ) + fparam_inv_std_np = ( + fitting.fparam_inv_std.detach().cpu().numpy() + if torch.is_tensor(fitting.fparam_inv_std) + else fitting.fparam_inv_std + ) + aparam_avg_np = ( + fitting.aparam_avg.detach().cpu().numpy() + if torch.is_tensor(fitting.aparam_avg) + else fitting.aparam_avg + ) + aparam_inv_std_np = ( + fitting.aparam_inv_std.detach().cpu().numpy() + if torch.is_tensor(fitting.aparam_inv_std) + else fitting.aparam_inv_std + ) + np.testing.assert_almost_equal(frefa, fparam_avg_np) + np.testing.assert_almost_equal(frefs_inv, fparam_inv_std_np) + np.testing.assert_almost_equal(arefa, aparam_avg_np) + np.testing.assert_almost_equal(arefs_inv, aparam_inv_std_np) diff --git a/source/tests/pt_expt/utils/test_network.py b/source/tests/pt_expt/utils/test_network.py index 24d61c5fd5..54f12554c1 100644 --- a/source/tests/pt_expt/utils/test_network.py +++ b/source/tests/pt_expt/utils/test_network.py @@ -281,3 +281,124 @@ def test_trainable_parameter_handling(self) -> None: for layer in net_frozen.layers: if layer.w is not None: self.assertFalse(layer.w.requires_grad) + + +class TestFittingNetRefactor(unittest.TestCase): + """Tests for the refactored FittingNet pt_expt wrapper.""" + + def setUp(self) -> None: + self.in_dim = 4 + self.out_dim = 1 + self.neuron = [8, 16] + self.activation = "tanh" + self.resnet_dt = True + self.precision = "float64" + + def test_pt_expt_fitting_net_wraps_dpmodel(self) -> None: + """Verify pt_expt FittingNet correctly wraps dpmodel.""" + from deepmd.pt_expt.utils.network import ( + FittingNet, + ) + + net = FittingNet( + in_dim=self.in_dim, + out_dim=self.out_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + # Check it's a torch.nn.Module + self.assertIsInstance(net, torch.nn.Module) + # Check layers are converted to pt_expt NativeLayer (torch modules) + self.assertIsInstance(net.layers, torch.nn.ModuleList) + for layer in net.layers: + self.assertIsInstance(layer, torch.nn.Module) + + def test_pt_expt_fitting_net_forward(self) -> None: + """Test pt_expt FittingNet forward pass returns torch.Tensor.""" + from deepmd.pt_expt.utils.network import ( + FittingNet, + ) + + net = FittingNet( + in_dim=self.in_dim, + out_dim=self.out_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + x = torch.randn(5, self.in_dim, dtype=torch.float64, device=env.DEVICE) + out = net(x) + self.assertIsInstance(out, torch.Tensor) + self.assertEqual(out.shape, (5, self.out_dim)) + self.assertEqual(out.dtype, torch.float64) + + def test_serialization_round_trip_pt_expt(self) -> None: + """Test pt_expt FittingNet serialization/deserialization.""" + from deepmd.pt_expt.utils.network import ( + FittingNet, + ) + + net = FittingNet( + in_dim=self.in_dim, + out_dim=self.out_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + x = torch.randn(5, self.in_dim, dtype=torch.float64, device=env.DEVICE) + out1 = net(x) + + # Serialize and deserialize + serialized = net.serialize() + net2 = FittingNet.deserialize(serialized) + + # Verify layers are still pt_expt NativeLayer modules + self.assertIsInstance(net2.layers, torch.nn.ModuleList) + for layer in net2.layers: + self.assertIsInstance(layer, torch.nn.Module) + + out2 = net2(x) + np.testing.assert_allclose( + out1.detach().cpu().numpy(), + out2.detach().cpu().numpy(), + ) + + def test_registry_converts_dpmodel_to_pt_expt(self) -> None: + """Test that dpmodel FittingNet can be converted to pt_expt via registry.""" + from deepmd.dpmodel.utils.network import FittingNet as DPFittingNet + from deepmd.pt_expt.common import ( + try_convert_module, + ) + from deepmd.pt_expt.utils.network import ( + FittingNet, + ) + + # Create dpmodel FittingNet + dp_net = DPFittingNet( + in_dim=self.in_dim, + out_dim=self.out_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + + # Try to convert via registry + converted = try_convert_module(dp_net) + + # Should return pt_expt FittingNet + self.assertIsNotNone(converted) + self.assertIsInstance(converted, torch.nn.Module) + self.assertIsInstance(converted, FittingNet) + + # Verify layers are pt_expt modules + for layer in converted.layers: + self.assertIsInstance(layer, torch.nn.Module)