From c558c1447293ff44a7f17c7f3305760f8b4cae63 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 13 Jun 2026 23:34:12 +0800 Subject: [PATCH] refactor(tests): auto-convert array-api-strict modules --- source/tests/array_api_strict/common.py | 211 +++++++++++++++++- .../array_api_strict/descriptor/__init__.py | 20 ++ .../tests/array_api_strict/descriptor/dpa1.py | 67 ++---- .../tests/array_api_strict/descriptor/dpa2.py | 53 +---- .../tests/array_api_strict/descriptor/dpa3.py | 34 +-- .../array_api_strict/descriptor/hybrid.py | 15 +- .../array_api_strict/descriptor/repflows.py | 55 +---- .../array_api_strict/descriptor/repformers.py | 81 ++----- .../descriptor/se_atten_v2.py | 9 + .../array_api_strict/descriptor/se_e2_a.py | 29 +-- .../array_api_strict/descriptor/se_e2_r.py | 29 +-- .../tests/array_api_strict/descriptor/se_t.py | 29 +-- .../array_api_strict/descriptor/se_t_tebd.py | 41 +--- .../array_api_strict/fitting/__init__.py | 15 ++ .../tests/array_api_strict/fitting/fitting.py | 62 ++--- .../array_api_strict/utils/exclude_mask.py | 18 +- .../tests/array_api_strict/utils/network.py | 65 +++++- .../array_api_strict/utils/type_embed.py | 18 +- 18 files changed, 421 insertions(+), 430 deletions(-) diff --git a/source/tests/array_api_strict/common.py b/source/tests/array_api_strict/common.py index 50109ded86..546e111dd2 100644 --- a/source/tests/array_api_strict/common.py +++ b/source/tests/array_api_strict/common.py @@ -1,11 +1,29 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from collections.abc import ( + Callable, +) +from functools import ( + wraps, +) +from importlib import ( + import_module, +) +from typing import ( + Any, + TypeVar, +) + import array_api_strict import numpy as np +from deepmd.dpmodel.common import ( + NativeOP, +) + -def to_array_api_strict_array(array: np.ndarray | None): - """Convert a numpy array to a JAX array. +def to_array_api_strict_array(array: np.ndarray | None) -> Any: + """Convert a numpy array to an array-api-strict array. Parameters ---------- @@ -14,9 +32,194 @@ def to_array_api_strict_array(array: np.ndarray | None): Returns ------- - jnp.ndarray - The JAX tensor. + array_api_strict.Array + The array-api-strict array. """ if array is None: return None return array_api_strict.asarray(array) + + +_PACKAGE_ROOT = __name__.rsplit(".", 1)[0] +_DPMODEL_TO_STRICT: dict[type[Any], Callable[[Any], Any]] = {} +_AUTO_WRAPPED_CLASSES: dict[type[NativeOP], type[Any]] = {} +_REGISTRATIONS_READY = False +_REGISTRATIONS_IN_PROGRESS = False +_REGISTRATION_MODULES = ( + f"{_PACKAGE_ROOT}.utils.network", + f"{_PACKAGE_ROOT}.utils.exclude_mask", + f"{_PACKAGE_ROOT}.utils.type_embed", + f"{_PACKAGE_ROOT}.descriptor.dpa1", + f"{_PACKAGE_ROOT}.descriptor.se_atten_v2", + f"{_PACKAGE_ROOT}.descriptor.se_e2_a", + f"{_PACKAGE_ROOT}.descriptor.se_e2_r", + f"{_PACKAGE_ROOT}.descriptor.se_t", + f"{_PACKAGE_ROOT}.descriptor.se_t_tebd", + f"{_PACKAGE_ROOT}.descriptor.repformers", + f"{_PACKAGE_ROOT}.descriptor.dpa2", + f"{_PACKAGE_ROOT}.descriptor.repflows", + f"{_PACKAGE_ROOT}.descriptor.dpa3", + f"{_PACKAGE_ROOT}.descriptor.hybrid", + f"{_PACKAGE_ROOT}.fitting", +) + + +class ArrayAPIList(list): + def append(self, item: Any) -> None: + return super().append(convert_array_api_strict_value(item)) + + def extend(self, items: list[Any]) -> None: + return super().extend(convert_array_api_strict_value(item) for item in items) + + def insert(self, index: int, item: Any) -> None: + return super().insert(index, convert_array_api_strict_value(item)) + + def __setitem__(self, index: Any, item: Any) -> None: + if isinstance(index, slice): + item = [convert_array_api_strict_value(ii) for ii in item] + else: + item = convert_array_api_strict_value(item) + return super().__setitem__(index, item) + + +def register_dpmodel_mapping( + dpmodel_cls: type[Any], converter: Callable[[Any], Any] +) -> None: + """Register how to convert a dpmodel object to its array-api-strict wrapper.""" + _DPMODEL_TO_STRICT[dpmodel_cls] = converter + + +def _looks_like_dpmodel_class(cls: type[Any]) -> bool: + module = cls.__module__ + return module == "deepmd.dpmodel" or module.startswith("deepmd.dpmodel.") + + +def _looks_like_dpmodel_object(value: Any) -> bool: + return _looks_like_dpmodel_class(type(value)) + + +def _looks_like_strict_object(value: Any) -> bool: + module = type(value).__module__ + return module == _PACKAGE_ROOT or module.startswith(f"{_PACKAGE_ROOT}.") + + +def _ensure_registrations() -> None: + global _REGISTRATIONS_IN_PROGRESS, _REGISTRATIONS_READY + + if _REGISTRATIONS_READY or _REGISTRATIONS_IN_PROGRESS: + return + + _REGISTRATIONS_IN_PROGRESS = True + try: + for module in _REGISTRATION_MODULES: + import_module(module) + _REGISTRATIONS_READY = True + finally: + _REGISTRATIONS_IN_PROGRESS = False + + +def try_convert_module(value: Any) -> Any | None: + """Convert a registered dpmodel object to its array-api-strict wrapper.""" + if _looks_like_strict_object(value): + return None + converter = _DPMODEL_TO_STRICT.get(type(value)) + if converter is not None: + return converter(value) + if _looks_like_dpmodel_object(value): + _ensure_registrations() + converter = _DPMODEL_TO_STRICT.get(type(value)) + if converter is not None: + return converter(value) + if isinstance(value, NativeOP): + return _auto_wrap_native_op(value) + return None + + +def _auto_wrap_native_op(value: NativeOP) -> Any: + cls = type(value) + if cls not in _AUTO_WRAPPED_CLASSES: + wrapped_cls = type( + cls.__name__, + (cls,), + { + "__module__": __name__, + "__qualname__": cls.__qualname__, + }, + ) + _AUTO_WRAPPED_CLASSES[cls] = array_api_strict_module(wrapped_cls) + wrapped_cls = _AUTO_WRAPPED_CLASSES[cls] + if not (hasattr(value, "serialize") and hasattr(wrapped_cls, "deserialize")): + raise TypeError( + f"Cannot auto-wrap {cls.__name__}: " + "it must implement serialize()/deserialize() or be explicitly " + "registered via register_dpmodel_mapping()." + ) + return wrapped_cls.deserialize(value.serialize()) + + +def _try_convert_list(value: list[Any], *, keep_converting: bool = False) -> list[Any]: + converted = ArrayAPIList() if keep_converting else [] + changed = keep_converting + for item in value: + converted_item = convert_array_api_strict_value(item) + converted.append(converted_item) + changed = changed or converted_item is not item + return converted if changed else value + + +def convert_array_api_strict_value(value: Any) -> Any: + if isinstance(value, np.ndarray): + return to_array_api_strict_array(value) + + if isinstance(value, list): + return _try_convert_list(value) + + converted = try_convert_module(value) + if converted is not None: + return converted + + return value + + +def array_api_strict_setattr(obj: Any, name: str, value: Any) -> Any: + if name in getattr(obj, "_array_api_strict_skip_auto_convert_attrs", ()): + return value + + if isinstance(value, list) and name in getattr( + obj, "_array_api_strict_data_list_attrs", () + ): + return _try_convert_list(value, keep_converting=True) + + return convert_array_api_strict_value(value) + + +T = TypeVar("T") + + +def array_api_strict_module(module: type[T]) -> type[T]: + """Add array-api-strict conversion to a dpmodel subclass.""" + original_setattr = module.__setattr__ + + @wraps(original_setattr) + def __setattr__(self: Any, name: str, value: Any) -> None: + value = array_api_strict_setattr(self, name, value) + return original_setattr(self, name, value) + + module.__setattr__ = __setattr__ # type: ignore[method-assign] + + if hasattr(module, "deserialize"): + for base in module.__bases__: + if base in (object, NativeOP): + continue + if ( + _looks_like_dpmodel_class(base) + and hasattr(base, "serialize") + and base not in _DPMODEL_TO_STRICT + ): + + def _converter(v: Any, _cls: type[Any] = module) -> Any: + return _cls.deserialize(v.serialize()) + + _DPMODEL_TO_STRICT[base] = _converter + + return module diff --git a/source/tests/array_api_strict/descriptor/__init__.py b/source/tests/array_api_strict/descriptor/__init__.py index bd778e364d..1bbefbea6f 100644 --- a/source/tests/array_api_strict/descriptor/__init__.py +++ b/source/tests/array_api_strict/descriptor/__init__.py @@ -2,19 +2,39 @@ from .dpa1 import ( DescrptDPA1, ) +from .dpa2 import ( + DescrptDPA2, +) +from .dpa3 import ( + DescrptDPA3, +) from .hybrid import ( DescrptHybrid, ) +from .se_atten_v2 import ( + DescrptSeAttenV2, +) from .se_e2_a import ( DescrptSeA, ) from .se_e2_r import ( DescrptSeR, ) +from .se_t import ( + DescrptSeT, +) +from .se_t_tebd import ( + DescrptSeTTebd, +) __all__ = [ "DescrptDPA1", + "DescrptDPA2", + "DescrptDPA3", "DescrptHybrid", "DescrptSeA", + "DescrptSeAttenV2", "DescrptSeR", + "DescrptSeT", + "DescrptSeTTebd", ] diff --git a/source/tests/array_api_strict/descriptor/dpa1.py b/source/tests/array_api_strict/descriptor/dpa1.py index d14444f269..5514a95228 100644 --- a/source/tests/array_api_strict/descriptor/dpa1.py +++ b/source/tests/array_api_strict/descriptor/dpa1.py @@ -1,8 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) - from deepmd.dpmodel.descriptor.dpa1 import DescrptBlockSeAtten as DescrptBlockSeAttenDP from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP from deepmd.dpmodel.descriptor.dpa1 import GatedAttentionLayer as GatedAttentionLayerDP @@ -14,73 +10,38 @@ ) from ..common import ( - to_array_api_strict_array, -) -from ..utils.exclude_mask import ( - PairExcludeMask, -) -from ..utils.network import ( - LayerNorm, - NativeLayer, - NetworkCollection, -) -from ..utils.type_embed import ( - TypeEmbedNet, + array_api_strict_module, ) +from ..utils import exclude_mask as _strict_exclude_mask # noqa: F401 +from ..utils import network as _strict_network # noqa: F401 +from ..utils import type_embed as _strict_type_embed # noqa: F401 from .base_descriptor import ( BaseDescriptor, ) +@array_api_strict_module class GatedAttentionLayer(GatedAttentionLayerDP): - def __setattr__(self, name: str, value: Any) -> None: - if name in {"in_proj", "out_proj"}: - value = NativeLayer.deserialize(value.serialize()) - return super().__setattr__(name, value) + pass +@array_api_strict_module class NeighborGatedAttentionLayer(NeighborGatedAttentionLayerDP): - def __setattr__(self, name: str, value: Any) -> None: - if name == "attention_layer": - value = GatedAttentionLayer.deserialize(value.serialize()) - elif name == "attn_layer_norm": - value = LayerNorm.deserialize(value.serialize()) - return super().__setattr__(name, value) + pass +@array_api_strict_module class NeighborGatedAttention(NeighborGatedAttentionDP): - def __setattr__(self, name: str, value: Any) -> None: - if name == "attention_layers": - value = [ - NeighborGatedAttentionLayer.deserialize(ii.serialize()) for ii in value - ] - return super().__setattr__(name, value) + pass +@array_api_strict_module class DescrptBlockSeAtten(DescrptBlockSeAttenDP): - def __setattr__(self, name: str, value: Any) -> None: - if name in {"mean", "stddev"}: - value = to_array_api_strict_array(value) - elif name in {"embeddings", "embeddings_strip"}: - if value is not None: - value = NetworkCollection.deserialize(value.serialize()) - elif name == "dpa1_attention": - value = NeighborGatedAttention.deserialize(value.serialize()) - elif name == "env_mat": - # env_mat doesn't store any value - pass - elif name == "emask": - value = PairExcludeMask(value.ntypes, value.exclude_types) - - return super().__setattr__(name, value) + pass @BaseDescriptor.register("dpa1") @BaseDescriptor.register("se_atten") +@array_api_strict_module class DescrptDPA1(DescrptDPA1DP): - def __setattr__(self, name: str, value: Any) -> None: - if name == "se_atten": - value = DescrptBlockSeAtten.deserialize(value.serialize()) - elif name == "type_embedding": - value = TypeEmbedNet.deserialize(value.serialize()) - return super().__setattr__(name, value) + pass diff --git a/source/tests/array_api_strict/descriptor/dpa2.py b/source/tests/array_api_strict/descriptor/dpa2.py index a510c6b461..f3a145531d 100644 --- a/source/tests/array_api_strict/descriptor/dpa2.py +++ b/source/tests/array_api_strict/descriptor/dpa2.py @@ -1,57 +1,20 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) - from deepmd.dpmodel.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2DP -from deepmd.dpmodel.utils.network import Identity as IdentityDP -from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP from ..common import ( - to_array_api_strict_array, -) -from ..utils.network import ( - NativeLayer, -) -from ..utils.type_embed import ( - TypeEmbedNet, + array_api_strict_module, ) +from ..utils import network as _strict_network # noqa: F401 +from ..utils import type_embed as _strict_type_embed # noqa: F401 +from . import dpa1 as _strict_dpa1 # noqa: F401 +from . import repformers as _strict_repformers # noqa: F401 +from . import se_t_tebd as _strict_se_t_tebd # noqa: F401 from .base_descriptor import ( BaseDescriptor, ) -from .dpa1 import ( - DescrptBlockSeAtten, -) -from .repformers import ( - DescrptBlockRepformers, -) -from .se_t_tebd import ( - DescrptBlockSeTTebd, -) @BaseDescriptor.register("dpa2") +@array_api_strict_module class DescrptDPA2(DescrptDPA2DP): - def __setattr__(self, name: str, value: Any) -> None: - if name in {"mean", "stddev"}: - value = to_array_api_strict_array(value) - elif name in {"repinit"}: - value = DescrptBlockSeAtten.deserialize(value.serialize()) - elif name in {"repinit_three_body"}: - if value is not None: - value = DescrptBlockSeTTebd.deserialize(value.serialize()) - elif name in {"repformers"}: - value = DescrptBlockRepformers.deserialize(value.serialize()) - elif name in {"type_embedding"}: - value = TypeEmbedNet.deserialize(value.serialize()) - elif name in {"g1_shape_tranform", "tebd_transform"}: - if value is None: - pass - elif isinstance(value, NativeLayerDP): - value = NativeLayer.deserialize(value.serialize()) - elif isinstance(value, IdentityDP): - # IdentityDP doesn't contain any value - it's good to go - pass - else: - raise ValueError(f"Unknown layer type: {type(value)}") - return super().__setattr__(name, value) + pass diff --git a/source/tests/array_api_strict/descriptor/dpa3.py b/source/tests/array_api_strict/descriptor/dpa3.py index 0086713e93..a6d1059b2e 100644 --- a/source/tests/array_api_strict/descriptor/dpa3.py +++ b/source/tests/array_api_strict/descriptor/dpa3.py @@ -1,40 +1,18 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) - from deepmd.dpmodel.descriptor.dpa3 import DescrptDPA3 as DescrptDPA3DP from ..common import ( - to_array_api_strict_array, -) -from ..utils.network import ( - NativeLayer, -) -from ..utils.type_embed import ( - TypeEmbedNet, + array_api_strict_module, ) +from ..utils import network as _strict_network # noqa: F401 +from ..utils import type_embed as _strict_type_embed # noqa: F401 +from . import repflows as _strict_repflows # noqa: F401 from .base_descriptor import ( BaseDescriptor, ) -from .repflows import ( - DescrptBlockRepflows, -) @BaseDescriptor.register("dpa3") +@array_api_strict_module class DescrptDPA3(DescrptDPA3DP): - def __setattr__(self, name: str, value: Any) -> None: - if name in {"mean", "stddev"}: - value = to_array_api_strict_array(value) - elif name in {"repflows"}: - value = DescrptBlockRepflows.deserialize(value.serialize()) - elif name in {"type_embedding", "chg_embedding", "spin_embedding"}: - if value is not None: - value = TypeEmbedNet.deserialize(value.serialize()) - elif name in {"mix_cs_mlp"}: - if value is not None: - value = NativeLayer.deserialize(value.serialize()) - else: - pass - return super().__setattr__(name, value) + pass diff --git a/source/tests/array_api_strict/descriptor/hybrid.py b/source/tests/array_api_strict/descriptor/hybrid.py index aaaa24ed6b..2d77e7e815 100644 --- a/source/tests/array_api_strict/descriptor/hybrid.py +++ b/source/tests/array_api_strict/descriptor/hybrid.py @@ -1,12 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) - from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DescrptHybridDP from ..common import ( - to_array_api_strict_array, + array_api_strict_module, ) from .base_descriptor import ( BaseDescriptor, @@ -14,11 +10,6 @@ @BaseDescriptor.register("hybrid") +@array_api_strict_module class DescrptHybrid(DescrptHybridDP): - def __setattr__(self, name: str, value: Any) -> None: - if name in {"nlist_cut_idx"}: - value = [to_array_api_strict_array(vv) for vv in value] - elif name in {"descrpt_list"}: - value = [BaseDescriptor.deserialize(vv.serialize()) for vv in value] - - return super().__setattr__(name, value) + pass diff --git a/source/tests/array_api_strict/descriptor/repflows.py b/source/tests/array_api_strict/descriptor/repflows.py index 35fdd65e4f..55aa419bfd 100644 --- a/source/tests/array_api_strict/descriptor/repflows.py +++ b/source/tests/array_api_strict/descriptor/repflows.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( - Any, + ClassVar, ) from deepmd.dpmodel.descriptor.repflows import ( @@ -9,52 +9,21 @@ from deepmd.dpmodel.descriptor.repflows import RepFlowLayer as RepFlowLayerDP from ..common import ( - to_array_api_strict_array, -) -from ..utils.exclude_mask import ( - PairExcludeMask, -) -from ..utils.network import ( - NativeLayer, + array_api_strict_module, ) +from ..utils import exclude_mask as _strict_exclude_mask # noqa: F401 +from ..utils import network as _strict_network # noqa: F401 +@array_api_strict_module class DescrptBlockRepflows(DescrptBlockRepflowsDP): - def __setattr__(self, name: str, value: Any) -> None: - if name in {"mean", "stddev"}: - value = to_array_api_strict_array(value) - elif name in {"layers"}: - value = [RepFlowLayer.deserialize(layer.serialize()) for layer in value] - elif name in {"edge_embd", "angle_embd"}: - value = NativeLayer.deserialize(value.serialize()) - elif name in {"env_mat_edge", "env_mat_angle"}: - # env_mat doesn't store any value - pass - elif name == "emask": - value = PairExcludeMask(value.ntypes, value.exclude_types) - else: - pass - - return super().__setattr__(name, value) + pass +@array_api_strict_module class RepFlowLayer(RepFlowLayerDP): - def __setattr__(self, name: str, value: Any) -> None: - if name in { - "node_self_mlp", - "node_sym_linear", - "node_edge_linear", - "edge_self_linear", - "a_compress_n_linear", - "a_compress_e_linear", - "edge_angle_linear1", - "edge_angle_linear2", - "angle_self_linear", - }: - if value is not None: - value = NativeLayer.deserialize(value.serialize()) - elif name in {"n_residual", "e_residual", "a_residual"}: - value = [to_array_api_strict_array(vv) for vv in value] - else: - pass - return super().__setattr__(name, value) + _array_api_strict_data_list_attrs: ClassVar[set[str]] = { + "n_residual", + "e_residual", + "a_residual", + } diff --git a/source/tests/array_api_strict/descriptor/repformers.py b/source/tests/array_api_strict/descriptor/repformers.py index ff65ff849f..a13a9321f5 100644 --- a/source/tests/array_api_strict/descriptor/repformers.py +++ b/source/tests/array_api_strict/descriptor/repformers.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( - Any, + ClassVar, ) from deepmd.dpmodel.descriptor.repformers import ( @@ -17,82 +17,41 @@ from deepmd.dpmodel.descriptor.repformers import RepformerLayer as RepformerLayerDP from ..common import ( - to_array_api_strict_array, -) -from ..utils.exclude_mask import ( - PairExcludeMask, -) -from ..utils.network import ( - LayerNorm, - NativeLayer, + array_api_strict_module, ) +from ..utils import exclude_mask as _strict_exclude_mask # noqa: F401 +from ..utils import network as _strict_network # noqa: F401 +@array_api_strict_module class DescrptBlockRepformers(DescrptBlockRepformersDP): - def __setattr__(self, name: str, value: Any) -> None: - if name in {"mean", "stddev"}: - value = to_array_api_strict_array(value) - elif name in {"layers"}: - value = [RepformerLayer.deserialize(layer.serialize()) for layer in value] - elif name == "g2_embd": - value = NativeLayer.deserialize(value.serialize()) - elif name == "env_mat": - # env_mat doesn't store any value - pass - elif name == "emask": - value = PairExcludeMask(value.ntypes, value.exclude_types) - - return super().__setattr__(name, value) + pass +@array_api_strict_module class Atten2Map(Atten2MapDP): - def __setattr__(self, name: str, value: Any) -> None: - if name in {"mapqk"}: - value = NativeLayer.deserialize(value.serialize()) - return super().__setattr__(name, value) + pass +@array_api_strict_module class Atten2MultiHeadApply(Atten2MultiHeadApplyDP): - def __setattr__(self, name: str, value: Any) -> None: - if name in {"mapv", "head_map"}: - value = NativeLayer.deserialize(value.serialize()) - return super().__setattr__(name, value) + pass +@array_api_strict_module class Atten2EquiVarApply(Atten2EquiVarApplyDP): - def __setattr__(self, name: str, value: Any) -> None: - if name in {"head_map"}: - value = NativeLayer.deserialize(value.serialize()) - return super().__setattr__(name, value) + pass +@array_api_strict_module class LocalAtten(LocalAttenDP): - def __setattr__(self, name: str, value: Any) -> None: - if name in {"mapq", "mapkv", "head_map"}: - value = NativeLayer.deserialize(value.serialize()) - return super().__setattr__(name, value) + pass +@array_api_strict_module class RepformerLayer(RepformerLayerDP): - def __setattr__(self, name: str, value: Any) -> None: - if name in {"linear1", "linear2", "g1_self_mlp", "proj_g1g2", "proj_g1g1g2"}: - if value is not None: - value = NativeLayer.deserialize(value.serialize()) - elif name in {"g1_residual", "g2_residual", "h2_residual"}: - value = [to_array_api_strict_array(vv) for vv in value] - elif name in {"attn2g_map"}: - if value is not None: - value = Atten2Map.deserialize(value.serialize()) - elif name in {"attn2_mh_apply"}: - if value is not None: - value = Atten2MultiHeadApply.deserialize(value.serialize()) - elif name in {"attn2_lm"}: - if value is not None: - value = LayerNorm.deserialize(value.serialize()) - elif name in {"attn2_ev_apply"}: - if value is not None: - value = Atten2EquiVarApply.deserialize(value.serialize()) - elif name in {"loc_attn"}: - if value is not None: - value = LocalAtten.deserialize(value.serialize()) - return super().__setattr__(name, value) + _array_api_strict_data_list_attrs: ClassVar[set[str]] = { + "g1_residual", + "g2_residual", + "h2_residual", + } diff --git a/source/tests/array_api_strict/descriptor/se_atten_v2.py b/source/tests/array_api_strict/descriptor/se_atten_v2.py index a2e06ac0e2..84db19fc59 100644 --- a/source/tests/array_api_strict/descriptor/se_atten_v2.py +++ b/source/tests/array_api_strict/descriptor/se_atten_v2.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from deepmd.dpmodel.descriptor.se_atten_v2 import DescrptSeAttenV2 as DescrptSeAttenV2DP +from ..common import ( + register_dpmodel_mapping, +) from .dpa1 import ( DescrptDPA1, ) @@ -8,3 +11,9 @@ class DescrptSeAttenV2(DescrptDPA1, DescrptSeAttenV2DP): pass + + +register_dpmodel_mapping( + DescrptSeAttenV2DP, + lambda v: DescrptSeAttenV2.deserialize(v.serialize()), +) diff --git a/source/tests/array_api_strict/descriptor/se_e2_a.py b/source/tests/array_api_strict/descriptor/se_e2_a.py index 17da2aafbf..05cb0fde80 100644 --- a/source/tests/array_api_strict/descriptor/se_e2_a.py +++ b/source/tests/array_api_strict/descriptor/se_e2_a.py @@ -1,19 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) - from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP from ..common import ( - to_array_api_strict_array, -) -from ..utils.exclude_mask import ( - PairExcludeMask, -) -from ..utils.network import ( - NetworkCollection, + array_api_strict_module, ) +from ..utils import exclude_mask as _strict_exclude_mask # noqa: F401 +from ..utils import network as _strict_network # noqa: F401 from .base_descriptor import ( BaseDescriptor, ) @@ -21,17 +13,6 @@ @BaseDescriptor.register("se_e2_a") @BaseDescriptor.register("se_a") +@array_api_strict_module class DescrptSeA(DescrptSeADP): - def __setattr__(self, name: str, value: Any) -> None: - if name in {"dstd", "davg"}: - value = to_array_api_strict_array(value) - elif name in {"embeddings"}: - if value is not None: - value = NetworkCollection.deserialize(value.serialize()) - elif name == "env_mat": - # env_mat doesn't store any value - pass - elif name == "emask": - value = PairExcludeMask(value.ntypes, value.exclude_types) - - return super().__setattr__(name, value) + pass diff --git a/source/tests/array_api_strict/descriptor/se_e2_r.py b/source/tests/array_api_strict/descriptor/se_e2_r.py index b499f4c4c9..12f89d1c4c 100644 --- a/source/tests/array_api_strict/descriptor/se_e2_r.py +++ b/source/tests/array_api_strict/descriptor/se_e2_r.py @@ -1,19 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) - from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP from ..common import ( - to_array_api_strict_array, -) -from ..utils.exclude_mask import ( - PairExcludeMask, -) -from ..utils.network import ( - NetworkCollection, + array_api_strict_module, ) +from ..utils import exclude_mask as _strict_exclude_mask # noqa: F401 +from ..utils import network as _strict_network # noqa: F401 from .base_descriptor import ( BaseDescriptor, ) @@ -21,17 +13,6 @@ @BaseDescriptor.register("se_e2_r") @BaseDescriptor.register("se_r") +@array_api_strict_module class DescrptSeR(DescrptSeRDP): - def __setattr__(self, name: str, value: Any) -> None: - if name in {"dstd", "davg"}: - value = to_array_api_strict_array(value) - elif name in {"embeddings"}: - if value is not None: - value = NetworkCollection.deserialize(value.serialize()) - elif name == "env_mat": - # env_mat doesn't store any value - pass - elif name == "emask": - value = PairExcludeMask(value.ntypes, value.exclude_types) - - return super().__setattr__(name, value) + pass diff --git a/source/tests/array_api_strict/descriptor/se_t.py b/source/tests/array_api_strict/descriptor/se_t.py index 13e650aa17..b85b9ad1f1 100644 --- a/source/tests/array_api_strict/descriptor/se_t.py +++ b/source/tests/array_api_strict/descriptor/se_t.py @@ -1,32 +1,13 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) - from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP from ..common import ( - to_array_api_strict_array, -) -from ..utils.exclude_mask import ( - PairExcludeMask, -) -from ..utils.network import ( - NetworkCollection, + array_api_strict_module, ) +from ..utils import exclude_mask as _strict_exclude_mask # noqa: F401 +from ..utils import network as _strict_network # noqa: F401 +@array_api_strict_module class DescrptSeT(DescrptSeTDP): - def __setattr__(self, name: str, value: Any) -> None: - if name in {"dstd", "davg"}: - value = to_array_api_strict_array(value) - elif name in {"embeddings"}: - if value is not None: - value = NetworkCollection.deserialize(value.serialize()) - elif name == "env_mat": - # env_mat doesn't store any value - pass - elif name == "emask": - value = PairExcludeMask(value.ntypes, value.exclude_types) - - return super().__setattr__(name, value) + pass diff --git a/source/tests/array_api_strict/descriptor/se_t_tebd.py b/source/tests/array_api_strict/descriptor/se_t_tebd.py index 12fc04e69e..b1bdc6e7ca 100644 --- a/source/tests/array_api_strict/descriptor/se_t_tebd.py +++ b/source/tests/array_api_strict/descriptor/se_t_tebd.py @@ -1,47 +1,22 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) - from deepmd.dpmodel.descriptor.se_t_tebd import ( DescrptBlockSeTTebd as DescrptBlockSeTTebdDP, ) from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP from ..common import ( - to_array_api_strict_array, -) -from ..utils.exclude_mask import ( - PairExcludeMask, -) -from ..utils.network import ( - NetworkCollection, -) -from ..utils.type_embed import ( - TypeEmbedNet, + array_api_strict_module, ) +from ..utils import exclude_mask as _strict_exclude_mask # noqa: F401 +from ..utils import network as _strict_network # noqa: F401 +from ..utils import type_embed as _strict_type_embed # noqa: F401 +@array_api_strict_module class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP): - def __setattr__(self, name: str, value: Any) -> None: - if name in {"mean", "stddev"}: - value = to_array_api_strict_array(value) - elif name in {"embeddings", "embeddings_strip"}: - if value is not None: - value = NetworkCollection.deserialize(value.serialize()) - elif name == "env_mat": - # env_mat doesn't store any value - pass - elif name == "emask": - value = PairExcludeMask(value.ntypes, value.exclude_types) - - return super().__setattr__(name, value) + pass +@array_api_strict_module class DescrptSeTTebd(DescrptSeTTebdDP): - def __setattr__(self, name: str, value: Any) -> None: - if name == "se_ttebd": - value = DescrptBlockSeTTebd.deserialize(value.serialize()) - elif name == "type_embedding": - value = TypeEmbedNet.deserialize(value.serialize()) - return super().__setattr__(name, value) + pass diff --git a/source/tests/array_api_strict/fitting/__init__.py b/source/tests/array_api_strict/fitting/__init__.py index 6ceb116d85..2041f600ea 100644 --- a/source/tests/array_api_strict/fitting/__init__.py +++ b/source/tests/array_api_strict/fitting/__init__.py @@ -1 +1,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from .fitting import ( + DipoleFittingNet, + DOSFittingNet, + EnergyFittingNet, + PolarFittingNet, + PropertyFittingNet, +) + +__all__ = [ + "DOSFittingNet", + "DipoleFittingNet", + "EnergyFittingNet", + "PolarFittingNet", + "PropertyFittingNet", +] diff --git a/source/tests/array_api_strict/fitting/fitting.py b/source/tests/array_api_strict/fitting/fitting.py index af0e57375b..61238f324f 100644 --- a/source/tests/array_api_strict/fitting/fitting.py +++ b/source/tests/array_api_strict/fitting/fitting.py @@ -1,8 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) - from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingNetDP from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP @@ -14,64 +10,32 @@ ) from ..common import ( - to_array_api_strict_array, -) -from ..utils.exclude_mask import ( - AtomExcludeMask, + array_api_strict_module, ) -from ..utils.network import ( - NetworkCollection, -) - - -def setattr_for_general_fitting(name: str, value: Any) -> Any: - if name in { - "bias_atom_e", - "fparam_avg", - "fparam_inv_std", - "aparam_avg", - "aparam_inv_std", - "case_embd", - "default_fparam_tensor", - }: - value = to_array_api_strict_array(value) - elif name == "emask": - value = AtomExcludeMask(value.ntypes, value.exclude_types) - elif name == "nets": - value = NetworkCollection.deserialize(value.serialize()) - return value +from ..utils import exclude_mask as _strict_exclude_mask # noqa: F401 +from ..utils import network as _strict_network # noqa: F401 +@array_api_strict_module class EnergyFittingNet(EnergyFittingNetDP): - def __setattr__(self, name: str, value: Any) -> None: - value = setattr_for_general_fitting(name, value) - return super().__setattr__(name, value) + pass +@array_api_strict_module class PropertyFittingNet(PropertyFittingNetDP): - def __setattr__(self, name: str, value: Any) -> None: - value = setattr_for_general_fitting(name, value) - return super().__setattr__(name, value) + pass +@array_api_strict_module class DOSFittingNet(DOSFittingNetDP): - def __setattr__(self, name: str, value: Any) -> None: - value = setattr_for_general_fitting(name, value) - return super().__setattr__(name, value) + pass +@array_api_strict_module class DipoleFittingNet(DipoleFittingNetDP): - def __setattr__(self, name: str, value: Any) -> None: - value = setattr_for_general_fitting(name, value) - return super().__setattr__(name, value) + pass +@array_api_strict_module class PolarFittingNet(PolarFittingNetDP): - def __setattr__(self, name: str, value: Any) -> None: - value = setattr_for_general_fitting(name, value) - if name in { - "scale", - "constant_matrix", - }: - value = to_array_api_strict_array(value) - return super().__setattr__(name, value) + pass diff --git a/source/tests/array_api_strict/utils/exclude_mask.py b/source/tests/array_api_strict/utils/exclude_mask.py index 7f5c29e0a8..f92d672389 100644 --- a/source/tests/array_api_strict/utils/exclude_mask.py +++ b/source/tests/array_api_strict/utils/exclude_mask.py @@ -1,25 +1,17 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) - from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP from ..common import ( - to_array_api_strict_array, + array_api_strict_module, ) +@array_api_strict_module class AtomExcludeMask(AtomExcludeMaskDP): - def __setattr__(self, name: str, value: Any) -> None: - if name in {"type_mask"}: - value = to_array_api_strict_array(value) - return super().__setattr__(name, value) + pass +@array_api_strict_module class PairExcludeMask(PairExcludeMaskDP): - def __setattr__(self, name: str, value: Any) -> None: - if name in {"type_mask"}: - value = to_array_api_strict_array(value) - return super().__setattr__(name, value) + pass diff --git a/source/tests/array_api_strict/utils/network.py b/source/tests/array_api_strict/utils/network.py index 42b0bb5c61..254b7afc31 100644 --- a/source/tests/array_api_strict/utils/network.py +++ b/source/tests/array_api_strict/utils/network.py @@ -7,8 +7,12 @@ from deepmd.dpmodel.common import ( NativeOP, ) +from deepmd.dpmodel.utils.network import EmbeddingNet as EmbeddingNetDP +from deepmd.dpmodel.utils.network import FittingNet as FittingNetDP +from deepmd.dpmodel.utils.network import Identity as IdentityDP from deepmd.dpmodel.utils.network import LayerNorm as LayerNormDP from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP +from deepmd.dpmodel.utils.network import NativeNet as NativeNetDP from deepmd.dpmodel.utils.network import NetworkCollection as NetworkCollectionDP from deepmd.dpmodel.utils.network import ( make_embedding_network, @@ -17,22 +21,36 @@ ) from ..common import ( + array_api_strict_module, + register_dpmodel_mapping, to_array_api_strict_array, ) +@array_api_strict_module class NativeLayer(NativeLayerDP): + _array_api_strict_skip_auto_convert_attrs: ClassVar[set[str]] = {"w", "b", "idt"} + def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"}: value = to_array_api_strict_array(value) return super().__setattr__(name, value) -NativeNet = make_multilayer_network(NativeLayer, NativeOP) -EmbeddingNet = make_embedding_network(NativeNet, NativeLayer) -FittingNet = make_fitting_network(EmbeddingNet, NativeNet, NativeLayer) +@array_api_strict_module +class NativeNet(make_multilayer_network(NativeLayer, NativeOP)): + pass + +class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)): + pass + + +class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)): + pass + +@array_api_strict_module class NetworkCollection(NetworkCollectionDP): NETWORK_TYPE_MAP: ClassVar[dict[str, type]] = { "network": NativeNet, @@ -43,3 +61,44 @@ class NetworkCollection(NetworkCollectionDP): class LayerNorm(LayerNormDP, NativeLayer): pass + + +@array_api_strict_module +class Identity(IdentityDP): + pass + + +register_dpmodel_mapping( + NativeNetDP, + lambda v: NativeNet.deserialize(v.serialize()), +) + +register_dpmodel_mapping( + EmbeddingNetDP, + lambda v: EmbeddingNet.deserialize(v.serialize()), +) + +register_dpmodel_mapping( + FittingNetDP, + lambda v: FittingNet.deserialize(v.serialize()), +) + +register_dpmodel_mapping( + NativeLayerDP, + lambda v: NativeLayer.deserialize(v.serialize()), +) + +register_dpmodel_mapping( + LayerNormDP, + lambda v: LayerNorm.deserialize(v.serialize()), +) + +register_dpmodel_mapping( + NetworkCollectionDP, + lambda v: NetworkCollection.deserialize(v.serialize()), +) + +register_dpmodel_mapping( + IdentityDP, + lambda v: Identity(), +) diff --git a/source/tests/array_api_strict/utils/type_embed.py b/source/tests/array_api_strict/utils/type_embed.py index 7551279002..5321fd1da4 100644 --- a/source/tests/array_api_strict/utils/type_embed.py +++ b/source/tests/array_api_strict/utils/type_embed.py @@ -1,22 +1,12 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) - from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP from ..common import ( - to_array_api_strict_array, -) -from ..utils.network import ( - EmbeddingNet, + array_api_strict_module, ) +from . import network as _strict_network # noqa: F401 +@array_api_strict_module class TypeEmbedNet(TypeEmbedNetDP): - def __setattr__(self, name: str, value: Any) -> None: - if name in {"econf_tebd"}: - value = to_array_api_strict_array(value) - if name in {"embedding_net"}: - value = EmbeddingNet.deserialize(value.serialize()) - return super().__setattr__(name, value) + pass