Skip to content

fix(tf): add InvalidArgument compatibility wrapper#5600

Merged
OutisLi merged 1 commit into
masterfrom
fix/tf-invalid-argument-compat
Jun 29, 2026
Merged

fix(tf): add InvalidArgument compatibility wrapper#5600
OutisLi merged 1 commit into
masterfrom
fix/tf-invalid-argument-compat

Conversation

@njzjz

@njzjz njzjz commented Jun 28, 2026

Copy link
Copy Markdown
Member

Summary

  • add a TensorFlow-version-gated deepmd::tf_compat::InvalidArgument helper
  • use absl::InvalidArgumentError for TensorFlow >= 2.20 and keep tensorflow::errors::InvalidArgument for older TensorFlow
  • route TF custom-op OP_REQUIRES InvalidArgument checks through the helper

Fixes #5006

@OutisLi Could you review this PR?

Tests

  • uvx ruff==0.15.18 check .
  • uvx ruff==0.15.18 format .
  • DP_ENABLE_PYTORCH=0 uv pip install -e '.[cpu,test]' with TensorFlow 2.21.0
  • dp --version
  • dp -h, dp --tf -h, dp --pt -h, dp --jax -h, dp --pd -h
  • python -c "import deepmd; import deepmd.tf; print('Both interfaces work')"
  • pytest source/tests/tf/test_dp_test.py::TestDPTestEner::test_1frame -v
  • standalone C++ build/install with TensorFlow 2.21.0 and PyTorch enabled, using DEEPMD_BYPASS_TORCH_CUDA_CHECK=ON for the local CPU PyTorch wheel

Summary by CodeRabbit

  • Bug Fixes
    • Improved TensorFlow compatibility across many operators by standardizing invalid-input error handling.
    • Validation failures now use a framework-compatible error path across supported TensorFlow versions.
    • Kept the same input checks and messages while making shape/rank mismatch errors more consistent.

@dosubot dosubot Bot added the bug label Jun 28, 2026
@njzjz njzjz requested a review from OutisLi June 28, 2026 15:11
@coderabbitai

coderabbitai Bot commented Jun 28, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

Introduces a deepmd::tf_compat namespace in custom_op.h with a Status alias and an InvalidArgument template that dispatches to absl::InvalidArgumentError (TF ≥ 2.20) or tensorflow::errors::InvalidArgument (older TF). All OP_REQUIRES validation calls across 30+ TF op kernel source files are then migrated to use deepmd::tf_compat::InvalidArgument.

Changes

TF compatibility error helper and op migration

Layer / File(s) Summary
deepmd::tf_compat helper
source/op/tf/custom_op.h
Adds <utility> and conditional Abseil includes; defines tf_compat::Status alias and InvalidArgument template selecting between absl::InvalidArgumentError and tensorflow::errors::InvalidArgument based on TF version macros.
OP_REQUIRES migration across all TF op kernels
source/op/tf/descrpt*.cc, source/op/tf/prod_env_mat*.cc, source/op/tf/prod_force*.cc, source/op/tf/prod_virial*.cc, source/op/tf/soft_min*.cc, source/op/tf/tabulate_multi_device.cc, source/op/tf/unaggregated_grad.cc, source/op/tf/ewald_recp.cc, source/op/tf/map_aparam.cc, source/op/tf/neighbor_stat.cc, source/op/tf/pair_tab.cc, source/op/tf/pairwise.cc
Replaces all errors::InvalidArgument(...) calls inside OP_REQUIRES validation blocks with deepmd::tf_compat::InvalidArgument(...). Validation conditions, error messages, and control flow are unchanged.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Suggested labels

C++, Core, enhancement

Suggested reviewers

  • iProzd
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title is concise and accurately summarizes the main change: adding a TensorFlow InvalidArgument compatibility wrapper.
Linked Issues check ✅ Passed The PR adds a TensorFlow-version-gated InvalidArgument wrapper and updates the affected tf custom ops, addressing #5006's deprecation warnings.
Out of Scope Changes check ✅ Passed All code changes stay within TensorFlow custom-op validation and the shared compat helper, with no unrelated features or subsystems added.
✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fix/tf-invalid-argument-compat

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
source/op/tf/prod_force_se_a_mask.cc (1)

39-68: 🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

Complete the shape guards before flattening these tensors.

This block still leaves mask_tensor, net_deriv_tensor, and nlist_tensor only partially validated. Lines 84-145 index them as fixed [nframes, total_atom_num, ...] layouts, so short mask/frame dimensions, too few net_deriv columns, or a narrow nlist will read past the allocation instead of failing fast.

Suggested fix
     OP_REQUIRES(
         context, (nframes == in_deriv_tensor.shape().dim_size(0)),
         deepmd::tf_compat::InvalidArgument("number of samples should match"));
     OP_REQUIRES(
         context, (nframes == nlist_tensor.shape().dim_size(0)),
         deepmd::tf_compat::InvalidArgument("number of samples should match"));
+    OP_REQUIRES(
+        context, (nframes == mask_tensor.shape().dim_size(0)),
+        deepmd::tf_compat::InvalidArgument("number of samples should match"));
+    OP_REQUIRES(
+        context,
+        (static_cast<int64_t>(nloc) * ndescrpt ==
+         net_deriv_tensor.shape().dim_size(1)),
+        deepmd::tf_compat::InvalidArgument(
+            "number of descriptors should match"));
     OP_REQUIRES(context,
                 (static_cast<int64_t>(nloc) * ndescrpt * 3 ==
                  in_deriv_tensor.shape().dim_size(1)),
                 deepmd::tf_compat::InvalidArgument(
                     "number of descriptors should match"));
+    OP_REQUIRES(
+        context, (total_atom_num == mask_tensor.shape().dim_size(1)),
+        deepmd::tf_compat::InvalidArgument("number of atoms should match"));
+    OP_REQUIRES(
+        context,
+        (static_cast<int64_t>(total_atom_num) * total_atom_num ==
+         nlist_tensor.shape().dim_size(1)),
+        deepmd::tf_compat::InvalidArgument("number of neighbors should match"));
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/op/tf/prod_force_se_a_mask.cc` around lines 39 - 68, The shape
validation in this op is incomplete, so add the missing guards in the same setup
block before any flattening or indexing in the main compute path. Use the
existing OP_REQUIRES checks around net_deriv_tensor, in_deriv_tensor,
mask_tensor, and nlist_tensor to verify the expected nframes and per-frame
widths for the fixed [nframes, total_atom_num, ...] layout, including
mask_tensor rows/columns, net_deriv_tensor column count, and nlist_tensor width,
so the later indexing logic cannot read past allocated tensors.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@source/op/tf/descrpt_se_a_mask.cc`:
- Around line 91-96: The `DescrptSeAMask` input validation is missing a size
check for `type_tensor`’s second dimension, even though `type(kk, ii)` is
accessed for every atom in the loop. Add an `OP_REQUIRES` in the same validation
block that compares `type_tensor.shape().dim_size(1)` against `total_atom_num`,
matching the existing checks for `coord_tensor` and `mask_matrix_tensor`, so
malformed inputs are rejected before `compute_input_stats`/the atom loop reads
past the buffer.

In `@source/op/tf/pair_tab.cc`:
- Around line 55-74: `PairTabOp::Compute` only checks the rank of
`table_info_tensor` and `table_data_tensor`, but later indexing in the
`table_info(...)` reads and the `table_data_tensor` access can still go out of
bounds for short 1-D inputs. Add explicit `dim_size(0)` validation for both
tensors before any indexing, using `OP_REQUIRES` with
`deepmd::tf_compat::InvalidArgument`, so invalid lengths fail gracefully instead
of relying on the debug-only assert.

In `@source/op/tf/pairwise.cc`:
- Around line 47-59: The validation in the pairwise op is incomplete:
`idxs_tensor` and `natoms_tensor` only have rank checks before `natoms(0)`,
`natoms(1)`, the `nall`-based loop, and the natoms output writes are used. Add
explicit size/shape checks in the `Pairwise` op path before any indexing or
iteration so malformed inputs fail through `OP_REQUIRES` with `InvalidArgument`;
use the existing `idxs_tensor`, `natoms_tensor`, `nframes`, and `natoms`
handling in this block and the later natoms output section as the locations to
tighten validation.
- Around line 240-244: The shape checks in the pairwise op are incomplete:
`sub_natoms_tensor` and `sub_forward_map_tensor` are used later without
validating their ranks and minimum sizes. Add explicit `OP_REQUIRES` checks in
the same validation block before `sub_natoms_tensor.vec<int>()` and before any
`sub_forward_map(ii, jj)` access, ensuring `sub_natoms_tensor` is 1-D with at
least 2 elements and `sub_forward_map_tensor` is 2-D with dimensions large
enough for the later indexing. Keep the existing `natoms_tensor` validation, and
make the new checks use `InvalidArgument` so malformed inputs fail cleanly
instead of hitting unchecked indexing in the pairwise op.

In `@source/op/tf/prod_virial_grad.cc`:
- Around line 96-98: The virial-grad shape check in OP_REQUIRES enforces a width
of 9, but the InvalidArgument text is still describing the wrong shape. Update
the error message in the prod_virial_grad.cc validation path to match the actual
expected tensor shape checked in the virial-grad logic, using the existing
OP_REQUIRES and deepmd::tf_compat::InvalidArgument call so the failure message
is accurate and consistent.

In `@source/op/tf/tabulate_multi_device.cc`:
- Around line 344-347: Validate last_layer_size before any GPU launch in
tabulate_multi_device.cc: the four GPU branches around
deepmd::tabulate_fusion_*_grad_grad_gpu currently call the kernel and only
afterward run OP_REQUIRES on last_layer_size <= 1024. Move the guard ahead of
each kernel invocation so invalid widths are rejected before launch, using the
existing OP_REQUIRES/deepmd::tf_compat::InvalidArgument check in each branch.

---

Outside diff comments:
In `@source/op/tf/prod_force_se_a_mask.cc`:
- Around line 39-68: The shape validation in this op is incomplete, so add the
missing guards in the same setup block before any flattening or indexing in the
main compute path. Use the existing OP_REQUIRES checks around net_deriv_tensor,
in_deriv_tensor, mask_tensor, and nlist_tensor to verify the expected nframes
and per-frame widths for the fixed [nframes, total_atom_num, ...] layout,
including mask_tensor rows/columns, net_deriv_tensor column count, and
nlist_tensor width, so the later indexing logic cannot read past allocated
tensors.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 53d5ee3c-0ce1-46b5-806e-2fcbe61182a3

📥 Commits

Reviewing files that changed from the base of the PR and between a9bcbc5 and 330696e.

📒 Files selected for processing (34)
  • source/op/tf/custom_op.h
  • source/op/tf/descrpt.cc
  • source/op/tf/descrpt_se_a_ef.cc
  • source/op/tf/descrpt_se_a_ef_para.cc
  • source/op/tf/descrpt_se_a_ef_vert.cc
  • source/op/tf/descrpt_se_a_mask.cc
  • source/op/tf/ewald_recp.cc
  • source/op/tf/map_aparam.cc
  • source/op/tf/neighbor_stat.cc
  • source/op/tf/pair_tab.cc
  • source/op/tf/pairwise.cc
  • source/op/tf/prod_env_mat_multi_device.cc
  • source/op/tf/prod_env_mat_multi_device_nvnmd.cc
  • source/op/tf/prod_force.cc
  • source/op/tf/prod_force_grad.cc
  • source/op/tf/prod_force_grad_multi_device.cc
  • source/op/tf/prod_force_multi_device.cc
  • source/op/tf/prod_force_se_a_grad.cc
  • source/op/tf/prod_force_se_a_mask.cc
  • source/op/tf/prod_force_se_a_mask_grad.cc
  • source/op/tf/prod_force_se_r_grad.cc
  • source/op/tf/prod_virial.cc
  • source/op/tf/prod_virial_grad.cc
  • source/op/tf/prod_virial_grad_multi_device.cc
  • source/op/tf/prod_virial_multi_device.cc
  • source/op/tf/prod_virial_se_a_grad.cc
  • source/op/tf/prod_virial_se_r_grad.cc
  • source/op/tf/soft_min.cc
  • source/op/tf/soft_min_force.cc
  • source/op/tf/soft_min_force_grad.cc
  • source/op/tf/soft_min_virial.cc
  • source/op/tf/soft_min_virial_grad.cc
  • source/op/tf/tabulate_multi_device.cc
  • source/op/tf/unaggregated_grad.cc

