Skip to content

fix(jax): restore jax2tf savedmodel export#5613

Merged
wanghan-iapcm merged 4 commits into
deepmodeling:masterfrom
njzjz:fix/jax-savedmodel-jax2tf-export
Jul 1, 2026
Merged

fix(jax): restore jax2tf savedmodel export#5613
wanghan-iapcm merged 4 commits into
deepmodeling:masterfrom
njzjz:fix/jax-savedmodel-jax2tf-export

Conversation

@njzjz

@njzjz njzjz commented Jun 30, 2026

Copy link
Copy Markdown
Member

Summary

  • restore the JAX .savedmodel conversion path to use jax2tf.convert(...) instead of the TF2 SavedModel exporter
  • keep the TF2 eager exporter on the .savedmodeltf suffix
  • restore the graph-safe TensorFlow helper code used around the jax2tf-converted model body and document why these helpers should not be collapsed into TF2/ndtensorflow shims

Background

source/tests/infer/convert-models.sh documents .savedmodel as the JAX/JAX2TF output suffix and .savedmodeltf as the TF2 output suffix. After #5598, the JAX .savedmodel branch delegated to the TF2 exporter, so freshly exported .savedmodel and .savedmodeltf artifacts had the same ordinary TF op structure and no XlaCallModule nodes.

This restores the historical contract: .savedmodel is a JAX/jax2tf artifact and should contain XlaCallModule; .savedmodeltf remains the TF2 eager SavedModel artifact.

TF2 JIT note

I also checked why DP_JIT=1 on the TF2 exporter does not create XlaCallModule nodes. A minimal tf.function(jit_compile=True) SavedModel in TF 2.21 stores _XlaMustCompile: true on the FunctionDef/PartitionedCall, but the serialized graph still contains ordinary TF ops and no XlaCallModule. In contrast, a minimal jax2tf.convert(...) SavedModel serializes XlaCallModule nodes. So XlaCallModule is a marker for the jax2tf native serialization path, not for generic TF2 jit_compile=True.

Validation

  • dp convert-backend source/tests/infer/deeppot_sea.yaml /tmp/.../deeppot_sea.savedmodel, parsed saved_model.pb: 8 XlaCallModule ops
  • dp convert-backend source/tests/infer/deeppot_dpa.yaml /tmp/.../deeppot_dpa.savedmodel, parsed saved_model.pb: 8 XlaCallModule ops
  • dp convert-backend source/tests/infer/deeppot_sea.yaml /tmp/.../deeppot_sea.savedmodeltf, parsed saved_model.pb: 0 XLA op names, as expected for TF2 eager export
  • ruff format .
  • ruff check .

Please review, @wanghan-iapcm.

Summary by CodeRabbit

  • New Features

    • Enhanced JAX-based SavedModel export with additional execution endpoints and metadata queries (including neighbor-list/selection and output configuration).
    • Added in-graph coordinate transforms and improved neighbor-list/periodic ghost handling for export.
  • Bug Fixes

    • Corrected neighbor-list padding/truncation, cutoff masking, and stable type-distinguished ordering.
    • Improved handling of virtual atoms and empty-cell periodic extension cases.
  • Tests

    • Migrated neighbor-list/region tests to pure TensorFlow ops.
    • Added coverage for model call behavior and SavedModel export contents.

@coderabbitai

coderabbitai Bot commented Jun 30, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 0c7132bf-d569-4600-a26f-60179beee5b0

📥 Commits

Reviewing files that changed from the base of the PR and between 976fe89 and 0b3809c.

📒 Files selected for processing (4)
  • source/jax2tf_tests/test_make_model.py
  • source/jax2tf_tests/test_nlist.py
  • source/jax2tf_tests/test_region.py
  • source/jax2tf_tests/test_serialization.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • source/jax2tf_tests/test_region.py

📝 Walkthrough

Walkthrough

The deepmd/jax/jax2tf/ modules now implement TensorFlow graph logic directly for region math, neighbor-list handling, model-call preprocessing, and SavedModel export, and the jax2tf test suite was updated to use TensorFlow ops and add export coverage.

Changes

JAX/jax2tf In-Graph Rewrite

Layer / File(s) Summary
Region geometry helpers
deepmd/jax/jax2tf/region.py
Replaces wrapper functions with direct TensorFlow implementations for coordinate transforms and face-distance calculations, and adds phys2inter and b_to_face_distance.
Neighbor list and ghost extension
deepmd/jax/jax2tf/nlist.py
Reimplements neighbor-list construction, type separation, and periodic ghost-coordinate extension as in-graph TensorFlow logic.
Neighbor-list formatting
deepmd/jax/jax2tf/format_nlist.py
Replaces the wrapper around format_nlist with a traced TensorFlow function that reshapes, pads, masks, sorts, filters by rcut, and truncates to nsel.
Model call preprocessing
deepmd/jax/jax2tf/make_model.py
Replaces TF2 delegation with a direct JAX2TF model-call pipeline that normalizes coordinates, extends ghosts, builds the neighbor list, calls call_lower, and postprocesses outputs.
SavedModel export wiring
deepmd/jax/jax2tf/serialization.py, deepmd/jax/utils/serialization.py
Implements local SavedModel export, call-lower wrappers, metadata endpoints, and the final tf.saved_model.save call; updates the utility import to the new JAX serializer.
TensorFlow-based jax2tf tests
source/jax2tf_tests/*
Updates the jax2tf test modules to use TensorFlow ops and adds coverage for model-call normalization, empty-cell neighbor extension, and SavedModel export contents.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • deepmodeling/deepmd-kit#5598: Directly touches the same deepmd/jax/jax2tf/ entrypoints and appears to be the prior shim-oriented counterpart to this rewrite.

Suggested labels

bug

Suggested reviewers

  • wanghan-iapcm
  • OutisLi
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 22.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly states the main change: restoring the JAX jax2tf SavedModel export path.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/jax/jax2tf/region.py`:
- Around line 48-56: The b_to_face_distance helper is using the signed
determinant from tf.linalg.det, which can yield negative face distances for
left-handed cells. Update b_to_face_distance to use the absolute cell volume
before computing h2yz, h2zx, and h2xy so the returned distances are always
positive and extend_coord_with_ghosts can size ghost ranges correctly.

In `@deepmd/jax/jax2tf/serialization.py`:
- Around line 35-37: The .savedmodel deserialization path in `deserialize` is
missing `min_nbor_dist` restoration, so the metadata can be lost even when
present in the serialized payload. Update the `data["model"]` handling to also
read and pass through `min_nbor_dist` from the payload, consistent with the
`.hlo` export path, so later `get_min_nbor_dist` remains available after
loading.
- Around line 295-299: The SavedModel export in serialization.py drops the
legacy do_message_passing alias, which can break existing loaders that still
expect it. In the export setup around has_message_passing, keep
tf_model.do_message_passing pointing to the same tf.function as
has_message_passing while also retaining tf_model.has_message_passing, so both
the new and legacy endpoints are available.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: a946784a-44d0-4fa5-a276-340dc28d8a73

📥 Commits

Reviewing files that changed from the base of the PR and between 73de44b and 02097f1.

📒 Files selected for processing (6)
  • deepmd/jax/jax2tf/format_nlist.py
  • deepmd/jax/jax2tf/make_model.py
  • deepmd/jax/jax2tf/nlist.py
  • deepmd/jax/jax2tf/region.py
  • deepmd/jax/jax2tf/serialization.py
  • deepmd/jax/utils/serialization.py

Comment thread deepmd/jax/jax2tf/region.py
Comment thread deepmd/jax/jax2tf/serialization.py
Comment thread deepmd/jax/jax2tf/serialization.py

@njzjz njzjz left a comment

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed the CodeRabbit review comments in ecc6d20.

Comment thread deepmd/jax/jax2tf/serialization.py
Comment thread deepmd/jax/jax2tf/nlist.py
Comment thread deepmd/jax/jax2tf/serialization.py
@njzjz njzjz requested a review from wanghan-iapcm June 30, 2026 08:25
@codecov

codecov Bot commented Jun 30, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 71.69811% with 75 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.97%. Comparing base (73de44b) to head (0b3809c).

Files with missing lines Patch % Lines
deepmd/jax/jax2tf/format_nlist.py 5.71% 33 Missing ⚠️
deepmd/jax/jax2tf/serialization.py 74.31% 28 Missing ⚠️
deepmd/jax/jax2tf/make_model.py 18.75% 13 Missing ⚠️
deepmd/jax/jax2tf/nlist.py 98.78% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5613      +/-   ##
==========================================
- Coverage   81.98%   81.97%   -0.02%     
==========================================
  Files         959      959              
  Lines      105430   105659     +229     
  Branches     4071     4073       +2     
==========================================
+ Hits        86442    86609     +167     
- Misses      17518    17579      +61     
- Partials     1470     1471       +1     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Jul 1, 2026
Merged via the queue into deepmodeling:master with commit e582360 Jul 1, 2026
70 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants