Skip to content

feat(dpmodel): graph-native se_atten attention (NeighborGraph PR-D)#5715

Open
wanghan-iapcm wants to merge 7 commits into
deepmodeling:masterfrom
wanghan-iapcm:feat-graph-attn-prD
Open

feat(dpmodel): graph-native se_atten attention (NeighborGraph PR-D)#5715
wanghan-iapcm wants to merge 7 commits into
deepmodeling:masterfrom
wanghan-iapcm:feat-graph-attn-prD

Conversation

@wanghan-iapcm

@wanghan-iapcm wanghan-iapcm commented Jul 2, 2026

Copy link
Copy Markdown
Collaborator

Implements NeighborGraph PR-D: the graph path now supports attn_layer > 0 for dpa1/se_atten, removing the attn_layer=0-only restriction shipped in #5583.

What

  • Segment toolkit: segment_max + numerically-stable, mask-aware segment_softmax (deepmd/dpmodel/utils/neighbor_graph/segment.py), built on the existing xp_maximum_at.
  • center_edge_pairs (neighbor_graph/pairs.py): pairs of edges sharing a center — the edge-pair axis shared with the upcoming angle machinery (PR-E). Segment-based enumeration (a global (E,E) boolean is deliberately avoided: O(N²·nnei²) memory). Two forms: compact eager (dynamic P, carry-all graphs) and shape-static (P = n_center·nnei², pure arange/reshape arithmetic, no nonzero) for the center-major static layout — this keeps the traced/compiled/export path traceable.
  • DescrptBlockSeAtten._graph_attention: op-for-op ragged mirror of GatedAttentionLayer/NeighborGatedAttention — per-center q@kᵀ becomes per-pair q_m·k_n, softmax over keys becomes segment_softmax grouped by the query edge; head_dim QKV slicing, q/k/v normalize, temperature/scaling, smooth shift trick, post-softmax sw and dotr weighting, residual + LayerNorm per layer.
  • edge_env_mat(return_sw=True) exposes the per-edge switch (zeroed on padding) for the smooth branch.
  • uses_graph_lower widened: attention configs (concat tebd, no exclude_types) are now graph-eligible — pt_expt eager/compiled/exported paths route them through the graph lower by default.

Numerical semantics (reviewed decision)

  • Shape-static adapter path (the dense call adapter, from_dense_quartet(compact=False) + static_nnei): bit-exact vs the dense body, rtol 1e-12, full flag matrix (attn_layer 1/2 × dotr × smooth × normalize × temperature, binding AND non-binding sel).
  • Carry-all graphs: exact for non-smooth attention. For smooth_type_embedding=True, the dense branch keeps sel-padding slots in the attention softmax denominator (weight exp(-attnw_shift)), which makes the dense output depend on sel itself (measured up to ~1e-4 with an identical physical neighbor set). The carry-all form drops those phantom terms by design — the sel-independent math. Pinned by a clean-divergence test; route-equivalence fixtures pin smooth_type_embedding=False.
  • se_atten_v2 (tebd_input_mode="strip") remains graph-ineligible (strip mode is a later PR) — pinned by test.

Testing

  • 38 new dpmodel tests (segment toolkit, pairs incl. random-vs-oracle + static-vs-compact equality, attention parity matrix, binding-sel divergence sanity).
  • pt_expt: test_make_fx_graph_attn (graph forward + autograd at attn_layer=2 traces under make_fx, both smooth branches — required since compiled training uses the graph lower); model-level graph-vs-legacy force/virial/atom-virial parity parametrized over attn_layer {0,2}.
  • Local CPU: common/dpmodel 583, consistent dpa1+se_atten_v2 209, pt_expt descriptor/model/utils 701 (2 failures: dpa4 export inductor error pre-existing on upstream/master, and a route-parity fixture fixed in-branch).
  • GPU-validated (Tesla T4, cuda:0): dpmodel suites 38, pt_expt graph-lower/make_fx/consistency 44 (CUDA 1e-10), route-parity 6, attention AOTI export pipeline + dpa1 cross-backend consistency 105 — all passed.

