fix(pd): preserve dtype in Paddle RepFlow dynamic aggregation#5712
fix(pd): preserve dtype in Paddle RepFlow dynamic aggregation#5712wanghan-iapcm wants to merge 1 commit into
Conversation
aggregate() allocated its output with paddle.zeros(...) without a dtype, so it fell back to paddle's default float (float32) and then cast the input data to that dtype before index_add_. For float64 RepFlow/DPA3 models this silently downcast descriptor updates to float32. Allocate the output with dtype=data.dtype so the input precision is preserved. Add a test that the summation path keeps float64. Fix deepmodeling#5688
📝 WalkthroughWalkthroughThe ChangesDtype Preservation Fix
Estimated code review effort: 1 (Trivial) | ~5 minutes Related issues: Suggested labels: bug, paddle, tests Suggested reviewers: (none identified from provided context) 🐰 A tensor once forgot its type, 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pd/model/network/utils.py`:
- Around line 42-44: The `output = paddle.zeros(...)` allocation in `utils.py`
does not explicitly guarantee it is created on the same device as `data` and
`owners`, so update this helper to create `output` inside the matching device
context used by `index_add_` (or otherwise bind it to the input tensor’s device)
while preserving `data.dtype`; if you do not enforce that, revise the
surrounding comment to remove the “same device” claim.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 5aaf4211-1938-4e86-9c67-35cf3547050a
📒 Files selected for processing (2)
deepmd/pd/model/network/utils.pysource/tests/pd/model/test_aggregate.py
| # make sure this operation is done on the same device of data and owners, | ||
| # and preserve the input dtype instead of falling back to the default float | ||
| output = paddle.zeros([num_owner, data.shape[1]], dtype=data.dtype) |
There was a problem hiding this comment.
🗄️ Data Integrity & Integration | 🟡 Minor | ⚡ Quick win
🧩 Analysis chain
🌐 Web query:
Does paddle.zeros accept a "place" keyword argument to control output tensor device in current stable Paddle releases?
💡 Result:
No, the standard stable releases of PaddlePaddle do not provide a "place" keyword argument for the paddle.zeros function to control the output tensor device [1][2][3][4]. In current stable versions, paddle.zeros accepts the arguments: shape, dtype, and name [2][3][4]. The function creates the tensor on the currently set global device, which can be managed using paddle.device.set_device [5][6]. While some internal development branches (represented in recent development documentation) may show an expanded signature including a "device" parameter [5], this is not part of the standard stable API. To create a tensor on a specific device in stable Paddle, it is common practice to use context management (e.g., paddle.device.device_guard) [7] or to create the tensor on the default device and then use the Tensor.to method to move it to the desired device [8].
Citations:
- 1: https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/zeros_en.html
- 2: https://www.paddlepaddle.org.cn/documentation/docs/en/3.0/api/paddle/zeros_en.html
- 3: https://paddlepaddle-static.cdn.bcebos.com/documentation/docs/en/api/paddle/zeros_en.html
- 4: https://www.paddlepaddle.org.cn/documentation/docs/en/2.6/api/paddle/zeros_en.html
- 5: https://www.paddlepaddle.org.cn/documentation/docs/en/develop/api/paddle/zeros_en.html
- 6: https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/device/__init__.py
- 7: https://github.com/PaddlePaddle/Paddle/wiki/PaddlePaddle-3.1.1-Release-Note-EN
- 8: https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/Tensor/Overview_en.html
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Locate the target file and show the relevant section with line numbers.
git ls-files deepmd/pd/model/network/utils.py
echo '---'
sed -n '1,140p' deepmd/pd/model/network/utils.py
echo '--- SEARCH: device guards / set_device usage near this utility'
rg -n "set_device|device_guard|paddle\.zeros|index_add_" deepmd/pd/model/network -SRepository: deepmodeling/deepmd-kit
Length of output: 5831
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Show the file with line numbers if present.
if [ -f deepmd/pd/model/network/utils.py ]; then
nl -ba deepmd/pd/model/network/utils.py | sed -n '1,140p'
fiRepository: deepmodeling/deepmd-kit
Length of output: 201
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Inspect the target function and nearby device-context handling.
nl -ba deepmd/pd/model/network/utils.py | sed -n '1,180p'
echo '---'
rg -n "set_device|device_guard|place=|paddle\.zeros|index_add_" deepmd/pd/model/network/utils.py deepmd/pd/model/network -SRepository: deepmodeling/deepmd-kit
Length of output: 201
Make the device guarantee explicit paddle.zeros here only follows the current device; it does not enforce a match with data/owners. If this helper can be called with tensors on a non-default device, create output under the matching device context before index_add_, or drop the “same device” wording from the comment.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/pd/model/network/utils.py` around lines 42 - 44, The `output =
paddle.zeros(...)` allocation in `utils.py` does not explicitly guarantee it is
created on the same device as `data` and `owners`, so update this helper to
create `output` inside the matching device context used by `index_add_` (or
otherwise bind it to the input tensor’s device) while preserving `data.dtype`;
if you do not enforce that, revise the surrounding comment to remove the “same
device” claim.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #5712 +/- ##
==========================================
- Coverage 81.97% 81.11% -0.86%
==========================================
Files 959 981 +22
Lines 105748 109859 +4111
Branches 4102 4235 +133
==========================================
+ Hits 86684 89113 +2429
- Misses 17573 19220 +1647
- Partials 1491 1526 +35 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Problem
The Paddle
aggregate()helper indeepmd/pd/model/network/utils.pyallocated its output withpaddle.zeros([num_owner, data.shape[1]])without specifying a dtype, so it fell back to Paddle's default floating dtype (float32). It then cast the inputdatatooutput.dtypebeforeindex_add_. For float64 Paddle RepFlow/DPA3 models, the dynamic-selection aggregation therefore accumulated descriptor updates in float32, silently downcasting intermediate values before returning them to the descriptor path.Fix
Allocate the output with
dtype=data.dtypeso the input precision is preserved (and the subsequentdata.astype(output.dtype)becomes a no-op for matching dtypes).Test
A new test aggregates float64 input and asserts the result stays float64 (and has the expected values). On the current code the output is float32; after the fix it is float64. This exercises the shared
output = paddle.zeros(..., dtype=data.dtype)allocation via the summation path.Note on verification
Verified locally with
paddlepaddle==3.3.1. The CI target is a newer nightly (paddlepaddle==3.4.0.dev20260310), but the behavior fixed here is version-agnostic:paddle.zeroswithoutdtypedefaults to float32 in all versions, anddtype=data.dtypecorrects it regardless. The test is scoped to the summation path to keep it independent of an unrelatedTensor.whereAPI difference in the older local Paddle.Fix #5688
Summary by CodeRabbit
Bug Fixes
Tests
float64precision and produces the expected summed results.