Skip to content

fix(jax): stabilize repflow dynamic selection export#5533

Merged
njzjz merged 4 commits into
deepmodeling:masterfrom
njzjz:fix/jax-repflow-static-dynamic-sel
Jun 19, 2026
Merged

fix(jax): stabilize repflow dynamic selection export#5533
njzjz merged 4 commits into
deepmodeling:masterfrom
njzjz:fix/jax-repflow-static-dynamic-sel

Conversation

@njzjz

@njzjz njzjz commented Jun 14, 2026

Copy link
Copy Markdown
Member

Summary

  • add an internal fixed-capacity dynamic-selection layout for repflows so JAX/jax2tf export avoids runtime-sized edge/angle tensors
  • skip unnecessary bincount in sum-only aggregate calls with a known owner count
  • add regression coverage comparing compact and static dynamic selection outputs

Validation

  • ruff check .
  • ruff format .
  • pytest source/tests/universal/dpmodel/descriptor/test_descriptor.py::TestDPA3StaticDynamicSelDP::test_static_dynamic_sel_matches_packed_dynamic_sel -q
  • dp convert-backend DPA-3.2-5M.pth DPA-3.2-5M.savedmodel

Summary by CodeRabbit

  • New Features
    • Added an export-friendly “static dynamic” execution mode for DPA3 repflow neighbor/edge/angle selection (fixed-capacity layout).
  • Bug Fixes
    • Improved static-dynamic indexing and edge count handling to keep outputs consistent, including correct padded-slot behavior and layout-mode preservation on restore.
    • Updated export-time warning when static-capacity materialization is triggered.
  • Documentation
    • Clarified DPA3 repflow/use_dynamic_sel angle-selection constraints for JAX export and memory scaling.
  • Chores
    • Refined aggregate to compute bin counts only when required.
  • Tests
    • Added tests verifying static-dynamic matches packed dynamic results and validating aggregate with explicit num_owner.

@coderabbitai

coderabbitai Bot commented Jun 14, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 8071aa5f-a4f5-461f-ac14-4d1065a15c13

📥 Commits

Reviewing files that changed from the base of the PR and between 6933bca and 9fd7aa9.

📒 Files selected for processing (3)
  • deepmd/utils/argcheck.py
  • source/tests/common/dpmodel/test_network.py
  • source/tests/universal/dpmodel/descriptor/test_descriptor.py
✅ Files skipped from review due to trivial changes (1)
  • deepmd/utils/argcheck.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • source/tests/universal/dpmodel/descriptor/test_descriptor.py

📝 Walkthrough

Walkthrough

Adds a backend-internal _use_static_dynamic_sel flag to DescrptBlockRepflows and RepFlowLayer, enabling a fixed-capacity (padded) execution path for edges and angles as an alternative to compact boolean-mask compaction. A new _get_static_graph_index helper builds the padded index tensors. The JAX subclasses override the flag to True. The aggregate utility is refactored to skip bin_count computation for pure sum reductions. Documentation is updated with JAX export memory constraints. An equivalence test verifies both paths produce matching outputs.

Changes

Static dynamic selection for RepFlows

Layer / File(s) Summary
Flag definition and instance snapshot
deepmd/dpmodel/descriptor/repflows.py
Defines _use_static_dynamic_sel as a class-level bool on both DescrptBlockRepflows and RepFlowLayer. Adds import warnings to support runtime alerts. In DescrptBlockRepflows.__init__, snapshots the class default to the instance and conditionally issues a warning when use_dynamic_sel and static mode are both enabled. RepFlowLayer.__init__ initializes the flag from the class default for the descriptor to override later.
Flag propagation to owned layers and deserialization
deepmd/dpmodel/descriptor/repflows.py
During DescrptBlockRepflows.__init__, propagates the snapped _use_static_dynamic_sel value to each owned RepFlowLayer. After deserialize(), re-copies the restored descriptor's flag to all restored layers to maintain consistent layout mode across the deserialized graph.
_get_static_graph_index helper and call() branching
deepmd/dpmodel/descriptor/repflows.py
Implements _get_static_graph_index returning fixed-capacity edge_index (2×n_edges) and angle_index (3×n_angles) with padded slots. Extends DescrptBlockRepflows.call() to branch on the flag: static path uses _get_static_graph_index and reshapes to flattened fixed capacities with full (j,k) angle gating; compact path keeps get_graph_index with boolean-mask compaction. Updates RepFlowLayer.call() to read n_edge from h2.shape[0] in static mode versus a masked sum in compact mode.
aggregate: conditional bin_count and validation tests
deepmd/dpmodel/utils/network.py, source/tests/common/dpmodel/test_network.py
Refactors aggregate to compute bin_count only when num_owner is absent or averaging is requested; skips bin_count for pure sum reductions by setting it to None. Allocates output directly with (num_owner, feature_dim) and asserts bin_count is not None before the averaging divide. Adds TestAggregate test class to validate both average=True and average=False modes with explicit num_owner parameter.
JAX backend override
deepmd/jax/descriptor/repflows.py
Sets _use_static_dynamic_sel = True on both DescrptBlockRepflows and RepFlowLayer JAX subclasses to activate fixed-capacity layout for JAX/jax2tf export.
Documentation updates for JAX export constraints
deepmd/utils/argcheck.py
Updates DPA3 repflow argument documentation to note JAX export requires static shapes. Extended doc_a_sel and doc_use_dynamic_sel to warn that angle-pair work arrays scale with nf × nloc × a_sel², advising users to keep a_sel minimal for JAX memory efficiency.
Equivalence test for static vs packed dynamic selection
source/tests/universal/dpmodel/descriptor/test_descriptor.py
Adds TestDPA3StaticDynamicSelDP with imports for NumPy, DescrptBlockRepflows, and TestCaseSingleFrameWithNlist. Includes a _make_dpa3 helper that toggles _use_static_dynamic_sel at construction time, and a test that asserts packed dynamic and static dynamic outputs match by masking padded slots via nlist != -1.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • deepmodeling/deepmd-kit#5355: Both PRs modify deepmd/dpmodel/descriptor/repflows.py in the same core classes (DescrptBlockRepflows/RepFlowLayer) and extend call/forward control-flow with new execution-mode flags, so they overlap at the graph-execution logic level.
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically summarizes the main change: stabilizing repflow dynamic selection export for JAX, which is the primary objective of this PR.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
source/tests/universal/dpmodel/descriptor/test_descriptor.py (1)

