Skip to content

feat(jax): add JAX-MD interface#5590

Merged
njzjz merged 6 commits into
deepmodeling:masterfrom
njzjz:feat/jax-md-interface
Jun 28, 2026
Merged

feat(jax): add JAX-MD interface#5590
njzjz merged 6 commits into
deepmodeling:masterfrom
njzjz:feat/jax-md-interface

Conversation

@njzjz

@njzjz njzjz commented Jun 25, 2026

Copy link
Copy Markdown
Member

Summary

  • add a JAX-MD adapter for DeePMD JAX models, including energy, force, neighbor-list, and as_jax_md helpers
  • add focused JAX tests for energy/force evaluation, dense neighbor conversion, and real JAX-MD neighbor lists
  • add a water JAX-MD smoke example using dpdata to read the LAMMPS data file
  • document the JAX-MD interface in the third-party integration docs

Tests

  • ruff check .
  • PYTHONPATH=. pytest source/tests/jax/test_jax_md.py -q
  • JAX-MD water smoke run on GPU

Summary by CodeRabbit

  • New Features

    • Added JAX-MD support for DeePMD JAX models, including energy, force, neighbor-list, and model-loading utilities.
    • Added a new water simulation example that runs a short JAX-MD NVE smoke test from the command line.
  • Documentation

    • Added setup and usage guidance for the JAX-MD integration.
    • Included a new third-party package entry for JAX-MD in the docs.

@coderabbitai

coderabbitai Bot commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a DeePMD JAX-MD interface for loading .jax checkpoints, building energy and force callables, allocating dense neighbor lists, and converting dense JAX-MD neighbors into DeePMD lower-interface inputs. It also adds JAX-MD documentation, a water example script and README, and integration tests.

Changes

JAX-MD integration

Layer / File(s) Summary
Model loading and energy helpers
deepmd/jax/jax_md/__init__.py
Loads .jax checkpoints, rejects unsupported inputs, normalizes coordinates and parameters, and exports the public JAX-MD API.
Dense neighbor factory and docs
deepmd/jax/jax_md/__init__.py, doc/third-party/index.rst, doc/third-party/jaxmd.md
Builds dense neighbor lists, converts dense neighbors to lower inputs, and adds the JAX-MD third-party docs page and index entry.
Water example setup
examples/water/jax_md/README.md, examples/water/jax_md/run_jax_md.py
Adds the water example run instructions, CLI parsing, LAMMPS water loading, unit conversions, and NVE initialization.
Water example integration loop
examples/water/jax_md/run_jax_md.py
Prints initial thermo output and advances the NVE loop with per-step energy, kinetic energy, temperature, and overflow status.
JAX-MD tests
source/tests/jax/test_jax_md.py
Defines reference models and checks direct energy and force evaluation, dense-neighbor wrapping, rejection cases, and jax_md allocation.

Sequence Diagram(s)

sequenceDiagram
  participant RunJaxMd as "examples/water/jax_md/run_jax_md.py"
  participant AsJaxMd as "deepmd.jax.jax_md.as_jax_md"
  participant NeighborList as "jax_md.partition.neighbor_list"
  participant NVE as "jax_md.simulate.nve"
  participant MdStep as "md_step"

  RunJaxMd->>AsJaxMd: build neighbor_fn and potential
  AsJaxMd->>NeighborList: allocate dense neighbor list
  RunJaxMd->>NVE: construct NVE integrator
  RunJaxMd->>MdStep: advance state
  MdStep->>NeighborList: refresh neighbor indices
  MdStep->>RunJaxMd: return updated state and thermo values
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main change: adding a JAX-MD interface for the JAX backend.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
source/tests/jax/test_jax_md.py (1)

108-112: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value

Optional: prefer pytest.importorskip("jax_md") for the optional dependency.

Functionally equivalent to the find_spec + unittest.skipIf combination, but more idiomatic in a pytest suite and removes the need for the unittest/find_spec imports used only here.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/jax/test_jax_md.py` around lines 108 - 112, The optional jax_md
test currently uses a unittest-style skip check with find_spec, which is less
idiomatic in this pytest suite. Update test_actual_jax_md_neighbor_list to use
pytest.importorskip("jax_md") inside the test instead of the
`@unittest.skipIf`(find_spec(...)) decorator, and remove the now-unused unittest
and find_spec imports from this test module.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/jax/jax_md.py`:
- Around line 170-175: The neighbor-list creation path in neighbor_list
currently allows callers to override the format via kwargs.setdefault, which can
defer failures until _jax_md_neighbor_to_lower_inputs sees an incompatible
shape. Update neighbor_list to explicitly validate the requested
partition.NeighborListFormat before calling partition.neighbor_list, and reject
any non-Dense format with a clear error so unsupported sparse formats fail fast.
- Around line 179-197: Reject scalar metrics in as_jax_md and
_jax_md_neighbor_to_lower_inputs by adding an early shape validation before
ghost-coordinate math. Ensure displacement_or_metric passed into
neighbor_list/_jax_md_neighbor_to_lower_inputs returns vector displacements with
the same trailing shape as coordinate differences, and raise a clear error if a
scalar metric is provided. Use the existing symbols as_jax_md, neighbor_list,
and _jax_md_neighbor_to_lower_inputs to place the guard where the displacement
function is first consumed.

---

Nitpick comments:
In `@source/tests/jax/test_jax_md.py`:
- Around line 108-112: The optional jax_md test currently uses a unittest-style
skip check with find_spec, which is less idiomatic in this pytest suite. Update
test_actual_jax_md_neighbor_list to use pytest.importorskip("jax_md") inside the
test instead of the `@unittest.skipIf`(find_spec(...)) decorator, and remove the
now-unused unittest and find_spec imports from this test module.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 93bc6f18-93c7-4fea-8770-6d47044ae7d4

📥 Commits

Reviewing files that changed from the base of the PR and between 5733301 and 1585048.

📒 Files selected for processing (6)
  • deepmd/jax/jax_md.py
  • doc/third-party/index.rst
  • doc/third-party/jaxmd.md
  • examples/water/jax_md/README.md
  • examples/water/jax_md/run_jax_md.py
  • source/tests/jax/test_jax_md.py

Comment thread deepmd/jax/jax_md.py Outdated
Comment thread deepmd/jax/jax_md/__init__.py
@codecov

codecov Bot commented Jun 25, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 62.06897% with 66 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.32%. Comparing base (5733301) to head (4e0b1df).
⚠️ Report is 10 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/jax/jax_md/__init__.py 62.06% 66 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5590      +/-   ##
==========================================
+ Coverage   82.27%   82.32%   +0.04%     
==========================================
  Files         887      897      +10     
  Lines      100331   101126     +795     
  Branches     4060     4060              
==========================================
+ Hits        82550    83249     +699     
- Misses      16320    16412      +92     
- Partials     1461     1465       +4     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@OutisLi OutisLi left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps jax_md.py should not be placed at the root folder of jax backend

@njzjz

njzjz commented Jun 26, 2026

Copy link
Copy Markdown
Member Author

Perhaps jax_md.py should not be placed at the root folder of jax backend

@OutisLi which folder do you suggest?

Comment thread deepmd/jax/jax_md.py Outdated
@OutisLi

OutisLi commented Jun 27, 2026

Copy link
Copy Markdown
Collaborator

Perhaps jax_md.py should not be placed at the root folder of jax backend

@OutisLi which folder do you suggest?

Perhaps a third party folder or inference folder

@njzjz njzjz requested review from OutisLi and wanghan-iapcm June 27, 2026 12:50

@iProzd iProzd left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see a blocking issue.
One small documentation fix: load_model() now explicitly rejects .hlo with NotImplementedError, but doc/third-party/jaxmd.md still says .hlo is accepted. Please update that section to say that only .jax checkpoints (or an already constructed model object) are currently supported by the JAX-MD adapter, and that .hlo is not supported because the adapter needs JAX-differentiable energy functions.
LGTM after that doc tweak.

@njzjz njzjz enabled auto-merge June 28, 2026 14:30
@njzjz njzjz added this pull request to the merge queue Jun 28, 2026
Merged via the queue into deepmodeling:master with commit ac8e430 Jun 28, 2026
70 checks passed
@njzjz njzjz deleted the feat/jax-md-interface branch June 28, 2026 23:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants