From bd2b832d810d439b17583efd89317b789ac4d843 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 23 Mar 2025 00:39:35 +0800 Subject: [PATCH 1/2] fix(jax): use `safe_for_vector_norm` for env mat and dpa3 This fixes the NaN when calculating the Hessian in the JAX backend. Reproduce NaN: ```py from deepmd.pt.utils.serialization import serialize_from_file from deepmd.jax.model.ener_model import EnergyModel from deepmd.jax.env import jax, jnp import numpy as np jax.config.update("jax_debug_nans", True) model = serialize_from_file('frozen_model.pth') model = EnergyModel.deserialize(model["model"]) model.enable_hessian() model_call = jax.jit(model.call) # nframes x natoms x 3 coord = np.array([[[0,0,0],[1,1,1]]], dtype=np.float64) # nframes x natoms atype = np.array([[0,1]], dtype=int) print(model_call(jnp.array(coord), jnp.array(atype), None)['energy_derv_r_derv_r']) ``` Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/repflows.py | 5 ++++- deepmd/dpmodel/utils/env_mat.py | 6 ++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 173f538a27..469b6c008f 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -27,6 +27,9 @@ NativeLayer, get_activation_fn, ) +from deepmd.dpmodel.utils.safe_gradient import ( + safe_for_vector_norm, +) from deepmd.dpmodel.utils.seed import ( child_seed, ) @@ -415,7 +418,7 @@ def call( # nf x nloc x a_nnei x 3 normalized_diff_i = a_diff / ( - xp.linalg.vector_norm(a_diff, axis=-1, keepdims=True) + 1e-6 + safe_for_vector_norm(a_diff, axis=-1, keepdims=True) + 1e-6 ) # nf x nloc x 3 x a_nnei normalized_diff_j = xp.matrix_transpose(normalized_diff_i) diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index bcecf62775..6c17ad7702 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -13,6 +13,9 @@ support_array_api, xp_take_along_axis, ) +from deepmd.dpmodel.utils.safe_gradient import ( + safe_for_vector_norm, +) @support_array_api(version="2023.12") @@ -58,8 +61,7 @@ def _make_env_mat( diff = coord_r - coord_l # nf x nloc x nnei # the grad of JAX vector_norm is NaN at x=0 - diff_ = xp.where(xp.abs(diff) < 1e-30, xp.full_like(diff, 1e-30), diff) - length = xp.linalg.vector_norm(diff_, axis=-1, keepdims=True) + length = safe_for_vector_norm(diff, axis=-1, keepdims=True) # for index 0 nloc atom length = length + xp.astype(~xp.expand_dims(mask, axis=-1), length.dtype) t0 = 1 / (length + protection) From 62c3f2bb2ccd66f3c53f3e81d59c34edc9afa438 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 25 Mar 2025 20:35:13 +0800 Subject: [PATCH 2/2] update repformers Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/repformers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index e15a20926f..aebc43a30f 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -28,6 +28,9 @@ NativeLayer, get_activation_fn, ) +from deepmd.dpmodel.utils.safe_gradient import ( + safe_for_vector_norm, +) from deepmd.dpmodel.utils.seed import ( child_seed, ) @@ -414,7 +417,7 @@ def call( if not self.direct_dist: g2, h2 = xp.split(dmatrix, [1], axis=-1) else: - g2, h2 = xp.linalg.vector_norm(diff, axis=-1, keepdims=True), diff + g2, h2 = safe_for_vector_norm(diff, axis=-1, keepdims=True), diff g2 = g2 / self.rcut h2 = h2 / self.rcut # nf x nloc x nnei x ng2