Skip to content

fix: dpa4 compile aliasing collision#5483

Merged
anyangml merged 9 commits into
deepmodeling:masterfrom
anyangml:fix/dpa4-multitask-compile
Jun 6, 2026
Merged

fix: dpa4 compile aliasing collision#5483
anyangml merged 9 commits into
deepmodeling:masterfrom
anyangml:fix/dpa4-multitask-compile

Conversation

@anyangml

@anyangml anyangml commented Jun 1, 2026

Copy link
Copy Markdown
Collaborator

Summary by CodeRabbit

  • Bug Fixes

    • Improved edge-list construction and tracing stability to avoid out-of-range/index issues during model tracing.
  • Performance

    • Reused compiled callables across compatible multitask branches via shared caching for faster multi-task compilation and reduced duplication.
  • Changes

    • Model export/tracing now consistently includes the charge_spin input slot to match runtime expectations.
  • Tests

    • Updated multitask compile-cache test to expect callable reuse across compatible branches.

Copilot AI review requested due to automatic review settings June 1, 2026 08:43
@anyangml anyangml marked this pull request as draft June 1, 2026 08:43
@dosubot dosubot Bot added the bug label Jun 1, 2026
@github-actions github-actions Bot added the Python label Jun 1, 2026
@coderabbitai

coderabbitai Bot commented Jun 1, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

SeZMModel now shares compiled ener callables across tasks with compatible module structure by promoting per-task buffers to FX placeholders and passing them as runtime varargs. Tracing uses prime-based safe trace sizes and pads/clamps inputs. Edge-list indexing/docs updated for symbolic tracing stability. Tests adjusted to expect callable reuse.

Changes

Multi-task compile cache with shared graphs and export adjustments

Layer / File(s) Summary
Top-level and dummy-edge docs
deepmd/pt/model/model/sezm_model.py
Module and NOTE docstrings updated to describe the appended masked dummy edges and index_select‑based edge-vector construction under symbolic shapes.
Cache infrastructure and helper functions
deepmd/pt/model/model/sezm_model.py
Add module-level compile cache, _sezm_structure_key, and helpers to detect/promote task-specific buffer names and fetch current per-instance buffer tensor values.
Trace shape hardening with prime dimensions
deepmd/pt/model/model/sezm_model.py
Add _is_prime, _next_safe_prime, and _trace_pad_dim to choose collision-free prime trace dimensions and coerce/pad real trace tensors to those sizes.
trace_and_compile refactor with buffer detection and patching
deepmd/pt/model/model/sezm_model.py
Check module-level shared cache early, compute structure/compile keys, detect promoted buffers, and define patch/restore closures that swap buffers into the model during tracing.
Compute function closures with patching and prime-based shapes
deepmd/pt/model/model/sezm_model.py
Both coord-correction compute_fn closures accept *task_buf_vals, patch/restore promoted buffers during tracing, and the extended_coord_corr path applies prime-based selection and pads/trims/clamps trace inputs.
Trace input construction with promoted buffer tensors
deepmd/pt/model/model/sezm_model.py
Append per-task buffer tensors after fixed trace args so make_fx creates separate placeholders for each promoted buffer; record per-instance buffer ordering for reuse.
Cache population and compiler tuning
deepmd/pt/model/model/sezm_model.py
Store compiled callable into per-instance and module-level caches and record task-buffer ordering metadata; document avoiding dist.barrier() to prevent deadlock.
Runtime buffer values in compiled forward path
deepmd/pt/model/model/sezm_model.py
At compiled ener forward, read current per-task buffer tensors and pass them as extra positional varargs to the compiled callable for both branches so the compiled graph is task-reusable.
Export path charge_spin threading
deepmd/pt/model/model/sezm_model.py
Ensure forward_common_lower_exportable trace includes a charge_spin slot so the traced closure matches the runtime 7-tuple signature expected by freezing.
Multitask compile cache test coverage
source/tests/pt/model/test_sezm_model.py
Update test commentary and change assertion to require that branches sharing descriptor/fitting reuse the same compiled callable object (assertIs), while branch cache dicts remain separate.

Edge-list index and documentation updates

