From 51985af085cb73093315a55f12d90b0122a416bd Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 10 Feb 2026 05:48:07 +0800 Subject: [PATCH 1/4] fix(jax): improve JAX modules' names Use `wraps` to keep the modules' names, so they won't be `FlaxModule`, which cannot be regonized. I realized it when implementing #5213. --- deepmd/jax/common.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py index 27c7f8883d..3acc51ce93 100644 --- a/deepmd/jax/common.py +++ b/deepmd/jax/common.py @@ -1,4 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from functools import ( + wraps, +) from typing import ( Any, overload, @@ -42,18 +45,18 @@ def to_jax_array(array: np.ndarray | None) -> jnp.ndarray | None: def flax_module( - module: NativeOP, -) -> nnx.Module: + module: type[NativeOP], +) -> type[nnx.Module]: """Convert a NativeOP to a Flax module. Parameters ---------- - module : NativeOP + module : type[NativeOP] The NativeOP to convert. Returns ------- - flax.nnx.Module + type[flax.nnx.Module] The Flax module. Examples @@ -72,6 +75,7 @@ class MixedMetaClass(*metas): def __call__(self, *args: Any, **kwargs: Any) -> Any: return type(nnx.Module).__call__(self, *args, **kwargs) + @wraps(module, updated=()) class FlaxModule(module, nnx.Module, metaclass=MixedMetaClass): def __init_subclass__(cls, **kwargs: Any) -> None: return super().__init_subclass__(**kwargs) From 551721a594cc18de26dc7479320774bd6473ff6c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 11 Feb 2026 00:26:14 +0800 Subject: [PATCH 2/4] address Copilot's suggestion Refactor flax_module to use a generic type for better type safety. Signed-off-by: Jinzhe Zeng --- deepmd/jax/common.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py index 3acc51ce93..073d39c102 100644 --- a/deepmd/jax/common.py +++ b/deepmd/jax/common.py @@ -44,9 +44,12 @@ def to_jax_array(array: np.ndarray | None) -> jnp.ndarray | None: return jnp.array(array) +T = TypeVar('T') + + def flax_module( - module: type[NativeOP], -) -> type[nnx.Module]: + module: type[T], +) -> type[T]: # runtime: actually returns type[T & nnx.Module] """Convert a NativeOP to a Flax module. Parameters From 3688dd23549fe56a9def016ec5d3f65d42c3aaf1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 16:28:04 +0000 Subject: [PATCH 3/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/jax/common.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py index 073d39c102..8adcf3e4dc 100644 --- a/deepmd/jax/common.py +++ b/deepmd/jax/common.py @@ -9,9 +9,6 @@ import numpy as np -from deepmd.dpmodel.common import ( - NativeOP, -) from deepmd.jax.env import ( jnp, nnx, @@ -44,12 +41,12 @@ def to_jax_array(array: np.ndarray | None) -> jnp.ndarray | None: return jnp.array(array) -T = TypeVar('T') +T = TypeVar("T") def flax_module( module: type[T], -) -> type[T]: # runtime: actually returns type[T & nnx.Module] +) -> type[T]: # runtime: actually returns type[T & nnx.Module] """Convert a NativeOP to a Flax module. Parameters From f1b2b3e46349a433e2f2d6fa451140c49bb57369 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 11 Feb 2026 00:35:47 +0800 Subject: [PATCH 4/4] Add TypeVar import to common.py Signed-off-by: Jinzhe Zeng --- deepmd/jax/common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py index 8adcf3e4dc..2bfa6dc5e7 100644 --- a/deepmd/jax/common.py +++ b/deepmd/jax/common.py @@ -4,6 +4,7 @@ ) from typing import ( Any, + TypeVar, overload, )