Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions deepmd/jax/common.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from functools import (
wraps,
)
from typing import (
Any,
TypeVar,
overload,
)

import numpy as np

from deepmd.dpmodel.common import (
NativeOP,
)
from deepmd.jax.env import (
jnp,
nnx,
Expand Down Expand Up @@ -41,19 +42,22 @@ def to_jax_array(array: np.ndarray | None) -> jnp.ndarray | None:
return jnp.array(array)


T = TypeVar("T")
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def flax_module(
module: NativeOP,
) -> nnx.Module:
module: type[T],
) -> type[T]: # runtime: actually returns type[T & nnx.Module]
"""Convert a NativeOP to a Flax module.

Parameters
----------
module : NativeOP
module : type[NativeOP]
The NativeOP to convert.

Returns
-------
flax.nnx.Module
type[flax.nnx.Module]
The Flax module.
Comment thread
njzjz marked this conversation as resolved.

Examples
Expand All @@ -72,6 +76,7 @@ class MixedMetaClass(*metas):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return type(nnx.Module).__call__(self, *args, **kwargs)

@wraps(module, updated=())
class FlaxModule(module, nnx.Module, metaclass=MixedMetaClass):
def __init_subclass__(cls, **kwargs: Any) -> None:
Comment thread
njzjz marked this conversation as resolved.
return super().__init_subclass__(**kwargs)
Expand Down