feat(dpa4): so3_readout across pt + dpmodel/pt_expt backends#5561
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: bdefabe7a2
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (3)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughAdds a Changesso3_readout feature for DPA4/SeZM
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt/model/descriptor/sezm.py`:
- Around line 1236-1243: When `so3_readout != "none"`, the `ffn_in` variable is
set to the full `x` tensor without truncation, but `x` may retain its initial
dimension (`self.node_ebed_dims[0]`) instead of the final expected dimension
(`self.node_ebed_dims[-1]`) when `_forward_blocks` is skipped. This causes a
shape mismatch when feeding to `output_ffn`. In the else branch of the
conditional that sets `ffn_in`, truncate `x` to match the final node embedding
dimension before casting to the compute dtype, similar to the slicing applied in
the `if self.so3_readout == "none"` branch, to ensure the tensor shape matches
what `output_ffn` expects.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: dd93cd81-2191-41e0-a1ea-4c93376cdf20
📒 Files selected for processing (6)
deepmd/dpmodel/descriptor/dpa4.pydeepmd/pt/model/descriptor/sezm.pydeepmd/utils/argcheck.pyexamples/water/dpa4/input.jsonsource/tests/consistent/descriptor/test_dpa4.pysource/tests/pt/model/test_dpa4_dpmodel_parity.py
…e path) AI-review (CodeRabbit/codex) finding on deepmodeling#5561: with so3_readout=glu/mlp and a shrinking l_schedule, the empty-edge path skips _forward_blocks, leaving x at the initial node degree (node_ebed_dims[0]); the full x was fed to output_ffn built for node_ebed_dims[-1] -> SO3Linear einsum shape mismatch on isolated atoms. Truncate the readout input to node_ebed_dims[-1] (no-op once blocks ran). - pt sezm.py: slice x to node_ebed_dims[-1] in both readout sites (forward, forward_with_edges). - dpmodel dpa4.py: same truncation for symmetry/robustness (no-op there since padded-edge blocks always shrink x). - regression test: so3_readout glu/mlp + shrinking schedule + all-isolated nlist (proven to fail pre-fix with the einsum size 4 vs 9 mismatch).
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5561 +/- ##
=======================================
Coverage 82.15% 82.15%
=======================================
Files 898 898
Lines 103306 103317 +11
Branches 4410 4412 +2
=======================================
+ Hits 84867 84885 +18
+ Misses 17065 17058 -7
Partials 1374 1374 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
What
Adds the DPA4/SeZM
so3_readoutoption ("none"/"glu"/"mlp") across all backends (pt + dpmodel + pt_expt), making it cross-backend consistent.Builds on #5556 (@OutisLi): its pt
so3_readoutcommit (refactor(dpa4): output ffn) is included here with original authorship preserved; this PR adds the missing dpmodel counterpart so the shared DPA4 serialize format round-trips across backends. (The unrelatednv_nlist/compiler-check fixes from #5556 are intentionally left to #5556.)Why
so3_readoutis implemented by configuring the final output FFN —"glu"/"mlp"turn on the SO(3)-grid FFN (ffn_so3_grid,grid_mlp). On its own (pt-only, as in #5556) it breaks DPA4 cross-backend consistency: ptserialize()emitsso3_readoutbut dpmodelDescrptDPA4couldn't round-trip it →source/tests/consistent/descriptor/test_dpa4.py::...::test_pt_consistent_with_reffailed on every Test Python shard.This is now feasible and small because #5555 already ported
ffn_so3_grid+ the SO(3)-grid machinery (SO3GridNet/GridMLP/GridProduct) into the dpmodelEquivariantFFN. So the dpmodelso3_readoutis just: accept the param, configureoutput_ffnexactly like pt, wire the readout (l=0 slice for"none"; full(N,D,1,C)fold for"glu"/"mlp"then slice l=0), and serialize the key. pt_expt auto-wraps.Changes
so3_readoutinDescrptSeZM+ argcheck +examples/water/dpa4/input.json.descriptor/dpa4.py:so3_readoutparam + validation;output_ffnconfigured (lmax=node_l_schedule[-1],kmax=min(kmax, readout_lmax),ffn_so3_grid,grid_mlp,grid_branch=0); readout forward mirrors pt; serialize the key.Validation
test_dpa4.pycross-backend consistency rows forso3_readout ∈ {none, glu, mlp}(pt vs dpmodel vs pt_expt, mixed_types) — green;test_pt_consistent_with_refnow passes.glu+mlp) — ~7e-15 abs (gate 1e-10), proving serialize interop.Notes
so3_readoutno longer "(Supported Backend: PyTorch)" — now multi-backend.Summary by CodeRabbit
Release Notes
New Features
so3_readoutoption to the DPA4 and SeZM descriptors (modes:"none","glu","mlp"), controlling how the final SO(3) readout is computed.so3_readout: "mlp".Tests