Known limitations

  • Strip-mode (se_atten_v2) attention stays on the dense path.
  • Carry-all smooth attention diverges from dense by design (see above); old behavior reachable via neighbor_graph_method="legacy" / explicit World-1 builders.
  • num_heads == 1 assumed (dpa1 never exposes num_heads); fail-fast otherwise.
  • Compact center_edge_pairs is eager-only (nonzero); traced paths use the shape-static form.
  • 3-body angles (PR-E), jax graph force (PR-F), dpa2/3 MP (PR-G) unchanged.

Summary by CodeRabbit

  • New Features

    • Graph-based attention support has been expanded, improving compatibility when exporting or tracing models.
    • Neighbor-graph handling now supports more flexible pair enumeration and stable segment-based softmax/max operations.
  • Bug Fixes

    • Improved consistency between graph and dense execution paths for attention-enabled models.
    • Better handling of padded or empty neighbor segments to avoid NaNs and preserve valid outputs.
  • Tests

    • Added broader coverage for graph attention, neighbor-graph utilities, and tracing parity across NumPy and Torch.

Han Wang added 6 commits July 3, 2026 00:10
…ftmax

Built on the existing xp_maximum_at (no new array_api helper needed).
Part of NeighborGraph PR-D (graph-native attention).
Segment-based (global (E,E) boolean deliberately avoided): compact eager
form for carry-all graphs + shape-static nonzero-free form for the
center-major static layout (jit/export/make_fx traceable).
Part of NeighborGraph PR-D; PR-E angles reuse (unordered, no-self).
…r > 0)

DescrptBlockSeAtten.call_graph grows _graph_attention: the dense per-center
(nnei, nnei) attention square becomes the edge-pair axis (center_edge_pairs,
ordered + self-included), softmax over keys becomes segment_softmax grouped
by the query edge. Op-for-op mirror of GatedAttentionLayer.call (head_dim
QKV slicing, normalize q/k/v, temperature/scaling, smooth shift trick,
post-softmax sw and dotr weighting, residual + LayerNorm per layer).

- shape-static adapter path (static_nnei threaded from the dense call
  adapter): bit-exact vs the dense body, rtol 1e-12, full flag matrix
  (attn_layer 1/2 x dotr x smooth x normalize x temperature, binding and
  non-binding sel).
- carry-all (compact) graphs: exact for non-smooth; for smooth the dense
  branch keeps sel-padding slots in the softmax denominator (dense output is
  sel-DEPENDENT, up to ~1e-4) — the carry-all form drops those phantom terms
  by design (user decision 2026-07-03), pinned by a clean-divergence test.
- edge_env_mat(return_sw=True) exposes the per-edge switch (zeroed on
  padding) for the smooth branch.
- uses_graph_lower: attention configs are now graph-eligible (concat tebd,
  no exclude_types still required).
…ial parity

- test_make_fx_graph_attn: graph forward + autograd.grad at attn_layer=2
  traces under make_fx for BOTH smooth branches (the shape-static
  center_edge_pairs form is nonzero-free) — required since pt_expt compiled
  training routes eligible models through the graph lower.
- model-level graph-vs-legacy lower parity now parametrized over
  attn_layer {0, 2} (energy/force/virial/atom_virial, 1e-12 CPU).
