diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py index 27c7f8883d..2bfa6dc5e7 100644 --- a/deepmd/jax/common.py +++ b/deepmd/jax/common.py @@ -1,14 +1,15 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from functools import ( + wraps, +) from typing import ( Any, + TypeVar, overload, ) import numpy as np -from deepmd.dpmodel.common import ( - NativeOP, -) from deepmd.jax.env import ( jnp, nnx, @@ -41,19 +42,22 @@ def to_jax_array(array: np.ndarray | None) -> jnp.ndarray | None: return jnp.array(array) +T = TypeVar("T") + + def flax_module( - module: NativeOP, -) -> nnx.Module: + module: type[T], +) -> type[T]: # runtime: actually returns type[T & nnx.Module] """Convert a NativeOP to a Flax module. Parameters ---------- - module : NativeOP + module : type[NativeOP] The NativeOP to convert. Returns ------- - flax.nnx.Module + type[flax.nnx.Module] The Flax module. Examples @@ -72,6 +76,7 @@ class MixedMetaClass(*metas): def __call__(self, *args: Any, **kwargs: Any) -> Any: return type(nnx.Module).__call__(self, *args, **kwargs) + @wraps(module, updated=()) class FlaxModule(module, nnx.Module, metaclass=MixedMetaClass): def __init_subclass__(cls, **kwargs: Any) -> None: return super().__init_subclass__(**kwargs)