Comment thread source/op/tf/descrpt_se_a_mask.cc
Comment thread source/op/tf/pair_tab.cc
Comment thread source/op/tf/pairwise.cc
Comment thread source/op/tf/pairwise.cc
Comment thread source/op/tf/prod_virial_grad.cc
Comment thread source/op/tf/tabulate_multi_device.cc
@github-actions github-actions Bot added the OP label Jun 28, 2026
@codecov

codecov Bot commented Jun 28, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 0% with 262 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.35%. Comparing base (a9bcbc5) to head (330696e).

Files with missing lines Patch % Lines
source/op/tf/prod_env_mat_multi_device.cc 0.00% 30 Missing ⚠️
source/op/tf/prod_env_mat_multi_device_nvnmd.cc 0.00% 20 Missing ⚠️
source/op/tf/prod_virial_grad_multi_device.cc 0.00% 17 Missing ⚠️
source/op/tf/prod_virial_multi_device.cc 0.00% 14 Missing ⚠️
source/op/tf/prod_force_grad_multi_device.cc 0.00% 13 Missing ⚠️
source/op/tf/descrpt_se_a_ef.cc 0.00% 11 Missing ⚠️
source/op/tf/descrpt_se_a_ef_para.cc 0.00% 11 Missing ⚠️
source/op/tf/descrpt_se_a_ef_vert.cc 0.00% 11 Missing ⚠️
source/op/tf/descrpt.cc 0.00% 10 Missing ⚠️
source/op/tf/prod_force_multi_device.cc 0.00% 10 Missing ⚠️
... and 19 more
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5600      +/-   ##
==========================================
- Coverage   82.35%   82.35%   -0.01%     
==========================================
  Files         896      896              
  Lines      100952   100954       +2     
  Branches     4059     4057       -2     
==========================================
- Hits        83138    83136       -2     
- Misses      16349    16352       +3     
- Partials     1465     1466       +1     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@OutisLi OutisLi added this pull request to the merge queue Jun 29, 2026
Merged via the queue into master with commit 58bef11 Jun 29, 2026
131 checks passed
@OutisLi OutisLi deleted the fix/tf-invalid-argument-compat branch June 29, 2026 12:39
SchrodingersCattt pushed a commit to SchrodingersCattt/deepmd-kit that referenced this pull request Jun 30, 2026
## Summary
- add a TensorFlow-version-gated `deepmd::tf_compat::InvalidArgument`
helper
- use `absl::InvalidArgumentError` for TensorFlow >= 2.20 and keep
`tensorflow::errors::InvalidArgument` for older TensorFlow
- route TF custom-op `OP_REQUIRES` InvalidArgument checks through the
helper

Fixes deepmodeling#5006

@OutisLi Could you review this PR?

## Tests
- `uvx ruff==0.15.18 check .`
- `uvx ruff==0.15.18 format .`
- `DP_ENABLE_PYTORCH=0 uv pip install -e '.[cpu,test]'` with TensorFlow
2.21.0
- `dp --version`
- `dp -h`, `dp --tf -h`, `dp --pt -h`, `dp --jax -h`, `dp --pd -h`
- `python -c "import deepmd; import deepmd.tf; print('Both interfaces
work')"`
- `pytest source/tests/tf/test_dp_test.py::TestDPTestEner::test_1frame
-v`
- standalone C++ build/install with TensorFlow 2.21.0 and PyTorch
enabled, using `DEEPMD_BYPASS_TORCH_CUDA_CHECK=ON` for the local CPU
PyTorch wheel

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Improved TensorFlow compatibility across many operators by
standardizing invalid-input error handling.
* Validation failures now use a framework-compatible error path across
supported TensorFlow versions.
* Kept the same input checks and messages while making shape/rank
mismatch errors more consistent.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Compilation warnings with recent TensorFlow due to deprecated API 'tsl::errors::InvalidArgument'

2 participants