fix(jax): reuse atomic forward outputs in force path#5605
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthrough
JAX Aux-aware Atomic Forward Refactor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #5605 +/- ##
=======================================
Coverage 82.37% 82.37%
=======================================
Files 902 902
Lines 101529 101542 +13
Branches 4058 4057 -1
=======================================
+ Hits 83630 83646 +16
Misses 16434 16434
+ Partials 1465 1462 -3 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Summary
forward_common_atomic.jax.jacrev(..., has_aux=True)for reducible coordinate-differentiable outputs so the force path can reuse the transformed forward's primal atomic outputs instead of running a separate atomic forward.atomic_ret[kk].jax.hessian(..., has_aux=True)for the hessian branch and discard the auxiliary atomic outputs there.Requesting review from @wanghan-iapcm.
Benchmark
Command:
Setup:
examples/water/se_e2_a/input.jsonnumb_steps=1000,disp_freq=250,save_freq=100000058bef11c9e60455b57Batchwall times below are the training-reported interval wall times./usr/bin/time realSpeedup:
14.1567 s -> 10.1333 s, about1.40xfaster,28.4%less time.92.208 s -> 70.924 s, about1.30xfaster,23.1%less time./usr/bin/time real:96.01 s -> 74.47 s, about1.29xfaster,22.4%less time.The printed
lcurve.outrows matched exactly between baseline and this PR at steps 1, 250, 500, 750, and 1000.Tests
Summary by CodeRabbit