Layer / File(s) Summary
Edge list index arithmetic and docs
deepmd/pt/model/model/sezm_model.py
Derive dst_actual from neighbor_flat.shape[0], update f_idx/dst_local arithmetic, and tighten docstrings/return-shape text describing the appended masked dummy edges and their False masks.

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

  • deepmodeling/deepmd-kit#5457: Related FX/torch.compile changes that detect task buffers, build a structure key, and pass promoted buffers as varargs to enable multi-task compiled-graph reuse.

Suggested reviewers

  • wanghan-iapcm
  • njzjz-bot
  • njzjz
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 inconclusive)

Check name Status Explanation Resolution
Title check ❓ Inconclusive The title 'fix: dpa4 compile aliasing collision' is vague and does not clearly convey the main technical change, which involves SeZM's compiled energy path, multi-task compile sharing, module-level caching, and trace-shape selection improvements. Consider revising the title to be more specific about the primary change, such as 'fix: enable SeZM multi-task compile sharing with structure-based caching' or similar, to better reflect the technical scope of the modifications.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed Docstring coverage is 80.95% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Comment thread deepmd/pt/model/model/sezm_model.py Dismissed
Comment thread deepmd/pt/model/model/sezm_model.py Dismissed

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR attempts to improve/repair the PyTorch-compiled execution path for the SeZM/DPA4 model, primarily by reducing recompiles/OOM in multi-task setups and addressing symbolic-shape tracing issues in make_fx.

Changes:

  • Add module-level compile sharing and promote selected per-task buffers (e.g., out_bias, bias_atom_e, case_embd) as FX inputs to enable compiled-graph reuse across shared-parameter tasks.
  • Add additional symbolic-shape anti-aliasing logic for trace inputs and temporarily disable ShapeEnv duck sizing during tracing.
  • Change edge-list construction to append a single masked dummy edge (instead of two) and adjust related documentation/behavior.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread deepmd/pt/model/model/sezm_model.py
Comment thread deepmd/pt/model/model/sezm_model.py
Comment thread deepmd/pt/model/model/sezm_model.py Outdated
Comment thread deepmd/pt/model/model/sezm_model.py Outdated
@codecov

codecov Bot commented Jun 1, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 91.19171% with 17 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.39%. Comparing base (967e525) to head (b23ce5e).
⚠️ Report is 7 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/pt/model/model/sezm_model.py 91.19% 17 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5483      +/-   ##
==========================================
+ Coverage   81.34%   81.39%   +0.04%     
==========================================
  Files         868      868              
  Lines       96373    96772     +399     
  Branches     4233     4240       +7     
==========================================
+ Hits        78399    78771     +372     
- Misses      16675    16698      +23     
- Partials     1299     1303       +4     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@anyangml anyangml marked this pull request as ready for review June 2, 2026 10:29

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/pt/model/model/sezm_model.py (1)

2735-2744: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Clear the shared compile cache when the ener head is reset.

This branch clears only self.compiled_core_compute_cache, but Line 1721 can immediately repopulate it from _SEZM_COMPILE_CACHE when the structure key is unchanged. That bypasses the retrace promised by this method and can resurrect a callable traced against the pre-reset head.

Suggested change
         else:
+            structure_key = _sezm_structure_key(self)
+            stale_keys = [
+                key
+                for key in _SEZM_COMPILE_CACHE
+                if key[: len(structure_key)] == structure_key
+            ]
+            for key in stale_keys:
+                _SEZM_COMPILE_CACHE.pop(key, None)
+            _SEZM_TASK_BUF_ORDER.pop(structure_key, None)
             self._core_compute_pending_compile_t0 = None
             self._core_compute_pending_compile_key = None
             # Drop every compile slot so the next forward retraces against the
             # reinitialised fitting head.
             self.compiled_core_compute_cache.clear()