915-933: ⚡ Quick win

Cover the use_loc_mapping=False static-index branch too.

_get_static_graph_index() changes its indexing stride when use_loc_mapping is disabled, but this helper always constructs the default mapped configuration. Extending this test to run both values would cover the second branch of the new static-dynamic layout logic that JAX now enables by default.

♻️ Suggested test expansion
-    def _make_dpa3(self, use_static_dynamic_sel: bool) -> DescrptDPA3:
+    def _make_dpa3(
+        self,
+        use_static_dynamic_sel: bool,
+        *,
+        use_loc_mapping: bool,
+    ) -> DescrptDPA3:
         # The switch is intentionally class-level and internal, so tests toggle
         # it only around construction and then restore the previous backend mode.
         old_use_static_dynamic_sel = DescrptBlockRepflows._use_static_dynamic_sel
         DescrptBlockRepflows._use_static_dynamic_sel = use_static_dynamic_sel
         try:
             return DescrptDPA3(
                 **DescriptorParamDPA3(
                     self.nt,
                     self.rcut,
                     self.rcut_smth,
                     self.sel,
                     ["O", "H"],
                     smooth_edge_update=True,
                     use_dynamic_sel=True,
+                    use_loc_mapping=use_loc_mapping,
                 )
             )
         finally:
             DescrptBlockRepflows._use_static_dynamic_sel = old_use_static_dynamic_sel

     def test_static_dynamic_sel_matches_packed_dynamic_sel(self) -> None:
-        packed = self._make_dpa3(False)
-        static = self._make_dpa3(True)
+        for use_loc_mapping in (True, False):
+            packed = self._make_dpa3(False, use_loc_mapping=use_loc_mapping)
+            static = self._make_dpa3(True, use_loc_mapping=use_loc_mapping)
 
-        packed_out = packed(
-            self.coord_ext,
-            self.atype_ext,
-            self.nlist,
-            mapping=self.mapping,
-        )
-        static_out = static(
-            self.coord_ext,
-            self.atype_ext,
-            self.nlist,
-            mapping=self.mapping,
-        )
+            packed_out = packed(
+                self.coord_ext,
+                self.atype_ext,
+                self.nlist,
+                mapping=self.mapping,
+            )
+            static_out = static(
+                self.coord_ext,
+                self.atype_ext,
+                self.nlist,
+                mapping=self.mapping,
+            )

-        np.testing.assert_allclose(packed_out[0], static_out[0], atol=self.atol)
-        np.testing.assert_allclose(packed_out[1], static_out[1], atol=self.atol)
+            np.testing.assert_allclose(packed_out[0], static_out[0], atol=self.atol)
+            np.testing.assert_allclose(packed_out[1], static_out[1], atol=self.atol)

-        valid_edge_mask = np.reshape(self.nlist != -1, (-1,))
-        assert static_out[2].shape[0] == self.nf * self.nloc * sum(self.sel)
-        np.testing.assert_allclose(
-            packed_out[2], static_out[2][valid_edge_mask], atol=self.atol
-        )
-        np.testing.assert_allclose(
-            packed_out[3], static_out[3][valid_edge_mask], atol=self.atol
-        )
-        np.testing.assert_allclose(
-            packed_out[4], static_out[4][valid_edge_mask], atol=self.atol
-        )
+            valid_edge_mask = np.reshape(self.nlist != -1, (-1,))
+            assert static_out[2].shape[0] == self.nf * self.nloc * sum(self.sel)
+            np.testing.assert_allclose(
+                packed_out[2], static_out[2][valid_edge_mask], atol=self.atol
+            )
+            np.testing.assert_allclose(
+                packed_out[3], static_out[3][valid_edge_mask], atol=self.atol
+            )
+            np.testing.assert_allclose(
+                packed_out[4], static_out[4][valid_edge_mask], atol=self.atol
+            )

Also applies to: 935-968

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/universal/dpmodel/descriptor/test_descriptor.py` around lines
915 - 933, The _make_dpa3 helper method currently only constructs the default
mapped configuration, but _get_static_graph_index() has different indexing
behavior when use_loc_mapping is disabled. Extend the _make_dpa3 method to
accept a parameter for use_loc_mapping (similar to how it accepts
use_static_dynamic_sel) and ensure it applies this parameter when constructing
DescriptorParamDPA3. Then update all test methods that use _make_dpa3 (including
those at lines 935-968) to run test assertions with both use_loc_mapping=True
and use_loc_mapping=False so that both branches of the static-dynamic layout
logic are covered.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@source/tests/universal/dpmodel/descriptor/test_descriptor.py`:
- Around line 915-933: The _make_dpa3 helper method currently only constructs
the default mapped configuration, but _get_static_graph_index() has different
indexing behavior when use_loc_mapping is disabled. Extend the _make_dpa3 method
to accept a parameter for use_loc_mapping (similar to how it accepts
use_static_dynamic_sel) and ensure it applies this parameter when constructing
DescriptorParamDPA3. Then update all test methods that use _make_dpa3 (including
those at lines 935-968) to run test assertions with both use_loc_mapping=True
and use_loc_mapping=False so that both branches of the static-dynamic layout
logic are covered.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: b5b5ca58-1ff0-4fbd-beea-b662619d0e7b

📥 Commits

Reviewing files that changed from the base of the PR and between c0b0319 and b338bb1.

📒 Files selected for processing (4)
  • deepmd/dpmodel/descriptor/repflows.py
  • deepmd/dpmodel/utils/network.py
  • deepmd/jax/descriptor/repflows.py
  • source/tests/universal/dpmodel/descriptor/test_descriptor.py

@codecov

codecov Bot commented Jun 14, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 98.27586% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 82.17%. Comparing base (c0b0319) to head (9fd7aa9).
⚠️ Report is 23 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/dpmodel/utils/network.py 90.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5533      +/-   ##
==========================================
- Coverage   82.18%   82.17%   -0.01%     
==========================================
  Files         890      896       +6     
  Lines      101357   102688    +1331     
  Branches     4240     4343     +103     
==========================================
+ Hits        83301    84385    +1084     
- Misses      16754    16966     +212     
- Partials     1302     1337      +35     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@njzjz njzjz requested review from iProzd and wanghan-iapcm June 14, 2026 17:32
@njzjz njzjz added the Test CUDA Trigger test CUDA workflow label Jun 16, 2026
@github-actions github-actions Bot removed the Test CUDA Trigger test CUDA workflow label Jun 16, 2026
Comment thread deepmd/jax/descriptor/repflows.py

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/dpmodel/descriptor/repflows.py`:
- Around line 280-285: In the warnings.warn() call that checks
self.use_dynamic_sel and self._use_static_dynamic_sel, add the stacklevel=2
parameter to the warnings.warn() function invocation. This parameter should be
added as a keyword argument after the warning message string to ensure the
warning points to the caller's instantiation site rather than the internal line
where the warning is issued.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 75695fc0-9c46-4dbf-ad41-383d69efbd80

📥 Commits

Reviewing files that changed from the base of the PR and between b338bb1 and 31516c6.

📒 Files selected for processing (1)
  • deepmd/dpmodel/descriptor/repflows.py

Comment thread deepmd/dpmodel/descriptor/repflows.py
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Jinzhe Zeng <njzjz@qq.com>

@iProzd iProzd left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two non-blocking suggestions:

  1. The a_sel**2 angle materialization is the main caveat. The warnings.warn helps, but it'd be good to also document in the export docs that JAX export materializes nf * nloc * a_sel**2 angle slots, so a_sel should be kept modest for exportable models.
  2. Patch coverage on network.py is low (~30%) — the new aggregate branches (the num_owner concat path and the average=True assert) aren't exercised. A small unit test targeting aggregate directly would prevent regressions.

LGTM otherwise.

@njzjz njzjz enabled auto-merge June 19, 2026 18:58
@njzjz njzjz added this pull request to the merge queue Jun 19, 2026
Merged via the queue into deepmodeling:master with commit 5a0d505 Jun 19, 2026
70 checks passed
@njzjz njzjz deleted the fix/jax-repflow-static-dynamic-sel branch June 19, 2026 23:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants