From 94835cde0b51731d173f25f9d5fa8a40d19b08b1 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 29 May 2025 01:30:49 +0800 Subject: [PATCH] fix(jax): fix repflows JIT issues Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/repflows.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 926b500645..c8366969d6 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -1324,7 +1324,10 @@ def call( ) nb, nloc, nnei = nlist.shape nall = node_ebd_ext.shape[1] - n_edge = int(xp.sum(xp.astype(nlist_mask, xp.int32))) + # int cannot jit; do not run it when self.use_dynamic_sel == False + n_edge = ( + int(xp.sum(xp.astype(nlist_mask, xp.int32))) if self.use_dynamic_sel else 0 + ) node_ebd = node_ebd_ext[:, :nloc, :] assert (nb, nloc) == node_ebd.shape[:2] if not self.use_dynamic_sel: