diff --git a/deepmd/pt/model/atomic_model/sezm_atomic_model.py b/deepmd/pt/model/atomic_model/sezm_atomic_model.py index d27dad8b38..087d1e63f3 100644 --- a/deepmd/pt/model/atomic_model/sezm_atomic_model.py +++ b/deepmd/pt/model/atomic_model/sezm_atomic_model.py @@ -149,6 +149,22 @@ def get_active_mode(self) -> str: """Return the current SeZM execution mode.""" return str(getattr(self, "_active_mode", "ener")) + def get_compute_stats_distinguish_types(self) -> bool: + """Return whether output statistics are type-resolved.""" + active_fitting = self.get_active_fitting_net() + if active_fitting is not None and hasattr( + active_fitting, "get_distinguish_types" + ): + return bool(active_fitting.get_distinguish_types()) + return super().get_compute_stats_distinguish_types() + + def get_intensive(self) -> bool: + """Return whether the active reducible output is intensive.""" + active_fitting = self.get_active_fitting_net() + if active_fitting is not None and hasattr(active_fitting, "get_intensive"): + return bool(active_fitting.get_intensive()) + return super().get_intensive() + def _compute_or_load_dens_force_stat( self, sampled_func: Any, @@ -595,9 +611,16 @@ def apply_out_stat( dict[str, torch.Tensor] Outputs after SeZM output-stat post-processing. """ - if "energy" in ret: - out_bias, _ = self._fetch_out_stat(["energy"]) - ret["energy"] = ret["energy"] + out_bias["energy"][atype] + out_bias, out_std = self._fetch_out_stat(self.bias_keys) + for key in self.bias_keys: + if key not in ret: + continue + if key == "energy": + ret[key] = ret[key] + out_bias[key][atype] + elif self.get_compute_stats_distinguish_types(): + ret[key] = ret[key] * out_std[key][atype] + out_bias[key][atype] + else: + ret[key] = ret[key] * out_std[key][0] + out_bias[key][0] return ret def get_dim_fparam(self) -> int: diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 1a01b05fe9..d72b852cfb 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -31,6 +31,7 @@ ) from deepmd.pt.model.task.sezm_ener import ( SeZMEnergyFittingNet, + _resolve_auto_neuron, ) from deepmd.utils.spin import ( Spin, @@ -75,6 +76,9 @@ from .sezm_model import ( SeZMModel, ) +from .sezm_property_model import ( + SeZMPropertyModel, +) from .sezm_spin_model import ( SeZMSpinModel, ) @@ -341,12 +345,37 @@ def get_sezm_model(model_params: dict) -> BaseModel: descriptor = BaseDescriptor(**model_params["descriptor"]) fitting_net = copy.deepcopy(model_params["fitting_net"]) - fitting_net.pop("type", None) + fitting_net.setdefault("type", "dpa4_ener") fitting_net["ntypes"] = descriptor.get_ntypes() fitting_net["type_map"] = copy.deepcopy(model_params["type_map"]) fitting_net["mixed_types"] = descriptor.mixed_types() fitting_net["dim_descrpt"] = descriptor.get_dim_out() - fitting = SeZMEnergyFittingNet(**fitting_net) + if fitting_net["type"] in ("dpa4_ener", "sezm_ener"): + fitting = BaseFitting(**fitting_net) + modelcls = SeZMModel + elif fitting_net["type"] == "property": + if bridging_method != "NONE": + raise ValueError( + "DPA4/SeZM property fitting does not support analytical bridging " + "potentials; set `bridging_method` to `none`." + ) + # Share the SeZM auto-width convention + fitting_net["neuron"] = _resolve_auto_neuron( + fitting_net.get("neuron"), + dim_descrpt=fitting_net["dim_descrpt"], + numb_fparam=fitting_net.get("numb_fparam", 0), + numb_aparam=fitting_net.get("numb_aparam", 0), + dim_case_embd=fitting_net.get("dim_case_embd", 0), + case_film_embd=fitting_net.get("case_film_embd", False), + use_aparam_as_mask=fitting_net.get("use_aparam_as_mask", False), + ) + fitting = BaseFitting(**fitting_net) + modelcls = SeZMPropertyModel + else: + raise ValueError( + "DPA4/SeZM model supports `dpa4_ener`, `sezm_ener`, or `property` " + f"fitting, but got `{fitting_net['type']}`." + ) atom_exclude_types = model_params.get("atom_exclude_types", []) preset_out_bias = model_params.get("preset_out_bias") preset_out_bias = _convert_preset_out_bias_to_array( @@ -356,7 +385,7 @@ def get_sezm_model(model_params: dict) -> BaseModel: use_compile = bool(model_params.get("use_compile", False)) enable_tf32 = bool(model_params.get("enable_tf32", True)) - model = SeZMModel( + model = modelcls( descriptor=descriptor, fitting=fitting, type_map=model_params["type_map"], @@ -416,6 +445,12 @@ def get_sezm_spin_model(model_params: dict) -> BaseModel: descriptor = BaseDescriptor(**model_params["descriptor"]) fitting_net = copy.deepcopy(model_params["fitting_net"]) + fitting_net_type = fitting_net.get("type", "dpa4_ener") + if fitting_net_type not in ("dpa4_ener", "sezm_ener"): + raise ValueError( + "Spin DPA4/SeZM currently supports only `dpa4_ener` or `sezm_ener` " + f"fitting, but got `{fitting_net_type}`." + ) fitting_net.pop("type", None) fitting_net["ntypes"] = descriptor.get_ntypes() fitting_net["type_map"] = copy.deepcopy(model_params["type_map"]) @@ -461,7 +496,7 @@ def get_model(model_params: dict) -> Any: return get_standard_model(model_params) elif model_type == "linear_ener": return get_linear_model(model_params) - elif model_type in ("SeZM", "sezm", "dpa4"): + elif model_type in ("SeZM", "sezm", "DPA4", "dpa4"): if "spin" in model_params: return get_sezm_spin_model(model_params) return get_sezm_model(model_params) @@ -480,6 +515,7 @@ def get_model(model_params: dict) -> Any: "LinearEnergyModel", "PolarModel", "SeZMModel", + "SeZMPropertyModel", "SeZMSpinModel", "SpinEnergyModel", "SpinModel", diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py index c1bc494eec..2f07f2522c 100644 --- a/deepmd/pt/model/model/sezm_model.py +++ b/deepmd/pt/model/model/sezm_model.py @@ -487,6 +487,7 @@ ) from deepmd.pt.model.model.transform_output import ( edge_energy_deriv, + fit_output_to_model_output, ) from deepmd.pt.utils import ( env, @@ -1363,6 +1364,7 @@ def core_compute( extended_atype: torch.Tensor | None = None, extended_coord_corr: torch.Tensor | None = None, embedding_only: bool = False, + conservative: bool = True, ) -> dict[str, torch.Tensor]: """ Compute SeZM lower outputs from the unified edge-vector schema. @@ -1406,6 +1408,11 @@ def core_compute( embedding_only When ``True``, return only the embedding outputs and skip the force/virial autograd entirely. + conservative + Whether to run the conservative energy derivative path. Energy + fitting keeps this enabled. Non-conservative property fitting + disables it, so fitting outputs are reduced by their output + definition without constructing edge-force gradients. Returns ------- @@ -1428,8 +1435,9 @@ def core_compute( # This keeps coordinate gathering and shift application outside the # differentiated region while preserving conservative forces through the # scatter indices below. The embedding path produces no force, so it - # keeps ``edge_vec`` detached and never allocates an autograd leaf. - if not embedding_only: + # keeps ``edge_vec`` detached and never allocates an autograd leaf. The + # same forward-only treatment is used by non-conservative property heads. + if conservative and not embedding_only: edge_vec = edge_vec.detach().requires_grad_(True) # === Step 2. Descriptor forward === @@ -1502,10 +1510,20 @@ def core_compute( ).view(out_shape) fit_ret["mask"] = atom_mask + if not conservative: + return fit_output_to_model_output( + fit_ret, + self.atomic_output_def(), + coord, + create_graph=False, + mask=fit_ret["mask"], + extended_coord_corr=extended_coord_corr, + ) + # === Step 5. Inject analytical pair potential (edge form) === # ZBL is evaluated from ``edge_vec`` (the autograd leaf) so its force # and virial flow through the same edge backward as the learned energy. - if self.inter_potential is not None: + if self.inter_potential is not None and "energy" in fit_ret: fit_ret["energy"] = fit_ret["energy"] + self.inter_potential( edge_vec=edge_vec, edge_index=edge_index, @@ -2091,8 +2109,9 @@ def compute_fn( # type: ignore[misc] traced = rebuild_graph_module(traced) # The conservative Inductor option set that keeps the dynamic edge - # graph lowerable is centralised in ``deepmd.pt.utils.compile_compat``. - compile_options = build_inductor_compile_options() + # graph lowerable is centralised in ``deepmd.pt.utils.compile_compat``; + # subclasses may augment it via ``_inductor_compile_options``. + compile_options = self._inductor_compile_options() # NOTE: Store the compiled callable inside the plain-``dict`` # cache ``compiled_core_compute_cache``. The dict itself was installed @@ -2275,6 +2294,14 @@ def should_use_compile(self) -> bool: return self.use_compile return bool(self._env_use_compile_infer) + def _inductor_compile_options(self) -> dict[str, Any]: + """Return the Inductor lowering options for this model's compiled core. + + Subclasses may override this to augment the shared option set from + :func:`build_inductor_compile_options` with model-specific entries. + """ + return build_inductor_compile_options() + # ========================================================================= # Export Utilities # ========================================================================= diff --git a/deepmd/pt/model/model/sezm_property_model.py b/deepmd/pt/model/model/sezm_property_model.py new file mode 100644 index 0000000000..e1fcb7b793 --- /dev/null +++ b/deepmd/pt/model/model/sezm_property_model.py @@ -0,0 +1,192 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""SeZM model variant for invariant property prediction.""" + +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.output_def import ( + OutputVariableDef, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) + +from .sezm_model import ( + SeZMModel, +) + + +@BaseModel.register("SeZMProperty") +@BaseModel.register("sezm_property") +@BaseModel.register("DPA4Property") +@BaseModel.register("dpa4_property") +class SeZMPropertyModel(SeZMModel): + """SeZM sparse-edge model for invariant property fitting. + + The descriptor path, sparse edge construction, compile cache, and type + handling are inherited from :class:`SeZMModel`. The property variant only + changes the readout contract: fitting outputs are reduced by their + ``OutputVariableDef`` and no conservative force or virial derivative is + constructed. + """ + + model_type = "SeZMProperty" + + def __init__( + self, + *args: Any, + bridging_method: str = "none", + **kwargs: Any, + ) -> None: + if str(bridging_method).upper() != "NONE": + raise ValueError( + "SeZM property fitting does not support analytical bridging " + "potentials; set `bridging_method` to `none`." + ) + super().__init__(*args, bridging_method=bridging_method, **kwargs) + + def _translate_property_output( + self, + model_ret: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + """Translate lower property keys to public prediction keys.""" + var_name = self.get_var_name() + model_predict = { + f"atom_{var_name}": model_ret[var_name], + var_name: model_ret[f"{var_name}_redu"], + } + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + force_input: torch.Tensor | None = None, + noise_mask: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Run invariant property prediction from coordinates.""" + del do_atomic_virial, force_input, noise_mask + model_ret = self.forward_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + charge_spin=charge_spin, + ) + return self._translate_property_output(model_ret) + + def forward_lower( + self, + coord: torch.Tensor, + atype: torch.Tensor, + edge_index: torch.Tensor, + edge_vec: torch.Tensor, + edge_scatter_index: torch.Tensor, + edge_mask: torch.Tensor, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + comm_dict: dict[str, torch.Tensor] | None = None, + extended_atype: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Run lower-interface invariant property prediction.""" + del do_atomic_virial + model_ret = self.forward_common_lower( + coord, + atype, + edge_index, + edge_vec, + edge_scatter_index, + edge_mask, + fparam=fparam, + aparam=aparam, + comm_dict=comm_dict, + extended_atype=extended_atype, + charge_spin=charge_spin, + ) + return self._translate_property_output(model_ret) + + def core_compute( + self, + coord: torch.Tensor, + atype: torch.Tensor, + edge_index: torch.Tensor, + edge_vec: torch.Tensor, + edge_scatter_index: torch.Tensor, + edge_mask: torch.Tensor, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + comm_dict: dict[str, torch.Tensor] | None = None, + extended_atype: torch.Tensor | None = None, + extended_coord_corr: torch.Tensor | None = None, + embedding_only: bool = False, + ) -> dict[str, torch.Tensor]: + """Compute property outputs through the SeZM forward-only graph.""" + return super().core_compute( + coord, + atype, + edge_index, + edge_vec, + edge_scatter_index, + edge_mask, + fparam=fparam, + aparam=aparam, + charge_spin=charge_spin, + comm_dict=comm_dict, + extended_atype=extended_atype, + extended_coord_corr=extended_coord_corr, + embedding_only=embedding_only, + conservative=False, + ) + + def _inductor_compile_options(self) -> dict[str, Any]: + """Augment the shared Inductor options for the property compile path. + + The non-conservative property backward graph triggers a TorchInductor + CPU codegen bug: a scalar ``where``/blendv is emitted as + ``decltype(scalar)::blendv(...)``, which the host C++ compiler rejects. + Forcing scalar CPU codegen (``cpp.simdlen = 0``) selects the path that + never emits the vectorized blendv. ``cpp.*`` options affect only the CPU + backend, so CUDA/Triton lowering -- the actual ``use_compile`` deployment + target -- is unchanged. + """ + options = super()._inductor_compile_options() + options["cpp.simdlen"] = 0 + return options + + def translated_output_def(self) -> dict[str, OutputVariableDef]: + """Return public output definitions for property prediction.""" + out_def_data = self.model_output_def().get_data() + var_name = self.get_var_name() + output_def = { + f"atom_{var_name}": out_def_data[var_name], + var_name: out_def_data[f"{var_name}_redu"], + } + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + return output_def + + def get_task_dim(self) -> int: + """Return the property output dimension.""" + return int(self.get_fitting_net().dim_out) + + def get_intensive(self) -> bool: + """Return whether the reduced property is intensive.""" + return bool(self.model_output_def()[self.get_var_name()].intensive) + + def get_var_name(self) -> str: + """Return the fitted property name.""" + return str(self.get_fitting_net().var_name) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index ba1a2cb347..4695a57efc 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3171,7 +3171,7 @@ def standard_model_args() -> Argument: ) def sezm_model_args() -> Argument: doc_descrpt = "Descriptor configuration for atomic environments. DPA4/SeZM uses the SeZM descriptor." - doc_fitting = "Fitting network configuration. DPA4/SeZM uses the `dpa4_ener` GLU energy fitting." + doc_fitting = "Fitting network configuration. DPA4/SeZM uses the `dpa4_ener` GLU energy fitting by default and also supports invariant `property` fitting." doc_model_branch_alias = ( "List of aliases for this model branch. " "Multiple aliases can be defined, and any alias can reference this branch throughout the model usage. " @@ -3234,7 +3234,7 @@ def sezm_model_args() -> Argument: "are saved with LoRA deltas folded into base weights, producing plain " "DPA4/SeZM checkpoints suitable for deployment." ) - doc_model = "DPA4/SeZM model scaffold with fixed SeZM descriptor and fitting types." + doc_model = "DPA4/SeZM model scaffold with the SeZM descriptor and selectable energy or invariant-property fitting." ca = Argument( "dpa4", @@ -3262,7 +3262,10 @@ def sezm_model_args() -> Argument: [ Variant( "type", - [fitting_args_plugin.get_argument("dpa4_ener")], + [ + fitting_args_plugin.get_argument("dpa4_ener"), + fitting_args_plugin.get_argument("property"), + ], optional=True, default_tag="dpa4_ener", doc="The type of the fitting.", @@ -4710,7 +4713,7 @@ def loss_tensor() -> list[Argument]: def loss_variant_type_args() -> Variant: - doc_loss = "The type of the loss. When the fitting type is `ener`, the loss type should be set to `ener`, `dens` (Only DPA4/SeZM supported), or left unset. When the fitting type is `dipole` or `polar`, the loss type should be set to `tensor`." + doc_loss = "The type of the loss. When the fitting type is `ener`, the loss type should be set to `ener`, `dens` (Only DPA4/SeZM supported), or left unset. When the fitting type is `property`, the loss type should be set to `property`. When the fitting type is `dipole` or `polar`, the loss type should be set to `tensor`." return Variant( "type", @@ -4722,7 +4725,7 @@ def loss_variant_type_args() -> Variant: def loss_args() -> list[Argument]: - doc_loss = "The definition of loss function. The loss type should be set to `tensor`, `ener`, `dens` or left unset." + doc_loss = "The definition of loss function. The loss type should be set to `tensor`, `property`, `ener`, `dens` or left unset." ca = Argument( "loss", dict, [], [loss_variant_type_args()], optional=True, doc=doc_loss ) diff --git a/doc/model/dpa4.md b/doc/model/dpa4.md index 2901b2ac2d..325f1074b1 100644 --- a/doc/model/dpa4.md +++ b/doc/model/dpa4.md @@ -9,10 +9,10 @@ Zone-bridging Model) architecture: an SO(3)-equivariant message-passing model for conservative interatomic potentials. The aliases `DPA4`, `SeZM`, and `sezm` all select the same implementation. -`model.type: "dpa4"` is a convenience scaffold that fixes the SeZM descriptor -and the `dpa4_ener` energy fitting network, so `descriptor.type` and -`fitting_net.type` may be omitted. A new input then needs only the model type, -`type_map`, and a few descriptor options. +`model.type: "dpa4"` is a convenience scaffold that selects the SeZM descriptor +and defaults to the `dpa4_ener` energy fitting network, so `descriptor.type` and +`fitting_net.type` may be omitted for energy training. A new energy input then +needs only the model type, `type_map`, and a few descriptor options. Reference: [DPA4 paper](https://arxiv.org/abs/2606.02419). @@ -128,6 +128,35 @@ predicts energies and forces are obtained by autograd: See [training energy models](train-energy.md) for the general workflow. +### Property training + +DPA4/SeZM can also train invariant structure properties by selecting the +standard property fitting network. Set `fitting_net.type` to `property`, provide +the property name used by the data file, and use the property loss: + +```json +{ + "model": { + "type": "dpa4", + "fitting_net": { + "type": "property", + "property_name": "band_prop", + "task_dim": 3, + "intensive": true + } + }, + "loss": { + "type": "property" + } +} +``` + +The property label follows the usual DeePMD property-data convention: each +system provides `band_prop.npy` when `property_name` is `band_prop`. Property +fitting is not a water task, so the complete DPA4/SeZM property input lives with +the property dataset: see `examples/property/train/input_dpa4.json`, which trains +on the QM9 property subset in `examples/property/data`. + ### Direct-force denoising (`dens`, experimental) DPA4/SeZM has an experimental direct-force denoising head: diff --git a/examples/property/train/README.md b/examples/property/train/README.md index 6e9345395c..44646594c8 100644 --- a/examples/property/train/README.md +++ b/examples/property/train/README.md @@ -1,4 +1,10 @@ -Some explanations of the parameters in `input.json`: +This directory contains property-fitting training inputs on the QM9 subset in +`../data`: + +- `input_torch.json`: DPA1 descriptor. +- `input_dpa4.json`: DPA4/SeZM descriptor. + +Some explanations of the shared property parameters: 1. `fitting_net/property_name` is the name of the property to be predicted. It should be consistent with the property name in the dataset. In each system, code will read `set.*/{property_name}.npy` file as prediction label if you use NumPy format data. 1. `fitting_net/task_dim` is the dimension of model output. It should be consistent with the property dimension in the dataset, which means if the shape of data stored in `set.*/{property_name}.npy` is `batch size * 3`, `fitting_net/task_dim` should be set to 3. diff --git a/examples/property/train/input_dpa4.json b/examples/property/train/input_dpa4.json new file mode 100644 index 0000000000..6995db8e06 --- /dev/null +++ b/examples/property/train/input_dpa4.json @@ -0,0 +1,126 @@ +{ + "_comment": "DPA4/SeZM invariant-property training example on the QM9 property subset in ../data (band_prop holds HOMO/LUMO/band gap, task_dim=3, intensive).", + "model": { + "type": "DPA4", + "type_map": [ + "H", + "C", + "N", + "O" + ], + "descriptor": { + "rcut": 6.0, + "channels": 32, + "n_radial": 16, + "use_env_seed": true, + "lmax": 3, + "mmax": 1, + "n_blocks": 2, + "so2_layers": 3, + "radial_so2_mode": "degree_channel", + "radial_so2_rank": 1, + "n_focus": 2, + "focus_dim": 0, + "n_atten_head": 1, + "ffn_neurons": 0, + "ffn_so3_grid": true, + "so3_readout": "mlp", + "grid_mlp": [ + false, + false, + false + ], + "grid_branch": [ + 1, + 1, + 1 + ], + "ffn_blocks": 2, + "sandwich_norm": [ + false, + true, + true, + false + ], + "message_node_so3": true, + "use_amp": true, + "precision": "float32", + "seed": 42 + }, + "fitting_net": { + "type": "property", + "property_name": "band_prop", + "task_dim": 3, + "intensive": true, + "neuron": [ + 0 + ], + "precision": "float32", + "seed": 42 + }, + "use_compile": false, + "enable_tf32": true + }, + "learning_rate": { + "type": "wsd", + "start_lr": 4.5e-4, + "stop_lr": 1e-6, + "warmup_steps": 5000, + "warmup_start_factor": 0.2, + "decay_phase_ratio": 0.65, + "decay_type": "cosine" + }, + "loss": { + "type": "property", + "metric": [ + "mae" + ], + "loss_func": "smooth_mae", + "beta": 1.0 + }, + "optimizer": { + "type": "HybridMuon", + "weight_decay": 0.001 + }, + "training": { + "stat_file": "./dpa4.hdf5", + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1" + ], + "batch_size": 1 + }, + "validation_data": { + "systems": [ + "../data/data_2" + ], + "batch_size": 1, + "numb_batch": 1 + }, + "numb_steps": 2000000, + "gradient_max_norm": 5.0, + "save_freq": 2000, + "max_ckpt_keep": 3, + "enable_ema": true, + "ema_decay": 0.999, + "ema_ckpt_keep": 3, + "disp_file": "lcurve.out", + "disp_freq": 1000, + "disp_avg": true, + "disp_training": true, + "time_training": true, + "tensorboard": false, + "enable_profiler": false, + "tensorboard_freq": 1000, + "tensorboard_log_dir": "tb_log", + "profiling": false, + "profiling_file": "timeline.json", + "zero_stage": 1, + "seed": 42 + }, + "validating": { + "compiled_infer": false, + "tf32_infer": false + } +} diff --git a/source/tests/pt/model/test_sezm_model.py b/source/tests/pt/model/test_sezm_model.py index 87b0a1cc31..6ec058602b 100644 --- a/source/tests/pt/model/test_sezm_model.py +++ b/source/tests/pt/model/test_sezm_model.py @@ -19,6 +19,7 @@ from deepmd.pt.loss import ( DeNSLoss, EnergyStdLoss, + PropertyLoss, ) from deepmd.pt.model.descriptor.sezm_nn import ( GatedActivation, @@ -38,6 +39,9 @@ InterPotential, SeZMModel, ) +from deepmd.pt.model.model.sezm_property_model import ( + SeZMPropertyModel, +) from deepmd.pt.train.training import ( prepare_model_for_loss, ) @@ -936,6 +940,168 @@ def test_multitask_compile_matches_eager(self) -> None: self._assert_multitask_compile_matches_eager(case_film_embd=False) +class TestSeZMModelProperty(unittest.TestCase): + """Test DPA4/SeZM invariant property fitting.""" + + def setUp(self) -> None: + self.device = env.DEVICE + torch.manual_seed(2024) + + @staticmethod + def _randomize_params(model: torch.nn.Module, seed: int = 1234) -> None: + """Fill all trainable tensors with deterministic small values.""" + torch.manual_seed(seed) + with torch.no_grad(): + for param in model.parameters(): + param.copy_(torch.randn_like(param) * 0.1) + + def _build_model_params(self, *, use_compile: bool, intensive: bool) -> dict: + return { + "type": "SeZM", + "type_map": ["A", "B"], + "descriptor": { + "type": "SeZM", + "sel": [2, 2], + "rcut": 3.0, + "channels": 4, + "n_focus": 1, + "n_radial": 3, + "radial_mlp": [6], + "use_env_seed": True, + "l_schedule": [1, 0], + "mmax": 1, + "so2_norm": False, + "so2_layers": 1, + "n_atten_head": 1, + "sandwich_norm": [True, False, True, False], + "ffn_neurons": 8, + "ffn_blocks": 1, + "s2_activation": [False, True], + "mlp_bias": False, + "layer_scale": False, + "use_amp": False, + "activation_function": "silu", + "glu_activation": True, + "precision": "float32", + "seed": 7, + }, + "fitting_net": { + "type": "property", + "property_name": "foo", + "task_dim": 3, + "intensive": intensive, + "neuron": [8], + "activation_function": "tanh", + "resnet_dt": True, + "precision": "float32", + "seed": 7, + }, + "use_compile": use_compile, + } + + def _make_tiny_frame( + self, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + generator = torch.Generator(device=self.device).manual_seed(2025) + box = 5.0 * torch.eye( + 3, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=self.device + ) + coord = ( + torch.rand( + [1, 5, 3], + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=self.device, + generator=generator, + ) + @ box + ) + atype = torch.tensor([[0, 0, 1, 1, 0]], dtype=torch.long, device=self.device) + return coord, atype, box.reshape(1, 9) + + def test_forward_shapes_and_reduction(self) -> None: + """Property outputs should use public property keys and reductions.""" + for intensive in (False, True): + model = get_sezm_model( + self._build_model_params(use_compile=False, intensive=intensive) + ).to(self.device) + self.assertIsInstance(model, SeZMPropertyModel) + self.assertEqual(model.get_var_name(), "foo") + self.assertEqual(model.get_task_dim(), 3) + self.assertEqual(model.get_intensive(), intensive) + + coord, atype, box = self._make_tiny_frame() + ret = model(coord, atype, box=box) + self.assertEqual(ret["atom_foo"].shape, (1, 5, 3)) + self.assertEqual(ret["foo"].shape, (1, 3)) + self.assertEqual(ret["mask"].shape, (1, 5)) + self.assertNotIn("force", ret) + self.assertNotIn("virial", ret) + if intensive: + expected = ret["atom_foo"].mean(dim=1) + else: + expected = ret["atom_foo"].sum(dim=1) + torch.testing.assert_close(ret["foo"], expected) + + def test_property_loss_and_serialization(self) -> None: + """PropertyLoss metadata and model serialization should round-trip.""" + from deepmd.pt.model.model.model import ( + BaseModel, + ) + + model = get_sezm_model( + self._build_model_params(use_compile=False, intensive=True) + ).to(self.device) + loss = PropertyLoss( + task_dim=model.get_task_dim(), + var_name=model.get_var_name(), + intensive=model.get_intensive(), + ) + self.assertEqual(loss.var_name, "foo") + + data = model.serialize() + self.assertEqual(data["type"], "SeZMProperty") + model2 = BaseModel.deserialize(data).to(self.device) + self.assertIsInstance(model2, SeZMPropertyModel) + + coord, atype, box = self._make_tiny_frame() + ret = model2(coord, atype, box=box) + self.assertEqual(ret["foo"].shape, (1, 3)) + + @unittest.skipIf(_SKIP_OFF_COMPILE_TORCH, _SKIP_OFF_COMPILE_TORCH_REASON) + def test_compile_matches_eager_and_backpropagates(self) -> None: + """Compiled property forward should match eager and keep gradients.""" + eager = get_sezm_model( + self._build_model_params(use_compile=False, intensive=False) + ).to(self.device) + compiled = get_sezm_model( + self._build_model_params(use_compile=True, intensive=False) + ).to(self.device) + self._randomize_params(eager) + compiled.load_state_dict(eager.state_dict()) + eager.train() + compiled.train() + + coord, atype, box = self._make_tiny_frame() + ret_eager = eager(coord, atype, box=box) + ret_compiled = compiled(coord, atype, box=box) + _assert_close_with_strict_warning( + ret_compiled["foo"], + ret_eager["foo"], + atol=1.0e-5, + rtol=1.0e-5, + msg="compiled property mismatch", + ) + self.assertIn((True, False), compiled.compiled_core_compute_cache) + + loss = ret_compiled["foo"].sum() + loss.backward() + grad_found = any( + param.grad is not None and torch.count_nonzero(param.grad).item() > 0 + for param in compiled.parameters() + ) + self.assertTrue(grad_found) + + class TestInterPotential(unittest.TestCase): """Test InterPotential ZBL analytical pair potential."""