From 95d35c03f0155c9b1fde4f0b4690b5d3aff4a0a6 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 28 Feb 2025 22:16:31 +0800 Subject: [PATCH 1/2] fix(array-api): fix compat with Array API 2024.12 Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/dpa1.py | 1 + deepmd/dpmodel/descriptor/repformers.py | 1 + deepmd/dpmodel/descriptor/se_t_tebd.py | 1 + 3 files changed, 3 insertions(+) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 37b677cf23..2ce05bd160 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -899,6 +899,7 @@ def call( exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) # nfnl x nnei exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) + exclude_mask = xp.astype(exclude_mask, xp.bool) # nfnl x nnei nlist = xp.reshape(nlist, (nf * nloc, nnei)) nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1)) diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index ae6b5de511..e15a20926f 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -393,6 +393,7 @@ def call( ): xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) + exclude_mask = xp.astype(exclude_mask, xp.bool) nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1)) # nf x nloc x nnei x 4 dmatrix, diff, sw = self.env_mat.call( diff --git a/deepmd/dpmodel/descriptor/se_t_tebd.py b/deepmd/dpmodel/descriptor/se_t_tebd.py index 570f9a47e8..1b0f44ec97 100644 --- a/deepmd/dpmodel/descriptor/se_t_tebd.py +++ b/deepmd/dpmodel/descriptor/se_t_tebd.py @@ -682,6 +682,7 @@ def call( exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) # nfnl x nnei nlist = xp.reshape(nlist, (nf * nloc, nnei)) + exclude_mask = xp.astype(exclude_mask, xp.bool) nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1)) # nfnl x nnei nlist_mask = nlist != -1 From 6c25b3981042c47737c6a1f60d4718fb5892daae Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 28 Feb 2025 23:04:26 +0800 Subject: [PATCH 2/2] exclude_mask bool Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/fitting/general_fitting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 388932f297..7342663141 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -488,6 +488,7 @@ def _call_common( ) # nf x nloc exclude_mask = self.emask.build_type_exclude_mask(atype) + exclude_mask = xp.astype(exclude_mask, xp.bool) # nf x nloc x nod outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs)) return {self.var_name: outs}