Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions deepmd/dpmodel/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,36 @@
Array = np.ndarray | Any # Any to support JAX, PyTorch, etc. arrays


def xp_asarray_nodetach(
xp: Any,
obj: Any,
*,
dtype: Any = None,
device: Any = None,
) -> Array:
"""``xp.asarray`` that preserves autograd for backend tensors.

``torch.asarray`` detaches its input from the autograd graph, so calling
``xp.asarray`` on a weight attribute that is already a backend tensor
(e.g. a ``torch.nn.Parameter`` registered by the pt_expt backend)
silently breaks gradient flow to that weight. This helper converts
genuine non-backend data (numpy arrays, python scalars/lists) via
``xp.asarray``; backend tensors are returned as-is, with an optional
differentiable dtype cast via ``xp.astype``.

The ``device`` argument only applies to the conversion path: backend
tensors are assumed to already live on the working device (they are
created together with the inputs).
"""
if isinstance(obj, np.ndarray) or not array_api_compat.is_array_api_obj(obj):
if dtype is None:
return xp.asarray(obj, device=device)
return xp.asarray(obj, dtype=dtype, device=device)
if dtype is not None and obj.dtype != dtype:
obj = xp.astype(obj, dtype)
return obj


# array api adds take_along_axis in https://github.com/data-apis/array-api/pull/816
# but it hasn't been released yet
# below is a pure Python implementation of take_along_axis
Expand Down
22 changes: 16 additions & 6 deletions deepmd/dpmodel/descriptor/dpa4.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
from deepmd.dpmodel import (
NativeOP,
)
from deepmd.dpmodel.array_api import (
xp_asarray_nodetach,
)
from deepmd.dpmodel.common import (
PRECISION_DICT,
get_xp_precision,
Expand Down Expand Up @@ -762,6 +765,7 @@ def call(
mapping: Array | None = None,
fparam: Array | None = None,
comm_dict: dict | None = None,
charge_spin: Array | None = None,
) -> tuple[Array, Any, Any, Any, Any]:
"""Compute the DPA4 descriptor.

Expand All @@ -780,6 +784,10 @@ def call(
Frame parameters; not used by DPA4 (interface compatibility).
comm_dict
MPI communication metadata; not used (interface compatibility).
charge_spin
Charge/spin embedding input; must be None since
``add_chg_spin_ebd=True`` is rejected at construction
(interface compatibility with ``DPAtomicModel``).

Returns
-------
Expand Down Expand Up @@ -851,10 +859,10 @@ def call(
shift_hat = self.film_shift_norm(shift_logits)
device = array_api_compat.device(scale_hat)
scale_strength = xp.exp(
xp.asarray(self.film_scale_strength_log, device=device)
xp_asarray_nodetach(xp, self.film_scale_strength_log, device=device)
)
shift_strength = xp.exp(
xp.asarray(self.film_shift_strength_log, device=device)
xp_asarray_nodetach(xp, self.film_shift_strength_log, device=device)
)
scale = 1.0 + scale_strength * xp.tanh(scale_hat)
shift = shift_strength * xp.tanh(shift_hat)
Expand Down Expand Up @@ -929,7 +937,9 @@ def _build_gie_zonal_coupling(self, edge_cache: EdgeCache) -> Any:
mp_cols = self.gie.zonal_m0_col_index_for_row[:mp_row_count]
Dt_full = edge_cache.Dt_full
dim_full = Dt_full.shape[-1]
flat_index = xp.asarray(mp_rows * dim_full + mp_cols, device=device)
flat_index = xp_asarray_nodetach(
xp, mp_rows * dim_full + mp_cols, device=device
)
mp_coupling = xp.take(
xp.reshape(Dt_full, (n_edge, dim_full * dim_full)),
flat_index,
Expand Down Expand Up @@ -1043,8 +1053,8 @@ def _variables(self) -> dict[str, np.ndarray]:
# pt interface-compatibility buffers
"version_tensor": np.asarray(self.version, dtype=np.float64),
"_empty_tensor": np.zeros((0,), dtype=np.float64),
"mean": np.asarray(self.mean, dtype=model_np_prec),
"stddev": np.asarray(self.stddev, dtype=model_np_prec),
"mean": to_numpy_array(self.mean).astype(model_np_prec),
"stddev": to_numpy_array(self.stddev).astype(model_np_prec),
}

def add(prefix: str, sub_vars: dict[str, Any]) -> None:
Expand Down Expand Up @@ -1073,7 +1083,7 @@ def add(prefix: str, sub_vars: dict[str, Any]) -> None:
def wigner_buffers(calc: WignerDCalculator) -> dict[str, np.ndarray]:
return {
"l1_perm": np.asarray([1, 2, 0], dtype=np.int64),
"l1_sign_outer": np.asarray(calc.l1_sign_outer, dtype=np.float64),
"l1_sign_outer": to_numpy_array(calc.l1_sign_outer).astype(np.float64),
}

add("wigner_calc.", wigner_buffers(self.wigner_calc))
Expand Down
11 changes: 7 additions & 4 deletions deepmd/dpmodel/descriptor/dpa4_nn/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
NativeOP,
)
from deepmd.dpmodel.array_api import (
xp_asarray_nodetach,
xp_sigmoid,
)
from deepmd.dpmodel.common import (
Expand Down Expand Up @@ -211,8 +212,8 @@ def call(self, x: Any, gate: Any = None) -> Any:
if self.lmax == 0:
return x0

gate_weight = xp.asarray(
self.gate_linear.weight[...], device=array_api_compat.device(x)
gate_weight = xp_asarray_nodetach(
xp, self.gate_linear.weight[...], device=array_api_compat.device(x)
)
input_dtype = gate_scalar_source.dtype
if input_dtype != gate_weight.dtype:
Expand All @@ -224,7 +225,9 @@ def call(self, x: Any, gate: Any = None) -> Any:
gating_scalars,
(x.shape[0], gate_scalar_source.shape[1], self.lmax, self.channels),
)
expand_index = xp.asarray(self.expand_index, device=array_api_compat.device(x))
expand_index = xp_asarray_nodetach(
xp, self.expand_index, device=array_api_compat.device(x)
)
gates = xp.take(gating_scalars, expand_index, axis=2) # (N, F, D-1, C)
if self.layout == "ndfc":
gates = xp.permute_dims(gates, (0, 2, 1, 3)) # (N, D-1, F, C)
Expand Down Expand Up @@ -284,7 +287,7 @@ def deserialize(cls, data: dict[str, Any]) -> GatedActivation:
)
prec = PRECISION_DICT[obj.precision.lower()]
expand_index = np.asarray(variables["expand_index"], dtype=np.int64)
if not np.array_equal(expand_index, obj.expand_index):
if not np.array_equal(expand_index, to_numpy_array(obj.expand_index)):
raise ValueError("expand_index does not match the lmax/mmax tables")
if obj.gate_linear is not None:
weight = np.asarray(variables["gate_linear.weight"], dtype=prec)
Expand Down
12 changes: 10 additions & 2 deletions deepmd/dpmodel/descriptor/dpa4_nn/edge_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
import array_api_compat
import numpy as np

from deepmd.dpmodel.array_api import (
xp_asarray_nodetach,
)

from .utils import (
safe_norm,
)
Expand Down Expand Up @@ -323,7 +327,9 @@ def build_edge_cache(
# === Step 4. Rewrite invalid slots to the safe +z dummy vector ===
# Gradient safety: see the function docstring.
maskf = xp.astype(mask_flat, vec.dtype)[:, None] # (E, 1)
z_unit = xp.asarray(np.array([[0.0, 0.0, 1.0]]), dtype=vec.dtype, device=device)
z_unit = xp_asarray_nodetach(
xp, np.array([[0.0, 0.0, 1.0]]), dtype=vec.dtype, device=device
)
edge_vec = vec * maskf + (1.0 - maskf) * z_unit
edge_len = safe_norm(edge_vec, eps) # (E, 1)

Expand All @@ -336,7 +342,9 @@ def build_edge_cache(
if random_gamma:
if gamma is None:
gamma = np.random.default_rng().uniform(0.0, 2.0 * math.pi, n_edge)
gamma = xp.astype(xp.asarray(gamma, device=device), edge_quat.dtype)
gamma = xp.astype(
xp_asarray_nodetach(xp, gamma, device=device), edge_quat.dtype
)
edge_quat = quaternion_multiply(quaternion_z_rotation(gamma), edge_quat)
D_full, Dt_full = wigner_calc(edge_quat)

Expand Down
22 changes: 17 additions & 5 deletions deepmd/dpmodel/descriptor/dpa4_nn/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
PRECISION_DICT,
NativeOP,
)
from deepmd.dpmodel.array_api import (
xp_asarray_nodetach,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
Expand Down Expand Up @@ -157,8 +160,8 @@ def call(self, atype: Any) -> Any:
Type embeddings with shape (..., embed_dim).
"""
xp = array_api_compat.array_namespace(atype)
weight = xp.asarray(
self.adam_type_embedding[...], device=array_api_compat.device(atype)
weight = xp_asarray_nodetach(
xp, self.adam_type_embedding[...], device=array_api_compat.device(atype)
)
# pt embedding.py:143 torch.embedding -> flat int64 take + reshape.
index = xp.astype(xp.reshape(atype, (-1,)), xp.int64)
Expand Down Expand Up @@ -297,7 +300,8 @@ def call(
if zonal_coupling is None:
Dt_full = edge_cache.Dt_full # (E, D, D)
dim_full = Dt_full.shape[-1]
flat_index = xp.asarray(
flat_index = xp_asarray_nodetach(
xp,
self.non_scalar_row_index * dim_full + self.zonal_m0_col_index_for_row,
device=device,
)
Expand All @@ -310,7 +314,9 @@ def call(
# === Step 3. Broadcast radial features per row ===
# Each non-scalar packed row reuses the radial feature of its degree l
# (pt embedding.py:245-250, index_select on axis 1).
radial_slot_index = xp.asarray(self.radial_slot_index_for_row, device=device)
radial_slot_index = xp_asarray_nodetach(
xp, self.radial_slot_index_for_row, device=device
)
radial_value_for_row = xp.take(
radial_feat, radial_slot_index, axis=1
) # (E, D-1, C)
Expand Down Expand Up @@ -548,7 +554,13 @@ def __init__(
seed=child_seed(seed, 3),
trainable=self.trainable,
)
self.output_proj.w = np.zeros_like(self.output_proj.w)
# Use an explicit shape/dtype instead of np.zeros_like(self.output_proj.w):
# in pt_expt the attribute is a requires-grad torch Parameter, on which
# numpy __array__ conversion raises.
self.output_proj.w = np.zeros(
(self.embed_dim * self.axis_dim, 2 * self.channels),
dtype=PRECISION_DICT[self.precision.lower()],
)

def call(
self,
Expand Down
9 changes: 5 additions & 4 deletions deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
NativeOP,
)
from deepmd.dpmodel.array_api import (
xp_asarray_nodetach,
xp_sigmoid,
)
from deepmd.dpmodel.common import (
Expand Down Expand Up @@ -423,8 +424,8 @@ def _to_grid(self, coeff: Any) -> Any:
# einsum "gd,ndfc->ngfc" (n_frames == 1) as a broadcast batched matmul
xp = array_api_compat.array_namespace(coeff)
n_batch, coeff_dim, n_focus, _ = coeff.shape
to_grid_mat = xp.asarray(
self.projector.to_grid_mat[...], device=array_api_compat.device(coeff)
to_grid_mat = xp_asarray_nodetach(
xp, self.projector.to_grid_mat[...], device=array_api_compat.device(coeff)
)
if to_grid_mat.dtype != coeff.dtype:
to_grid_mat = xp.astype(to_grid_mat, coeff.dtype)
Expand All @@ -439,8 +440,8 @@ def _from_grid(self, grid: Any) -> Any:
xp = array_api_compat.array_namespace(grid)
n_batch, n_grid, n_focus, _ = grid.shape
coeff_dim = self.projector.coeff_dim
from_grid_mat = xp.asarray(
self.projector.from_grid_mat[...], device=array_api_compat.device(grid)
from_grid_mat = xp_asarray_nodetach(
xp, self.projector.from_grid_mat[...], device=array_api_compat.device(grid)
)
if from_grid_mat.dtype != grid.dtype:
from_grid_mat = xp.astype(from_grid_mat, grid.dtype)
Expand Down
12 changes: 10 additions & 2 deletions deepmd/dpmodel/descriptor/dpa4_nn/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
import array_api_compat
import numpy as np

from deepmd.dpmodel.array_api import (
xp_asarray_nodetach,
)


def get_so3_dim_of_lmax(lmax: int) -> int:
"""
Expand Down Expand Up @@ -169,7 +173,9 @@ def project_D_to_m(

xp = array_api_compat.array_namespace(D_full)
D_block = D_full[:, :ebed_dim_full, :ebed_dim_full]
index = xp.asarray(coeff_index_m, device=array_api_compat.device(D_full))
index = xp_asarray_nodetach(
xp, coeff_index_m, device=array_api_compat.device(D_full)
)
proj = xp.take(D_block, index, axis=1)
if cache is not None:
cache[cache_key] = proj
Expand Down Expand Up @@ -223,7 +229,9 @@ def project_Dt_from_m(

xp = array_api_compat.array_namespace(Dt_full)
Dt_block = Dt_full[:, :ebed_dim_full, :ebed_dim_full]
index = xp.asarray(coeff_index_m, device=array_api_compat.device(Dt_full))
index = xp_asarray_nodetach(
xp, coeff_index_m, device=array_api_compat.device(Dt_full)
)
proj = xp.take(Dt_block, index, axis=2)
if cache is not None:
cache[cache_key] = proj
Expand Down
Loading
Loading