feat(jax): add JAX-MD interface#5590
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a DeePMD JAX-MD interface for loading ChangesJAX-MD integration
Sequence Diagram(s)sequenceDiagram
participant RunJaxMd as "examples/water/jax_md/run_jax_md.py"
participant AsJaxMd as "deepmd.jax.jax_md.as_jax_md"
participant NeighborList as "jax_md.partition.neighbor_list"
participant NVE as "jax_md.simulate.nve"
participant MdStep as "md_step"
RunJaxMd->>AsJaxMd: build neighbor_fn and potential
AsJaxMd->>NeighborList: allocate dense neighbor list
RunJaxMd->>NVE: construct NVE integrator
RunJaxMd->>MdStep: advance state
MdStep->>NeighborList: refresh neighbor indices
MdStep->>RunJaxMd: return updated state and thermo values
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
source/tests/jax/test_jax_md.py (1)
108-112: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueOptional: prefer
pytest.importorskip("jax_md")for the optional dependency.Functionally equivalent to the
find_spec+unittest.skipIfcombination, but more idiomatic in a pytest suite and removes the need for theunittest/find_specimports used only here.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/jax/test_jax_md.py` around lines 108 - 112, The optional jax_md test currently uses a unittest-style skip check with find_spec, which is less idiomatic in this pytest suite. Update test_actual_jax_md_neighbor_list to use pytest.importorskip("jax_md") inside the test instead of the `@unittest.skipIf`(find_spec(...)) decorator, and remove the now-unused unittest and find_spec imports from this test module.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/jax/jax_md.py`:
- Around line 170-175: The neighbor-list creation path in neighbor_list
currently allows callers to override the format via kwargs.setdefault, which can
defer failures until _jax_md_neighbor_to_lower_inputs sees an incompatible
shape. Update neighbor_list to explicitly validate the requested
partition.NeighborListFormat before calling partition.neighbor_list, and reject
any non-Dense format with a clear error so unsupported sparse formats fail fast.
- Around line 179-197: Reject scalar metrics in as_jax_md and
_jax_md_neighbor_to_lower_inputs by adding an early shape validation before
ghost-coordinate math. Ensure displacement_or_metric passed into
neighbor_list/_jax_md_neighbor_to_lower_inputs returns vector displacements with
the same trailing shape as coordinate differences, and raise a clear error if a
scalar metric is provided. Use the existing symbols as_jax_md, neighbor_list,
and _jax_md_neighbor_to_lower_inputs to place the guard where the displacement
function is first consumed.
---
Nitpick comments:
In `@source/tests/jax/test_jax_md.py`:
- Around line 108-112: The optional jax_md test currently uses a unittest-style
skip check with find_spec, which is less idiomatic in this pytest suite. Update
test_actual_jax_md_neighbor_list to use pytest.importorskip("jax_md") inside the
test instead of the `@unittest.skipIf`(find_spec(...)) decorator, and remove the
now-unused unittest and find_spec imports from this test module.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 93bc6f18-93c7-4fea-8770-6d47044ae7d4
📒 Files selected for processing (6)
deepmd/jax/jax_md.pydoc/third-party/index.rstdoc/third-party/jaxmd.mdexamples/water/jax_md/README.mdexamples/water/jax_md/run_jax_md.pysource/tests/jax/test_jax_md.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5590 +/- ##
==========================================
+ Coverage 82.27% 82.32% +0.04%
==========================================
Files 887 897 +10
Lines 100331 101126 +795
Branches 4060 4060
==========================================
+ Hits 82550 83249 +699
- Misses 16320 16412 +92
- Partials 1461 1465 +4 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
OutisLi
left a comment
There was a problem hiding this comment.
Perhaps jax_md.py should not be placed at the root folder of jax backend
@OutisLi which folder do you suggest? |
Perhaps a third party folder or inference folder |
iProzd
left a comment
There was a problem hiding this comment.
I do not see a blocking issue.
One small documentation fix: load_model() now explicitly rejects .hlo with NotImplementedError, but doc/third-party/jaxmd.md still says .hlo is accepted. Please update that section to say that only .jax checkpoints (or an already constructed model object) are currently supported by the JAX-MD adapter, and that .hlo is not supported because the adapter needs JAX-differentiable energy functions.
LGTM after that doc tweak.
Summary
as_jax_mdhelpersTests
ruff check .PYTHONPATH=. pytest source/tests/jax/test_jax_md.py -qSummary by CodeRabbit
New Features
Documentation