diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 173f538a27..469b6c008f 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -27,6 +27,9 @@ NativeLayer, get_activation_fn, ) +from deepmd.dpmodel.utils.safe_gradient import ( + safe_for_vector_norm, +) from deepmd.dpmodel.utils.seed import ( child_seed, ) @@ -415,7 +418,7 @@ def call( # nf x nloc x a_nnei x 3 normalized_diff_i = a_diff / ( - xp.linalg.vector_norm(a_diff, axis=-1, keepdims=True) + 1e-6 + safe_for_vector_norm(a_diff, axis=-1, keepdims=True) + 1e-6 ) # nf x nloc x 3 x a_nnei normalized_diff_j = xp.matrix_transpose(normalized_diff_i) diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index e15a20926f..aebc43a30f 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -28,6 +28,9 @@ NativeLayer, get_activation_fn, ) +from deepmd.dpmodel.utils.safe_gradient import ( + safe_for_vector_norm, +) from deepmd.dpmodel.utils.seed import ( child_seed, ) @@ -414,7 +417,7 @@ def call( if not self.direct_dist: g2, h2 = xp.split(dmatrix, [1], axis=-1) else: - g2, h2 = xp.linalg.vector_norm(diff, axis=-1, keepdims=True), diff + g2, h2 = safe_for_vector_norm(diff, axis=-1, keepdims=True), diff g2 = g2 / self.rcut h2 = h2 / self.rcut # nf x nloc x nnei x ng2 diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index bcecf62775..6c17ad7702 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -13,6 +13,9 @@ support_array_api, xp_take_along_axis, ) +from deepmd.dpmodel.utils.safe_gradient import ( + safe_for_vector_norm, +) @support_array_api(version="2023.12") @@ -58,8 +61,7 @@ def _make_env_mat( diff = coord_r - coord_l # nf x nloc x nnei # the grad of JAX vector_norm is NaN at x=0 - diff_ = xp.where(xp.abs(diff) < 1e-30, xp.full_like(diff, 1e-30), diff) - length = xp.linalg.vector_norm(diff_, axis=-1, keepdims=True) + length = safe_for_vector_norm(diff, axis=-1, keepdims=True) # for index 0 nloc atom length = length + xp.astype(~xp.expand_dims(mask, axis=-1), length.dtype) t0 = 1 / (length + protection)