fix(tf): add InvalidArgument compatibility wrapper#5600
Conversation
📝 WalkthroughWalkthroughIntroduces a ChangesTF compatibility error helper and op migration
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
source/op/tf/prod_force_se_a_mask.cc (1)
39-68: 🩺 Stability & Availability | 🟠 Major | ⚡ Quick winComplete the shape guards before flattening these tensors.
This block still leaves
mask_tensor,net_deriv_tensor, andnlist_tensoronly partially validated. Lines 84-145 index them as fixed[nframes, total_atom_num, ...]layouts, so short mask/frame dimensions, too fewnet_derivcolumns, or a narrownlistwill read past the allocation instead of failing fast.Suggested fix
OP_REQUIRES( context, (nframes == in_deriv_tensor.shape().dim_size(0)), deepmd::tf_compat::InvalidArgument("number of samples should match")); OP_REQUIRES( context, (nframes == nlist_tensor.shape().dim_size(0)), deepmd::tf_compat::InvalidArgument("number of samples should match")); + OP_REQUIRES( + context, (nframes == mask_tensor.shape().dim_size(0)), + deepmd::tf_compat::InvalidArgument("number of samples should match")); + OP_REQUIRES( + context, + (static_cast<int64_t>(nloc) * ndescrpt == + net_deriv_tensor.shape().dim_size(1)), + deepmd::tf_compat::InvalidArgument( + "number of descriptors should match")); OP_REQUIRES(context, (static_cast<int64_t>(nloc) * ndescrpt * 3 == in_deriv_tensor.shape().dim_size(1)), deepmd::tf_compat::InvalidArgument( "number of descriptors should match")); + OP_REQUIRES( + context, (total_atom_num == mask_tensor.shape().dim_size(1)), + deepmd::tf_compat::InvalidArgument("number of atoms should match")); + OP_REQUIRES( + context, + (static_cast<int64_t>(total_atom_num) * total_atom_num == + nlist_tensor.shape().dim_size(1)), + deepmd::tf_compat::InvalidArgument("number of neighbors should match"));🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/op/tf/prod_force_se_a_mask.cc` around lines 39 - 68, The shape validation in this op is incomplete, so add the missing guards in the same setup block before any flattening or indexing in the main compute path. Use the existing OP_REQUIRES checks around net_deriv_tensor, in_deriv_tensor, mask_tensor, and nlist_tensor to verify the expected nframes and per-frame widths for the fixed [nframes, total_atom_num, ...] layout, including mask_tensor rows/columns, net_deriv_tensor column count, and nlist_tensor width, so the later indexing logic cannot read past allocated tensors.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@source/op/tf/descrpt_se_a_mask.cc`:
- Around line 91-96: The `DescrptSeAMask` input validation is missing a size
check for `type_tensor`’s second dimension, even though `type(kk, ii)` is
accessed for every atom in the loop. Add an `OP_REQUIRES` in the same validation
block that compares `type_tensor.shape().dim_size(1)` against `total_atom_num`,
matching the existing checks for `coord_tensor` and `mask_matrix_tensor`, so
malformed inputs are rejected before `compute_input_stats`/the atom loop reads
past the buffer.
In `@source/op/tf/pair_tab.cc`:
- Around line 55-74: `PairTabOp::Compute` only checks the rank of
`table_info_tensor` and `table_data_tensor`, but later indexing in the
`table_info(...)` reads and the `table_data_tensor` access can still go out of
bounds for short 1-D inputs. Add explicit `dim_size(0)` validation for both
tensors before any indexing, using `OP_REQUIRES` with
`deepmd::tf_compat::InvalidArgument`, so invalid lengths fail gracefully instead
of relying on the debug-only assert.
In `@source/op/tf/pairwise.cc`:
- Around line 47-59: The validation in the pairwise op is incomplete:
`idxs_tensor` and `natoms_tensor` only have rank checks before `natoms(0)`,
`natoms(1)`, the `nall`-based loop, and the natoms output writes are used. Add
explicit size/shape checks in the `Pairwise` op path before any indexing or
iteration so malformed inputs fail through `OP_REQUIRES` with `InvalidArgument`;
use the existing `idxs_tensor`, `natoms_tensor`, `nframes`, and `natoms`
handling in this block and the later natoms output section as the locations to
tighten validation.
- Around line 240-244: The shape checks in the pairwise op are incomplete:
`sub_natoms_tensor` and `sub_forward_map_tensor` are used later without
validating their ranks and minimum sizes. Add explicit `OP_REQUIRES` checks in
the same validation block before `sub_natoms_tensor.vec<int>()` and before any
`sub_forward_map(ii, jj)` access, ensuring `sub_natoms_tensor` is 1-D with at
least 2 elements and `sub_forward_map_tensor` is 2-D with dimensions large
enough for the later indexing. Keep the existing `natoms_tensor` validation, and
make the new checks use `InvalidArgument` so malformed inputs fail cleanly
instead of hitting unchecked indexing in the pairwise op.
In `@source/op/tf/prod_virial_grad.cc`:
- Around line 96-98: The virial-grad shape check in OP_REQUIRES enforces a width
of 9, but the InvalidArgument text is still describing the wrong shape. Update
the error message in the prod_virial_grad.cc validation path to match the actual
expected tensor shape checked in the virial-grad logic, using the existing
OP_REQUIRES and deepmd::tf_compat::InvalidArgument call so the failure message
is accurate and consistent.
In `@source/op/tf/tabulate_multi_device.cc`:
- Around line 344-347: Validate last_layer_size before any GPU launch in
tabulate_multi_device.cc: the four GPU branches around
deepmd::tabulate_fusion_*_grad_grad_gpu currently call the kernel and only
afterward run OP_REQUIRES on last_layer_size <= 1024. Move the guard ahead of
each kernel invocation so invalid widths are rejected before launch, using the
existing OP_REQUIRES/deepmd::tf_compat::InvalidArgument check in each branch.
---
Outside diff comments:
In `@source/op/tf/prod_force_se_a_mask.cc`:
- Around line 39-68: The shape validation in this op is incomplete, so add the
missing guards in the same setup block before any flattening or indexing in the
main compute path. Use the existing OP_REQUIRES checks around net_deriv_tensor,
in_deriv_tensor, mask_tensor, and nlist_tensor to verify the expected nframes
and per-frame widths for the fixed [nframes, total_atom_num, ...] layout,
including mask_tensor rows/columns, net_deriv_tensor column count, and
nlist_tensor width, so the later indexing logic cannot read past allocated
tensors.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 53d5ee3c-0ce1-46b5-806e-2fcbe61182a3
📒 Files selected for processing (34)
source/op/tf/custom_op.hsource/op/tf/descrpt.ccsource/op/tf/descrpt_se_a_ef.ccsource/op/tf/descrpt_se_a_ef_para.ccsource/op/tf/descrpt_se_a_ef_vert.ccsource/op/tf/descrpt_se_a_mask.ccsource/op/tf/ewald_recp.ccsource/op/tf/map_aparam.ccsource/op/tf/neighbor_stat.ccsource/op/tf/pair_tab.ccsource/op/tf/pairwise.ccsource/op/tf/prod_env_mat_multi_device.ccsource/op/tf/prod_env_mat_multi_device_nvnmd.ccsource/op/tf/prod_force.ccsource/op/tf/prod_force_grad.ccsource/op/tf/prod_force_grad_multi_device.ccsource/op/tf/prod_force_multi_device.ccsource/op/tf/prod_force_se_a_grad.ccsource/op/tf/prod_force_se_a_mask.ccsource/op/tf/prod_force_se_a_mask_grad.ccsource/op/tf/prod_force_se_r_grad.ccsource/op/tf/prod_virial.ccsource/op/tf/prod_virial_grad.ccsource/op/tf/prod_virial_grad_multi_device.ccsource/op/tf/prod_virial_multi_device.ccsource/op/tf/prod_virial_se_a_grad.ccsource/op/tf/prod_virial_se_r_grad.ccsource/op/tf/soft_min.ccsource/op/tf/soft_min_force.ccsource/op/tf/soft_min_force_grad.ccsource/op/tf/soft_min_virial.ccsource/op/tf/soft_min_virial_grad.ccsource/op/tf/tabulate_multi_device.ccsource/op/tf/unaggregated_grad.cc
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #5600 +/- ##
==========================================
- Coverage 82.35% 82.35% -0.01%
==========================================
Files 896 896
Lines 100952 100954 +2
Branches 4059 4057 -2
==========================================
- Hits 83138 83136 -2
- Misses 16349 16352 +3
- Partials 1465 1466 +1 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
## Summary - add a TensorFlow-version-gated `deepmd::tf_compat::InvalidArgument` helper - use `absl::InvalidArgumentError` for TensorFlow >= 2.20 and keep `tensorflow::errors::InvalidArgument` for older TensorFlow - route TF custom-op `OP_REQUIRES` InvalidArgument checks through the helper Fixes deepmodeling#5006 @OutisLi Could you review this PR? ## Tests - `uvx ruff==0.15.18 check .` - `uvx ruff==0.15.18 format .` - `DP_ENABLE_PYTORCH=0 uv pip install -e '.[cpu,test]'` with TensorFlow 2.21.0 - `dp --version` - `dp -h`, `dp --tf -h`, `dp --pt -h`, `dp --jax -h`, `dp --pd -h` - `python -c "import deepmd; import deepmd.tf; print('Both interfaces work')"` - `pytest source/tests/tf/test_dp_test.py::TestDPTestEner::test_1frame -v` - standalone C++ build/install with TensorFlow 2.21.0 and PyTorch enabled, using `DEEPMD_BYPASS_TORCH_CUDA_CHECK=ON` for the local CPU PyTorch wheel <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved TensorFlow compatibility across many operators by standardizing invalid-input error handling. * Validation failures now use a framework-compatible error path across supported TensorFlow versions. * Kept the same input checks and messages while making shape/rank mismatch errors more consistent. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
Summary
deepmd::tf_compat::InvalidArgumenthelperabsl::InvalidArgumentErrorfor TensorFlow >= 2.20 and keeptensorflow::errors::InvalidArgumentfor older TensorFlowOP_REQUIRESInvalidArgument checks through the helperFixes #5006
@OutisLi Could you review this PR?
Tests
uvx ruff==0.15.18 check .uvx ruff==0.15.18 format .DP_ENABLE_PYTORCH=0 uv pip install -e '.[cpu,test]'with TensorFlow 2.21.0dp --versiondp -h,dp --tf -h,dp --pt -h,dp --jax -h,dp --pd -hpython -c "import deepmd; import deepmd.tf; print('Both interfaces work')"pytest source/tests/tf/test_dp_test.py::TestDPTestEner::test_1frame -vDEEPMD_BYPASS_TORCH_CUDA_CHECK=ONfor the local CPU PyTorch wheelSummary by CodeRabbit