fix: dpa4 compile aliasing collision#5483
Conversation
for more information, see https://pre-commit.ci
📝 WalkthroughWalkthroughSeZMModel now shares compiled ener callables across tasks with compatible module structure by promoting per-task buffers to FX placeholders and passing them as runtime varargs. Tracing uses prime-based safe trace sizes and pads/clamps inputs. Edge-list indexing/docs updated for symbolic tracing stability. Tests adjusted to expect callable reuse. ChangesMulti-task compile cache with shared graphs and export adjustments
Edge-list index and documentation updates
🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Pull request overview
This PR attempts to improve/repair the PyTorch-compiled execution path for the SeZM/DPA4 model, primarily by reducing recompiles/OOM in multi-task setups and addressing symbolic-shape tracing issues in make_fx.
Changes:
- Add module-level compile sharing and promote selected per-task buffers (e.g.,
out_bias,bias_atom_e,case_embd) as FX inputs to enable compiled-graph reuse across shared-parameter tasks. - Add additional symbolic-shape anti-aliasing logic for trace inputs and temporarily disable
ShapeEnvduck sizing during tracing. - Change edge-list construction to append a single masked dummy edge (instead of two) and adjust related documentation/behavior.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5483 +/- ##
==========================================
+ Coverage 81.34% 81.39% +0.04%
==========================================
Files 868 868
Lines 96373 96772 +399
Branches 4233 4240 +7
==========================================
+ Hits 78399 78771 +372
- Misses 16675 16698 +23
- Partials 1299 1303 +4 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/pt/model/model/sezm_model.py (1)
2735-2744:⚠️ Potential issue | 🟠 Major | ⚡ Quick winClear the shared compile cache when the
enerhead is reset.This branch clears only
self.compiled_core_compute_cache, but Line 1721 can immediately repopulate it from_SEZM_COMPILE_CACHEwhen the structure key is unchanged. That bypasses the retrace promised by this method and can resurrect a callable traced against the pre-reset head.Suggested change
else: + structure_key = _sezm_structure_key(self) + stale_keys = [ + key + for key in _SEZM_COMPILE_CACHE + if key[: len(structure_key)] == structure_key + ] + for key in stale_keys: + _SEZM_COMPILE_CACHE.pop(key, None) + _SEZM_TASK_BUF_ORDER.pop(structure_key, None) self._core_compute_pending_compile_t0 = None self._core_compute_pending_compile_key = None # Drop every compile slot so the next forward retraces against the # reinitialised fitting head. self.compiled_core_compute_cache.clear() + self._task_buf_order_cache.clear()🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/model/sezm_model.py` around lines 2735 - 2744, The ener-head reset branch currently clears only instance cache (compiled_core_compute_cache) but leaves the shared cache (_SEZM_COMPILE_CACHE) intact so a subsequent lookup (see code that repopulates from _SEZM_COMPILE_CACHE using the structure key) can resurrect callables traced against the old head; update the else branch that resets the ener head to also invalidate the shared compile cache for the same structure key: when you set _core_compute_pending_compile_key to None and call compiled_core_compute_cache.clear(), also remove any entries in _SEZM_COMPILE_CACHE that correspond to the previous structure key (or clear the shared cache entirely) so retracing is forced (operate on the same key variable used to populate _SEZM_COMPILE_CACHE in the forward path).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt/model/model/sezm_model.py`:
- Around line 465-483: The current _sezm_structure_key function risks returning
identical keys for different tasks because it only samples the first child of
descriptor and fitting_net; update it to either (preferred) derive the key from
all shared parameter objects (e.g., use a frozenset of ids for all parameters
that are shared between tasks) or (safer/alternative) detect whether the entire
descriptor and fitting_net stacks are fully shared and raise an error if not, so
partial sharing cannot collapse different tasks into the same
_SEZM_COMPILE_CACHE entry; locate the logic in _sezm_structure_key and use
SeZMModel.descriptor / SeZMModel.fitting_net plus the model.share_params
semantics to compute the full parameter-based key or to assert full-stack
sharing before returning a key for cache reuse.
- Line 1772: The loop using zip(task_buf_names, vals) can silently truncate if
the iterables differ; change it to zip(task_buf_names, vals, strict=True) to
make mismatched lengths raise an error. Locate the loop that iterates over
task_buf_names and vals (the line currently written as "for name, val in
zip(task_buf_names, vals):") and update it to include strict=True so any length
mismatch is detected immediately.
---
Outside diff comments:
In `@deepmd/pt/model/model/sezm_model.py`:
- Around line 2735-2744: The ener-head reset branch currently clears only
instance cache (compiled_core_compute_cache) but leaves the shared cache
(_SEZM_COMPILE_CACHE) intact so a subsequent lookup (see code that repopulates
from _SEZM_COMPILE_CACHE using the structure key) can resurrect callables traced
against the old head; update the else branch that resets the ener head to also
invalidate the shared compile cache for the same structure key: when you set
_core_compute_pending_compile_key to None and call
compiled_core_compute_cache.clear(), also remove any entries in
_SEZM_COMPILE_CACHE that correspond to the previous structure key (or clear the
shared cache entirely) so retracing is forced (operate on the same key variable
used to populate _SEZM_COMPILE_CACHE in the forward path).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: b86f7a23-46a3-4ec5-b8a9-d4a97ba0f001
📒 Files selected for processing (2)
deepmd/pt/model/model/sezm_model.pysource/tests/pt/model/test_sezm_model.py
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/pt/model/model/sezm_model.py (1)
455-463:⚠️ Potential issue | 🟠 Major | 🏗️ Heavy liftBound the module-level compile cache lifetime.
_SEZM_COMPILE_CACHEand_SEZM_TASK_BUF_ORDERonly grow, and nothing removes entries when models/wrappers are discarded. That turns compile sharing into a process-wide retention point for compiled artifacts, so repeatedget_sezm_model()/ multitask wrapper construction in one worker can accumulate old compiled graphs instead of releasing them.Also applies to: 2115-2121
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/model/sezm_model.py` around lines 455 - 463, _SEZM_COMPILE_CACHE and _SEZM_TASK_BUF_ORDER currently grow forever; change their implementation to bound lifetime by either (a) using weak references so entries are removed when the compiled artifacts are garbage-collected (e.g., store values in weakref.WeakValueDictionary or store weakrefs to compiled callables) and/or (b) implementing an explicit LRU eviction with a MAX_SEZM_CACHE_SIZE (use collections.OrderedDict to popitem(last=False) on insertion when size exceeded). Apply this change to the module-level maps named _SEZM_COMPILE_CACHE and _SEZM_TASK_BUF_ORDER and ensure any insertion paths (e.g., in get_sezm_model(), wherever compiled graphs are cached after share_params) use the new bounded cache APIs so old compiled graphs/wrappers can be released.
🧹 Nitpick comments (1)
source/tests/pt/model/test_sezm_model.py (1)
616-620: ⚡ Quick winAdd a zero/one-edge regression next to this new tail check.
The
dummy_count = 2change is justified byE == 0/1symbolic-trace failures, but this fixture still exercises a normal dense frame. A single-atom or single-real-edge case would actually lock in the scenario this change is meant to protect.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/pt/model/test_sezm_model.py` around lines 616 - 620, Add explicit tests for the E==0 and E==1 edge-count cases alongside the existing tail check: create or parametrize the fixture/state so cache_std.src (or the structure used to derive n_real) is exercised with zero real edges and with a single real edge, then assert the padded-tail behavior (edge_mask length, that edge_mask[n_real:] are all False, and that dummy_count remains 2) for both cases; update the test around n_real/edge_mask (symbols: cache_std.src, n_real, edge_mask, dummy_count) to run these two scenarios so the regression for single-atom/single-edge symbolic traces is covered.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@deepmd/pt/model/model/sezm_model.py`:
- Around line 455-463: _SEZM_COMPILE_CACHE and _SEZM_TASK_BUF_ORDER currently
grow forever; change their implementation to bound lifetime by either (a) using
weak references so entries are removed when the compiled artifacts are
garbage-collected (e.g., store values in weakref.WeakValueDictionary or store
weakrefs to compiled callables) and/or (b) implementing an explicit LRU eviction
with a MAX_SEZM_CACHE_SIZE (use collections.OrderedDict to popitem(last=False)
on insertion when size exceeded). Apply this change to the module-level maps
named _SEZM_COMPILE_CACHE and _SEZM_TASK_BUF_ORDER and ensure any insertion
paths (e.g., in get_sezm_model(), wherever compiled graphs are cached after
share_params) use the new bounded cache APIs so old compiled graphs/wrappers can
be released.
---
Nitpick comments:
In `@source/tests/pt/model/test_sezm_model.py`:
- Around line 616-620: Add explicit tests for the E==0 and E==1 edge-count cases
alongside the existing tail check: create or parametrize the fixture/state so
cache_std.src (or the structure used to derive n_real) is exercised with zero
real edges and with a single real edge, then assert the padded-tail behavior
(edge_mask length, that edge_mask[n_real:] are all False, and that dummy_count
remains 2) for both cases; update the test around n_real/edge_mask (symbols:
cache_std.src, n_real, edge_mask, dummy_count) to run these two scenarios so the
regression for single-atom/single-edge symbolic traces is covered.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: aabc757f-ea39-4bdd-9408-479453e55637
📒 Files selected for processing (2)
deepmd/pt/model/model/sezm_model.pysource/tests/pt/model/test_sezm_model.py
Summary by CodeRabbit
Bug Fixes
Performance
Changes
Tests