- eligibility pins: attention+concat is graph-eligible; se_atten_v2
  (tebd_input_mode='strip') correctly stays dense (strip = later PR;
  the plan's 'se_atten_v2 inherits for free' did not hold).
- linear-model weight tests: pin smooth_type_embedding=False — the standard
  (graph-routed, carry-all) and linear (graph-ineligible, dense) submodels
  otherwise differ by the accepted smooth-attention denominator divergence
  (~1e-6), which is a route artifact, not a weight-combination bug.
- new binding-sel sanity: carry-all graph attention diverges from the
  sel-truncated dense path when sel binds (spec decision deepmodeling#17).
…rity)

neighbor_list=None now takes the carry-all graph default for eligible
attention models; explicit World-1 builders take the legacy dense route.
With smooth attention the two routes differ by design (PR-D), so the
route-equivalence tests pin smooth_type_embedding=False.
@dosubot dosubot Bot added the new feature label Jul 2, 2026
@github-actions github-actions Bot added the Python label Jul 2, 2026
@coderabbitai

coderabbitai Bot commented Jul 2, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

DPA1's graph-native call_graph path is extended to support attn_layer > 0 by adding shape-static and compact center-edge-pair enumeration, and mask-aware segment max/softmax reductions. _graph_attention/_graph_attention_one_layer implement transformer attention over graph edges. Test suites are added/updated for parity and smoothing consistency.

Changes

Graph-native attention support

Layer / File(s) Summary
Segment reduction primitives
deepmd/dpmodel/utils/neighbor_graph/segment.py, deepmd/dpmodel/utils/neighbor_graph/__init__.py, source/tests/common/dpmodel/test_segment_softmax.py
Adds segment_max and mask-aware segment_softmax using xp_maximum_at, exported via package __init__, with new unit tests for stability, masking, and torch/numpy parity.
Center edge-pair enumeration
deepmd/dpmodel/utils/neighbor_graph/pairs.py, deepmd/dpmodel/utils/neighbor_graph/env.py, source/tests/common/dpmodel/test_center_edge_pairs.py
Adds center_edge_pairs with compact eager and shape-static (static_nnei) implementations; edge_env_mat gains return_sw to expose the smooth switch; tested for correctness, empty inputs, and torch/numpy parity.
DPA1 graph-native attention forward
deepmd/dpmodel/descriptor/dpa1.py
uses_graph_lower no longer restricts to attn_layer == 0; call_graph/_call_graph_adapter thread static_nnei; _graph_attention/_graph_attention_one_layer implement graph-native transformer attention using center_edge_pairs and segment_softmax, including smooth/non-smooth and dotr branches.
Parity tests and smoothing pinning
source/tests/common/dpmodel/test_dpa1_graph_attention_parity.py, source/tests/common/dpmodel/test_dpa1_call_graph_block.py, source/tests/pt_expt/descriptor/test_dpa1.py, source/tests/pt_expt/model/test_dpa1_graph_lower.py, source/tests/pt_expt/model/test_linear_model.py, source/tests/pt_expt/utils/test_neighbor_list.py
Adds dense-vs-graph attention parity tests, removes obsolete fail-fast test, adds FX-traceability test for graph attention, extends graph-lower parity across attn_layer, and pins smooth_type_embedding=False in unrelated model/neighbor-list tests to avoid divergence.

Estimated code review effort: 4 (Complex) | ~60 minutes

Sequence Diagram(s)

sequenceDiagram
    participant DescrptDPA1
    participant DescrptBlockSeAtten
    participant CenterEdgePairs
    participant GraphAttention

    DescrptDPA1->>DescrptBlockSeAtten: call_graph(graph, atype, static_nnei)
    DescrptBlockSeAtten->>DescrptBlockSeAtten: edge_env_mat(return_sw=True)
    DescrptBlockSeAtten->>CenterEdgePairs: center_edge_pairs(dst, edge_mask, static_nnei)
    CenterEdgePairs-->>DescrptBlockSeAtten: query_edge, key_edge, pair_mask
    DescrptBlockSeAtten->>GraphAttention: _graph_attention(gg, rr, dst, sw_e, static_nnei)
    GraphAttention->>GraphAttention: segment_softmax + segment_sum per layer
    GraphAttention-->>DescrptBlockSeAtten: updated edge features
    DescrptBlockSeAtten-->>DescrptDPA1: descriptor output
Loading

Possibly related PRs

  • deepmodeling/deepmd-kit#5583: Both PRs extend deepmd/dpmodel/descriptor/dpa1.py's graph-native lowering (uses_graph_lower/call_graph/DescrptBlockSeAtten._call_graph), with this PR expanding the attn_layer==0-only graph path established there.
  • deepmodeling/deepmd-kit#5604: Both PRs modify the same DescrptDPA1/DescrptBlockSeAtten.call_graph logic in deepmd/dpmodel/descriptor/dpa1.py.

Suggested labels: enhancement

Suggested reviewers: OutisLi, iProzd

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly matches the main change: adding graph-native se_atten attention support in dpmodel.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
source/tests/common/dpmodel/test_segment_softmax.py (1)

55-65: 🎯 Functional Correctness | 🔵 Trivial | ⚡ Quick win

Add a regression test for masked-entry-larger-than-max.

None of the mask tests here cover a masked entry whose value exceeds the unmasked max in the same segment — the scenario that triggers the NaN-propagation issue flagged in segment.py. Once that's fixed, a test like the one below would guard the regression:

def test_masked_entry_extreme_value_no_nan(self) -> None:
    logits = np.array([1.0, 1e30, 2.0])  # masked entry (idx 1) dwarfs the max
    ids = np.array([0, 0, 0], dtype=np.int64)
    mask = np.array([True, False, True])
    w = segment_softmax(logits, ids, 1, mask=mask)
    assert not np.any(np.isnan(w))
    assert w[1] == 0.0
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/common/dpmodel/test_segment_softmax.py` around lines 55 - 65,
Add a regression test in test_segment_softmax for the
masked-entry-larger-than-max case that currently leads to NaN propagation in
segment_softmax. Extend the existing mask coverage by creating a segment where
the masked element has an extreme value above the unmasked max, then assert the
result contains no NaNs, the masked position is exactly zero, and the unmasked
weights still normalize correctly. Use the existing segment_softmax test pattern
in test_masked_entries_zero to keep the new case consistent.
deepmd/dpmodel/utils/neighbor_graph/pairs.py (1)

92-117: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value

dst values are unused in the shape-static path (only its shape matters).

_pairs_shape_static derives query/key edges purely from index arithmetic assuming the center-major layout documented in the module docstring; the actual dst values are never consulted to validate that assumption. This matches the documented contract, but if a caller ever passes a dst/static_nnei combination that doesn't match the assumed layout, this silently produces wrong pairs with no diagnostic. Consider a lightweight assertion (e.g., e_tot % nn == 0) or a debug-mode check that dst is actually constant within each block, to fail fast on a layout mismatch instead of silently mis-pairing.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/utils/neighbor_graph/pairs.py` around lines 92 - 117, The
shape-static path in `_pairs_shape_static` relies on center-major block layout
but never validates that `dst` actually matches that assumption, so a mismatched
`static_nnei`/layout can silently produce wrong pairs. Add a lightweight guard
in `_pairs_shape_static` to fail fast on layout mismatches, such as verifying
`e_tot % nn == 0` and/or checking that `dst` is constant within each `nn` block
in a debug-friendly way. Keep the existing index-arithmetic logic for
`query_edge`, `key_edge`, and `pair_mask`, but ensure the contract is enforced
before returning.
deepmd/dpmodel/descriptor/dpa1.py (2)

1671-1684: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

"Bit-exact" claim needs a caveat for the default (smooth + compact) configuration.

The docstring states this is a "Bit-exact analogue of call" and the "Known limitations" section only lists tebd_input_mode and exclude_types. But per test_block_compact_graph_smooth_clean_divergence in test_dpa1_graph_attention_parity.py, when static_nnei is None (the default, compact/carry-all form) and smooth=True (also the class default), the output deliberately diverges from dense (up to ~1e-4) by design — the carry-all graph drops phantom sel-padding softmax terms that dense keeps. A reader of this docstring/API surface would not learn about this without digging into the test suite. Since smooth_type_embedding defaults to True and static_nnei defaults to None, the "bit-exact" claim is misleading for the descriptor's own default configuration.

Suggest adding a short caveat to the "Known limitations" (or a new "Notes") section referencing this divergence, mirroring what's already documented in the test docstring.

📝 Suggested docstring addition
         Notes
         -----
         Known limitations:
         - ``tebd_input_mode == "concat"`` only (strip mode lands later);
         - ``exclude_types`` is not yet supported and raises (lands in a later PR).
+        - When ``attn_layer > 0``, ``smooth_type_embedding=True`` (the class
+          default) combined with the compact/carry-all form (``static_nnei=None``,
+          also the default) intentionally diverges from the dense reference
+          (up to ~1e-4): the carry-all graph has no sel-padding slots, so it
+          drops the phantom denominator terms the dense smooth branch keeps.
+          Bit-exact parity (1e-12) only holds on the shape-static form
+          (``static_nnei`` set, as used by the dense ``call`` adapter) or when
+          ``smooth_type_embedding=False``.
         """

Also applies to: 1712-1717

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/descriptor/dpa1.py` around lines 1671 - 1684, Update the
call_graph docstring in dpa1.py to add a caveat that the “bit-exact” claim does
not hold for the default smooth + compact/carry-all configuration: when
static_nnei is None and smooth=True, the graph path can intentionally diverge
slightly from dense because it omits phantom sel-padding softmax terms. Add this
to the existing “Known limitations” or a new “Notes” section, and keep the
wording consistent with the behavior exercised by
test_block_compact_graph_smooth_clean_divergence and the related call_graph
documentation block.

1856-1932: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Extract the shared attnw_shift default. GatedAttentionLayer.call also uses 20.0, so pulling this into a shared constant would keep the dense and graph paths aligned if that default ever changes.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/descriptor/dpa1.py` around lines 1856 - 1932, The hardcoded
attention shift value is duplicated in _graph_attention_one_layer and
GatedAttentionLayer.call, so pull the 20.0 default into a shared constant or
class attribute used by both paths. Update the graph attention logic to
reference that shared symbol so the dense and graph implementations stay aligned
if the default changes.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/dpmodel/utils/neighbor_graph/segment.py`:
- Around line 59-89: The masked path in segment_softmax is using raw data for
the exponent shift, which can turn masked large values into inf and then nan
after multiplying by the mask. Update segment_softmax to compute shifted from
data_for_max (the same masked-safe values used for seg_max), and keep the
existing empty/fully-masked guards so exp and denom stay finite. Check the
segment_max/segment_sum flow and the _graph_attention_one_layer caller to ensure
masked attention logits cannot leak into the denominator.

---

Nitpick comments:
In `@deepmd/dpmodel/descriptor/dpa1.py`:
- Around line 1671-1684: Update the call_graph docstring in dpa1.py to add a
caveat that the “bit-exact” claim does not hold for the default smooth +
compact/carry-all configuration: when static_nnei is None and smooth=True, the
graph path can intentionally diverge slightly from dense because it omits
phantom sel-padding softmax terms. Add this to the existing “Known limitations”
or a new “Notes” section, and keep the wording consistent with the behavior
exercised by test_block_compact_graph_smooth_clean_divergence and the related
call_graph documentation block.
- Around line 1856-1932: The hardcoded attention shift value is duplicated in
_graph_attention_one_layer and GatedAttentionLayer.call, so pull the 20.0
default into a shared constant or class attribute used by both paths. Update the
graph attention logic to reference that shared symbol so the dense and graph
implementations stay aligned if the default changes.

In `@deepmd/dpmodel/utils/neighbor_graph/pairs.py`:
- Around line 92-117: The shape-static path in `_pairs_shape_static` relies on
center-major block layout but never validates that `dst` actually matches that
assumption, so a mismatched `static_nnei`/layout can silently produce wrong
pairs. Add a lightweight guard in `_pairs_shape_static` to fail fast on layout
mismatches, such as verifying `e_tot % nn == 0` and/or checking that `dst` is
constant within each `nn` block in a debug-friendly way. Keep the existing
index-arithmetic logic for `query_edge`, `key_edge`, and `pair_mask`, but ensure
the contract is enforced before returning.

In `@source/tests/common/dpmodel/test_segment_softmax.py`:
- Around line 55-65: Add a regression test in test_segment_softmax for the
masked-entry-larger-than-max case that currently leads to NaN propagation in
segment_softmax. Extend the existing mask coverage by creating a segment where
the masked element has an extreme value above the unmasked max, then assert the
result contains no NaNs, the masked position is exactly zero, and the unmasked
weights still normalize correctly. Use the existing segment_softmax test pattern
in test_masked_entries_zero to keep the new case consistent.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 86eede99-fa33-4044-b859-5fe1eb620896

📥 Commits

Reviewing files that changed from the base of the PR and between 55d7e79 and 91784df.

📒 Files selected for processing (13)
  • deepmd/dpmodel/descriptor/dpa1.py
  • deepmd/dpmodel/utils/neighbor_graph/__init__.py
  • deepmd/dpmodel/utils/neighbor_graph/env.py
  • deepmd/dpmodel/utils/neighbor_graph/pairs.py
  • deepmd/dpmodel/utils/neighbor_graph/segment.py
  • source/tests/common/dpmodel/test_center_edge_pairs.py
  • source/tests/common/dpmodel/test_dpa1_call_graph_block.py
  • source/tests/common/dpmodel/test_dpa1_graph_attention_parity.py
  • source/tests/common/dpmodel/test_segment_softmax.py
  • source/tests/pt_expt/descriptor/test_dpa1.py
  • source/tests/pt_expt/model/test_dpa1_graph_lower.py
  • source/tests/pt_expt/model/test_linear_model.py
  • source/tests/pt_expt/utils/test_neighbor_list.py

Comment on lines +59 to +89
def segment_softmax(
data: Array,
segment_ids: Array,
num_segments: int,
mask: Array | None = None,
) -> Array:
"""Softmax over entries sharing a segment id, numerically stable.

Mirrors the dense ``np_softmax`` max-subtraction trick with a PER-SEGMENT
max. ``mask`` (bool, per entry) removes masked entries from the softmax
entirely (zero weight AND excluded from the denominator). Empty or
fully-masked segments produce all-zero weights (no NaN).
"""
xp = array_api_compat.array_namespace(data)
if mask is not None:
# keep masked entries out of the per-segment max: send them to -inf
neg = xp.full_like(data, -xp.inf)
data_for_max = xp.where(mask, data, neg)
else:
data_for_max = data
seg_max = segment_max(data_for_max, segment_ids, num_segments)
# guard -inf (empty / fully-masked segments) so gather doesn't yield inf-inf
seg_max = xp.where(xp.isinf(seg_max), xp.zeros_like(seg_max), seg_max)
shifted = data - xp.take(seg_max, segment_ids, axis=0)
ex = xp.exp(shifted)
if mask is not None:
ex = ex * xp.astype(mask, ex.dtype)
denom = segment_sum(ex, segment_ids, num_segments)
denom_e = xp.take(denom, segment_ids, axis=0)
safe = xp.where(denom_e > 0, denom_e, xp.ones_like(denom_e))
return ex / safe

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

Masked-softmax can silently produce NaN for an entire segment.

shifted (line 82) is computed from the raw data, not data_for_max. If a masked entry's value is larger than the per-segment max of the unmasked entries, exp(shifted) overflows to +inf for that entry; the later ex * mask (line 85) then evaluates inf * 0.0 = nan. That NaN is summed into denom (line 86) and gathered back onto every entry sharing the segment id (line 87), so the NaN contaminates the whole segment's softmax output — not just the masked entry. This is exactly the scenario the "numerically stable" masking is supposed to guard against, and it's untested (test_masked_entries_zero / test_all_masked_segment_is_zero_no_nan only use masked values smaller than the unmasked max).

Downstream, dpa1.py's _graph_attention_one_layer calls this with mask=pair_mask on raw attention logits for padding pairs, which are not bounded a priori.

🛡️ Proposed fix
     seg_max = segment_max(data_for_max, segment_ids, num_segments)
     # guard -inf (empty / fully-masked segments) so gather doesn't yield inf-inf
     seg_max = xp.where(xp.isinf(seg_max), xp.zeros_like(seg_max), seg_max)
-    shifted = data - xp.take(seg_max, segment_ids, axis=0)
+    # use data_for_max (already -inf on masked entries) so masked entries
+    # exp() to exactly 0 instead of relying on a post-hoc inf*0 multiply
+    shifted = data_for_max - xp.take(seg_max, segment_ids, axis=0)
     ex = xp.exp(shifted)
     if mask is not None:
         ex = ex * xp.astype(mask, ex.dtype)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def segment_softmax(
data: Array,
segment_ids: Array,
num_segments: int,
mask: Array | None = None,
) -> Array:
"""Softmax over entries sharing a segment id, numerically stable.
Mirrors the dense ``np_softmax`` max-subtraction trick with a PER-SEGMENT
max. ``mask`` (bool, per entry) removes masked entries from the softmax
entirely (zero weight AND excluded from the denominator). Empty or
fully-masked segments produce all-zero weights (no NaN).
"""
xp = array_api_compat.array_namespace(data)
if mask is not None:
# keep masked entries out of the per-segment max: send them to -inf
neg = xp.full_like(data, -xp.inf)
data_for_max = xp.where(mask, data, neg)
else:
data_for_max = data
seg_max = segment_max(data_for_max, segment_ids, num_segments)
# guard -inf (empty / fully-masked segments) so gather doesn't yield inf-inf
seg_max = xp.where(xp.isinf(seg_max), xp.zeros_like(seg_max), seg_max)
shifted = data - xp.take(seg_max, segment_ids, axis=0)
ex = xp.exp(shifted)
if mask is not None:
ex = ex * xp.astype(mask, ex.dtype)
denom = segment_sum(ex, segment_ids, num_segments)
denom_e = xp.take(denom, segment_ids, axis=0)
safe = xp.where(denom_e > 0, denom_e, xp.ones_like(denom_e))
return ex / safe
def segment_softmax(
data: Array,
segment_ids: Array,
num_segments: int,
mask: Array | None = None,
) -> Array:
"""Softmax over entries sharing a segment id, numerically stable.
Mirrors the dense ``np_softmax`` max-subtraction trick with a PER-SEGMENT
max. ``mask`` (bool, per entry) removes masked entries from the softmax
entirely (zero weight AND excluded from the denominator). Empty or
fully-masked segments produce all-zero weights (no NaN).
"""
xp = array_api_compat.array_namespace(data)
if mask is not None:
# keep masked entries out of the per-segment max: send them to -inf
neg = xp.full_like(data, -xp.inf)
data_for_max = xp.where(mask, data, neg)
else:
data_for_max = data
seg_max = segment_max(data_for_max, segment_ids, num_segments)
# guard -inf (empty / fully-masked segments) so gather doesn't yield inf-inf
seg_max = xp.where(xp.isinf(seg_max), xp.zeros_like(seg_max), seg_max)
# use data_for_max (already -inf on masked entries) so masked entries
# exp() to exactly 0 instead of relying on a post-hoc inf*0 multiply
shifted = data_for_max - xp.take(seg_max, segment_ids, axis=0)
ex = xp.exp(shifted)
if mask is not None:
ex = ex * xp.astype(mask, ex.dtype)
denom = segment_sum(ex, segment_ids, num_segments)
denom_e = xp.take(denom, segment_ids, axis=0)
safe = xp.where(denom_e > 0, denom_e, xp.ones_like(denom_e))
return ex / safe
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/utils/neighbor_graph/segment.py` around lines 59 - 89, The
masked path in segment_softmax is using raw data for the exponent shift, which
can turn masked large values into inf and then nan after multiplying by the
mask. Update segment_softmax to compute shifted from data_for_max (the same
masked-safe values used for seg_max), and keep the existing empty/fully-masked
guards so exp and denom stay finite. Check the segment_max/segment_sum flow and
the _graph_attention_one_layer caller to ensure masked attention logits cannot
leak into the denominator.

@codecov

codecov Bot commented Jul 2, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 97.76119% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.17%. Comparing base (55d7e79) to head (91784df).

Files with missing lines Patch % Lines
deepmd/dpmodel/descriptor/dpa1.py 97.95% 1 Missing ⚠️
deepmd/dpmodel/utils/neighbor_graph/env.py 80.00% 1 Missing ⚠️
deepmd/dpmodel/utils/neighbor_graph/pairs.py 98.30% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5715      +/-   ##
==========================================
- Coverage   81.26%   81.17%   -0.10%     
==========================================
  Files         988      989       +1     
  Lines      110876   111007     +131     
  Branches     4234     4232       -2     
==========================================
+ Hits        90103    90106       +3     
- Misses      19247    19378     +131     
+ Partials     1526     1523       -3     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant