Skip to content

fix(jax): reuse atomic forward outputs in force path#5605

Merged
wanghan-iapcm merged 1 commit into
deepmodeling:masterfrom
njzjz:fix/jax-forward-aux-force
Jun 30, 2026
Merged

fix(jax): reuse atomic forward outputs in force path#5605
wanghan-iapcm merged 1 commit into
deepmodeling:masterfrom
njzjz:fix/jax-forward-aux-force

Conversation

@njzjz

@njzjz njzjz commented Jun 29, 2026

Copy link
Copy Markdown
Member

Summary

  • Lazily materialize JAX atomic outputs in forward_common_atomic.
  • Use jax.jacrev(..., has_aux=True) for reducible coordinate-differentiable outputs so the force path can reuse the transformed forward's primal atomic outputs instead of running a separate atomic forward.
  • Keep the coordinate-weighted atomic virial correction as a separate transformed forward because that derivative must still flow through atomic_ret[kk].
  • Use jax.hessian(..., has_aux=True) for the hessian branch and discard the auxiliary atomic outputs there.

Requesting review from @wanghan-iapcm.

Benchmark

Command:

srun --gres=gpu:1 env PYTHONPATH=<baseline-or-this-worktree>:$PYTHONPATH \
  dp --jax train input.json --skip-neighbor-stat -o out.json

Setup:

  • Input derived from examples/water/se_e2_a/input.json
  • Absolute water data paths, numb_steps=1000, disp_freq=250, save_freq=1000000
  • JAX 0.8.0
  • GPU: NVIDIA GeForce RTX 5090
  • Baseline: 58bef11c9
  • This PR: e60455b57

Batch wall times below are the training-reported interval wall times.

Version Batch 1 Batch 250 Batch 500 Batch 750 Batch 1000 Training wall time /usr/bin/time real
Baseline 12.52 s 21.71 s 14.05 s 14.13 s 14.29 s 92.208 s 96.01 s
This PR 10.25 s 15.48 s 10.18 s 10.11 s 10.11 s 70.924 s 74.47 s

Speedup:

  • Steady-state average over the last three 250-step intervals: 14.1567 s -> 10.1333 s, about 1.40x faster, 28.4% less time.
  • End-to-end training wall time: 92.208 s -> 70.924 s, about 1.30x faster, 23.1% less time.
  • /usr/bin/time real: 96.01 s -> 74.47 s, about 1.29x faster, 22.4% less time.

The printed lcurve.out rows matched exactly between baseline and this PR at steps 1, 250, 500, 750, and 1000.

Tests

ruff format deepmd/jax/model/base_model.py
ruff check deepmd/jax/model/base_model.py
pytest source/tests/jax/test_dp_hessian_model.py -q
pytest source/tests/consistent/descriptor/test_se_e2_a.py -q -k test_jax_self_consistent

Summary by CodeRabbit

  • Bug Fixes
    • Improved coordinate-derivative handling for model outputs, which should make force and higher-order derivative calculations more reliable.
    • Reduced unnecessary recomputation during prediction, helping derived results compute more efficiently and consistently.

@njzjz njzjz requested a review from wanghan-iapcm June 29, 2026 17:13
@coderabbitai

coderabbitai Bot commented Jun 29, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 1faa0262-44f3-4e8e-a8d5-45e2ee5b4dea

📥 Commits

Reviewing files that changed from the base of the PR and between 58bef11 and e60455b.

📒 Files selected for processing (1)
  • deepmd/jax/model/base_model.py

📝 Walkthrough

Walkthrough

forward_common_atomic in deepmd/jax/model/base_model.py is refactored to replace eager atomic forward execution with a lazy get_atomic_ret() cache. An eval_output closure returning (primal, aux) is introduced; jax.jacrev and jax.hessian are updated to use has_aux=True, and extended_force construction is updated accordingly.

JAX Aux-aware Atomic Forward Refactor

Layer / File(s) Summary
Lazy caching, eval_output closure, and jacrev with has_aux
deepmd/jax/model/base_model.py
Replaces eager per-key atomic forward with get_atomic_ret() lazy cache; adds eval_output closure returning (primal, aux); vmaps jax.jacrev(..., has_aux=True) for coordinate Jacobian/force outputs.
Hessian, extended_force, and eval_ce comment
deepmd/jax/model/base_model.py
Updates hessian to jax.hessian(eval_output, has_aux=True) (vmapped); rebuilds extended_force via jnp.transpose into model_predict[kk_derv_r]; adds comment clarifying eval_ce must run its own forward rather than reuse cached primal.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly reflects the main JAX force-path change: reusing atomic forward outputs instead of recomputing them.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@codecov

codecov Bot commented Jun 29, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 82.37%. Comparing base (58bef11) to head (e60455b).
⚠️ Report is 1 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #5605   +/-   ##
=======================================
  Coverage   82.37%   82.37%           
=======================================
  Files         902      902           
  Lines      101529   101542   +13     
  Branches     4058     4057    -1     
=======================================
+ Hits        83630    83646   +16     
  Misses      16434    16434           
+ Partials     1465     1462    -3     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Jun 30, 2026
Merged via the queue into deepmodeling:master with commit 63845a5 Jun 30, 2026
70 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants