Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions deepmd/dpmodel/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,55 @@ def xp_add_at(x: Array, indices: Array, values: Array) -> Array:
return x


def xp_maximum_at(x: Array, indices: Array, values: Array) -> Array:
"""Segment max-assign of values into x at the specified indices.

Element-wise analogue of :func:`xp_add_at` that takes the maximum instead
of the sum: for every ``k`` it assigns ``x[indices[k]] = maximum(
x[indices[k]], values[k])``. Repeated indices reduce to the per-segment
maximum, which is order-independent.

Parameters
----------
x : Array
Destination array indexed along axis 0; typically pre-filled with
``-inf`` so empty segments stay neutral.
indices : Array
Integer destination indices with shape (K,).
values : Array
Source values with shape (K, *x.shape[1:]).

Returns
-------
Array
The updated array (modified in place and returned for NumPy; a new
array for JAX/PyTorch).
"""
xp = array_api_compat.array_namespace(x, indices, values)
if array_api_compat.is_numpy_array(x):
# NumPy: in-place ufunc reduction at the given indices.
xp.maximum.at(x, indices, values)
return x

elif array_api_compat.is_jax_array(x):
# JAX: functional indexed-max update, not in-place.
return x.at[indices].max(values)
elif array_api_compat.is_torch_array(x):
import torch

index = indices.reshape([-1] + [1] * (values.ndim - 1)).expand_as(values)
return torch.scatter_reduce(
x, 0, index, values, reduce="amax", include_self=True
)
else:
# Fallback for array_api_strict: basic indexing only.
n = indices.shape[0]
for i in range(n):
idx = int(indices[i])
x[idx, ...] = xp.maximum(x[idx, ...], values[i, ...])
return x


def xp_sigmoid(x: Array) -> Array:
"""Compute the sigmoid function.

Expand Down
Loading
Loading