+            self._task_buf_order_cache.clear()
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/model/sezm_model.py` around lines 2735 - 2744, The ener-head
reset branch currently clears only instance cache (compiled_core_compute_cache)
but leaves the shared cache (_SEZM_COMPILE_CACHE) intact so a subsequent lookup
(see code that repopulates from _SEZM_COMPILE_CACHE using the structure key) can
resurrect callables traced against the old head; update the else branch that
resets the ener head to also invalidate the shared compile cache for the same
structure key: when you set _core_compute_pending_compile_key to None and call
compiled_core_compute_cache.clear(), also remove any entries in
_SEZM_COMPILE_CACHE that correspond to the previous structure key (or clear the
shared cache entirely) so retracing is forced (operate on the same key variable
used to populate _SEZM_COMPILE_CACHE in the forward path).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt/model/model/sezm_model.py`:
- Around line 465-483: The current _sezm_structure_key function risks returning
identical keys for different tasks because it only samples the first child of
descriptor and fitting_net; update it to either (preferred) derive the key from
all shared parameter objects (e.g., use a frozenset of ids for all parameters
that are shared between tasks) or (safer/alternative) detect whether the entire
descriptor and fitting_net stacks are fully shared and raise an error if not, so
partial sharing cannot collapse different tasks into the same
_SEZM_COMPILE_CACHE entry; locate the logic in _sezm_structure_key and use
SeZMModel.descriptor / SeZMModel.fitting_net plus the model.share_params
semantics to compute the full parameter-based key or to assert full-stack
sharing before returning a key for cache reuse.
- Line 1772: The loop using zip(task_buf_names, vals) can silently truncate if
the iterables differ; change it to zip(task_buf_names, vals, strict=True) to
make mismatched lengths raise an error. Locate the loop that iterates over
task_buf_names and vals (the line currently written as "for name, val in
zip(task_buf_names, vals):") and update it to include strict=True so any length
mismatch is detected immediately.

---

Outside diff comments:
In `@deepmd/pt/model/model/sezm_model.py`:
- Around line 2735-2744: The ener-head reset branch currently clears only
instance cache (compiled_core_compute_cache) but leaves the shared cache
(_SEZM_COMPILE_CACHE) intact so a subsequent lookup (see code that repopulates
from _SEZM_COMPILE_CACHE using the structure key) can resurrect callables traced
against the old head; update the else branch that resets the ener head to also
invalidate the shared compile cache for the same structure key: when you set
_core_compute_pending_compile_key to None and call
compiled_core_compute_cache.clear(), also remove any entries in
_SEZM_COMPILE_CACHE that correspond to the previous structure key (or clear the
shared cache entirely) so retracing is forced (operate on the same key variable
used to populate _SEZM_COMPILE_CACHE in the forward path).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: b86f7a23-46a3-4ec5-b8a9-d4a97ba0f001

📥 Commits

Reviewing files that changed from the base of the PR and between 967e525 and 52bcebc.

📒 Files selected for processing (2)
  • deepmd/pt/model/model/sezm_model.py
  • source/tests/pt/model/test_sezm_model.py

Comment thread deepmd/pt/model/model/sezm_model.py
Comment thread deepmd/pt/model/model/sezm_model.py
@anyangml anyangml changed the title fix: try fix dpa4 compile fix: dpa4 compile aliasing collision Jun 3, 2026
@anyangml anyangml added the Test CUDA Trigger test CUDA workflow label Jun 3, 2026
@github-actions github-actions Bot removed the Test CUDA Trigger test CUDA workflow label Jun 3, 2026
Comment thread deepmd/pt/model/model/sezm_model.py Outdated
Comment thread deepmd/pt/model/model/sezm_model.py Outdated
@OutisLi OutisLi added the Test CUDA Trigger test CUDA workflow label Jun 4, 2026
@github-actions github-actions Bot removed the Test CUDA Trigger test CUDA workflow label Jun 4, 2026
@anyangml anyangml requested a review from OutisLi June 5, 2026 01:53

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/pt/model/model/sezm_model.py (1)

455-463: ⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Bound the module-level compile cache lifetime.

_SEZM_COMPILE_CACHE and _SEZM_TASK_BUF_ORDER only grow, and nothing removes entries when models/wrappers are discarded. That turns compile sharing into a process-wide retention point for compiled artifacts, so repeated get_sezm_model() / multitask wrapper construction in one worker can accumulate old compiled graphs instead of releasing them.

Also applies to: 2115-2121

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/model/sezm_model.py` around lines 455 - 463,
_SEZM_COMPILE_CACHE and _SEZM_TASK_BUF_ORDER currently grow forever; change
their implementation to bound lifetime by either (a) using weak references so
entries are removed when the compiled artifacts are garbage-collected (e.g.,
store values in weakref.WeakValueDictionary or store weakrefs to compiled
callables) and/or (b) implementing an explicit LRU eviction with a
MAX_SEZM_CACHE_SIZE (use collections.OrderedDict to popitem(last=False) on
insertion when size exceeded). Apply this change to the module-level maps named
_SEZM_COMPILE_CACHE and _SEZM_TASK_BUF_ORDER and ensure any insertion paths
(e.g., in get_sezm_model(), wherever compiled graphs are cached after
share_params) use the new bounded cache APIs so old compiled graphs/wrappers can
be released.
🧹 Nitpick comments (1)
source/tests/pt/model/test_sezm_model.py (1)

616-620: ⚡ Quick win

Add a zero/one-edge regression next to this new tail check.

The dummy_count = 2 change is justified by E == 0/1 symbolic-trace failures, but this fixture still exercises a normal dense frame. A single-atom or single-real-edge case would actually lock in the scenario this change is meant to protect.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/pt/model/test_sezm_model.py` around lines 616 - 620, Add
explicit tests for the E==0 and E==1 edge-count cases alongside the existing
tail check: create or parametrize the fixture/state so cache_std.src (or the
structure used to derive n_real) is exercised with zero real edges and with a
single real edge, then assert the padded-tail behavior (edge_mask length, that
edge_mask[n_real:] are all False, and that dummy_count remains 2) for both
cases; update the test around n_real/edge_mask (symbols: cache_std.src, n_real,
edge_mask, dummy_count) to run these two scenarios so the regression for
single-atom/single-edge symbolic traces is covered.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@deepmd/pt/model/model/sezm_model.py`:
- Around line 455-463: _SEZM_COMPILE_CACHE and _SEZM_TASK_BUF_ORDER currently
grow forever; change their implementation to bound lifetime by either (a) using
weak references so entries are removed when the compiled artifacts are
garbage-collected (e.g., store values in weakref.WeakValueDictionary or store
weakrefs to compiled callables) and/or (b) implementing an explicit LRU eviction
with a MAX_SEZM_CACHE_SIZE (use collections.OrderedDict to popitem(last=False)
on insertion when size exceeded). Apply this change to the module-level maps
named _SEZM_COMPILE_CACHE and _SEZM_TASK_BUF_ORDER and ensure any insertion
paths (e.g., in get_sezm_model(), wherever compiled graphs are cached after
share_params) use the new bounded cache APIs so old compiled graphs/wrappers can
be released.

---

Nitpick comments:
In `@source/tests/pt/model/test_sezm_model.py`:
- Around line 616-620: Add explicit tests for the E==0 and E==1 edge-count cases
alongside the existing tail check: create or parametrize the fixture/state so
cache_std.src (or the structure used to derive n_real) is exercised with zero
real edges and with a single real edge, then assert the padded-tail behavior
(edge_mask length, that edge_mask[n_real:] are all False, and that dummy_count
remains 2) for both cases; update the test around n_real/edge_mask (symbols:
cache_std.src, n_real, edge_mask, dummy_count) to run these two scenarios so the
regression for single-atom/single-edge symbolic traces is covered.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: aabc757f-ea39-4bdd-9408-479453e55637

📥 Commits

Reviewing files that changed from the base of the PR and between 52bcebc and b23ce5e.

📒 Files selected for processing (2)
  • deepmd/pt/model/model/sezm_model.py
  • source/tests/pt/model/test_sezm_model.py

@anyangml anyangml enabled auto-merge June 5, 2026 02:17
@iProzd iProzd added Test CUDA Trigger test CUDA workflow and removed Test CUDA Trigger test CUDA workflow labels Jun 5, 2026
@anyangml anyangml disabled auto-merge June 5, 2026 02:19
@anyangml anyangml enabled auto-merge June 5, 2026 04:59
@anyangml anyangml added this pull request to the merge queue Jun 5, 2026
Merged via the queue into deepmodeling:master with commit 87d6e4e Jun 6, 2026
73 checks passed
@anyangml anyangml deleted the fix/dpa4-multitask-compile branch June 6, 2026 00:36
@coderabbitai coderabbitai Bot mentioned this pull request Jun 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants