fix(jax): restore jax2tf savedmodel export#5613
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (4)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughThe ChangesJAX/jax2tf In-Graph Rewrite
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/jax/jax2tf/region.py`:
- Around line 48-56: The b_to_face_distance helper is using the signed
determinant from tf.linalg.det, which can yield negative face distances for
left-handed cells. Update b_to_face_distance to use the absolute cell volume
before computing h2yz, h2zx, and h2xy so the returned distances are always
positive and extend_coord_with_ghosts can size ghost ranges correctly.
In `@deepmd/jax/jax2tf/serialization.py`:
- Around line 35-37: The .savedmodel deserialization path in `deserialize` is
missing `min_nbor_dist` restoration, so the metadata can be lost even when
present in the serialized payload. Update the `data["model"]` handling to also
read and pass through `min_nbor_dist` from the payload, consistent with the
`.hlo` export path, so later `get_min_nbor_dist` remains available after
loading.
- Around line 295-299: The SavedModel export in serialization.py drops the
legacy do_message_passing alias, which can break existing loaders that still
expect it. In the export setup around has_message_passing, keep
tf_model.do_message_passing pointing to the same tf.function as
has_message_passing while also retaining tf_model.has_message_passing, so both
the new and legacy endpoints are available.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: a946784a-44d0-4fa5-a276-340dc28d8a73
📒 Files selected for processing (6)
deepmd/jax/jax2tf/format_nlist.pydeepmd/jax/jax2tf/make_model.pydeepmd/jax/jax2tf/nlist.pydeepmd/jax/jax2tf/region.pydeepmd/jax/jax2tf/serialization.pydeepmd/jax/utils/serialization.py
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #5613 +/- ##
==========================================
- Coverage 81.98% 81.97% -0.02%
==========================================
Files 959 959
Lines 105430 105659 +229
Branches 4071 4073 +2
==========================================
+ Hits 86442 86609 +167
- Misses 17518 17579 +61
- Partials 1470 1471 +1 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Summary
.savedmodelconversion path to usejax2tf.convert(...)instead of the TF2 SavedModel exporter.savedmodeltfsuffixBackground
source/tests/infer/convert-models.shdocuments.savedmodelas the JAX/JAX2TF output suffix and.savedmodeltfas the TF2 output suffix. After #5598, the JAX.savedmodelbranch delegated to the TF2 exporter, so freshly exported.savedmodeland.savedmodeltfartifacts had the same ordinary TF op structure and noXlaCallModulenodes.This restores the historical contract:
.savedmodelis a JAX/jax2tf artifact and should containXlaCallModule;.savedmodeltfremains the TF2 eager SavedModel artifact.TF2 JIT note
I also checked why
DP_JIT=1on the TF2 exporter does not createXlaCallModulenodes. A minimaltf.function(jit_compile=True)SavedModel in TF 2.21 stores_XlaMustCompile: trueon the FunctionDef/PartitionedCall, but the serialized graph still contains ordinary TF ops and noXlaCallModule. In contrast, a minimaljax2tf.convert(...)SavedModel serializesXlaCallModulenodes. SoXlaCallModuleis a marker for the jax2tf native serialization path, not for generic TF2jit_compile=True.Validation
dp convert-backend source/tests/infer/deeppot_sea.yaml /tmp/.../deeppot_sea.savedmodel, parsedsaved_model.pb: 8XlaCallModuleopsdp convert-backend source/tests/infer/deeppot_dpa.yaml /tmp/.../deeppot_dpa.savedmodel, parsedsaved_model.pb: 8XlaCallModuleopsdp convert-backend source/tests/infer/deeppot_sea.yaml /tmp/.../deeppot_sea.savedmodeltf, parsedsaved_model.pb: 0 XLA op names, as expected for TF2 eager exportruff format .ruff check .Please review, @wanghan-iapcm.
Summary by CodeRabbit
New Features
Bug Fixes
Tests