[TRTLLM-12128][feat] enable SageAttention for Wan/FLUX (new commits)#13570
[TRTLLM-12128][feat] enable SageAttention for Wan/FLUX (new commits)#13570xrq-phys wants to merge 12 commits intoNVIDIA:mainfrom
Conversation
|
/bot run |
|
PR_Github #45971 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis pull request introduces SageAttention support to the visual generation pipeline. Changes include adding a new Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~35 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 10
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py (1)
1-1:⚠️ Potential issue | 🟠 MajorUpdate the copyright year in this modified file.
The header still says 2025, but this file now has 2026 changes.
As per coding guidelines, "Include NVIDIA copyright header on all new files; update year on modified files."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py` at line 1, The file trtllm.py still contains a 2025 NVIDIA copyright header; update the header year to 2026 in the top-of-file comment (the SPDX/copyright line) so the file reflects the current modification year.tests/unittest/_torch/visual_gen/test_attention_perf.py (1)
1-1:⚠️ Potential issue | 🟠 MajorUpdate the copyright year in this modified file.
The header still says 2025, but this file now has 2026 changes.
As per coding guidelines, "Include NVIDIA copyright header on all new files; update year on modified files."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/visual_gen/test_attention_perf.py` at line 1, Update the SPDX copyright header year from 2025 to 2026 in the file by editing the top-of-file header line that currently reads "SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved." to reflect 2026 so the modified file's header matches the current changes.
🧹 Nitpick comments (1)
tests/unittest/_torch/visual_gen/test_attention_integration.py (1)
693-701: QA integration test-list updates are not needed for this change.This PR touches
tests/unittest/...coverage; notests/integration/defs/...additions were made, so QA scheduled list changes are unnecessary here.As per coding guidelines: "If the PR only touches unittest/ or narrow unit scope, say explicitly whether QA list updates are unnecessary or optional."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/visual_gen/test_attention_integration.py` around lines 693 - 701, This change only updates unit tests under tests/unittest/_torch/visual_gen/test_attention_integration.py (the loop invoking test_sage_attention_self_attention with results[label]) and does not add any integration defs, so explicitly state that QA integration list updates are unnecessary by adding a clear one-line comment or PR note near the test block (referencing test_sage_attention_self_attention and the results dict/label loop) that says "QA integration test-list updates are unnecessary for this PR" to satisfy the coding guideline.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/visual_gen/README.md`:
- Around line 91-100: The README currently shows a SageAttention usage example
for WAN (visual_gen_wan_t2v.py with --enable_sage_attention and
--attention_backend TRTLLM) but the PR also adds FLUX support; update the docs
to either (A) add a FLUX-specific SageAttention example/row mirroring the WAN
example (use the FLUX model name, appropriate script or --attention_backend
value for FLUX and include --enable_sage_attention), or (B) explicitly state
that SageAttention is WAN-only in the example and the support table; modify the
example block and the support table entries consistently so they reflect the
chosen scope.
In `@tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py`:
- Line 241: The API exposes seq_len_kv but the non-Sage path in
trtllm.py/_concat_qkv (and surrounding logic where seq_len_kv is accepted) only
supports self-attention; add a guard that if seq_len_kv is not None and
seq_len_kv != seq_len you raise a clear ValueError (or NotImplementedError)
explaining cross-attention is unsupported yet, referencing the parameter name
seq_len_kv and the helper _concat_qkv so callers (e.g., attention module) get a
deterministic failure instead of an obscure reshape error; apply the same check
in the other related block(s) around lines 268-295 where seq_len_kv is used.
- Around line 271-274: Replace the inline assert in the Sage attention input
check with an explicit input-validation exception: in the method where
self.sage_attention_config is checked (the block that currently does "assert k
is not None and v is not None, ('SageAttention requires separate Q, K, V
tensors')"), change it to explicitly test "if k is None or v is None:" and raise
a ValueError("SageAttention requires separate Q, K, V tensors") so the invalid
inputs fail deterministically at this point (before any reshape or downstream
ops).
In `@tensorrt_llm/_torch/visual_gen/attention_backend/utils.py`:
- Around line 114-119: The code currently forwards
attention_config.sage_attention_config into kwargs before checking backend,
which can pass TRTLLM-only args to other backends; modify the logic so that
sage_attention_config is only added to kwargs when backend.upper() == "TRTLLM"
and attention_config is not None and attention_config.sage_attention_config is
not None (i.e., move or wrap the existing kwargs["sage_attention_config"]
assignment behind the backend check), referencing
attention_config.sage_attention_config, kwargs, and backend to locate and fix
the spot.
In `@tensorrt_llm/_torch/visual_gen/config.py`:
- Around line 121-124: The error text in the raise ValueError uses misleading
wording that claims a fallback ("Fallback to non-SageAttention TRTLLM
attention") even though the code actually raises an exception; update the
message in the raise ValueError in the block referencing
self.sage_attention_config to accurately state that the configuration is
unsupported and that no fallback will occur (e.g., "Unsupported
{self.sage_attention_config}; cannot fallback to non-SageAttention"), and ensure
the message includes the self.sage_attention_config value for clarity.
- Around line 54-80: The SageAttentionConfig class currently inherits from
BaseModel which permits extra fields; change it to inherit from StrictBaseModel
to enforce extra='forbid' and fail fast on unknown fields: update the class
definition for SageAttentionConfig to extend StrictBaseModel (and add/import
StrictBaseModel where necessary), keeping all existing PydanticField attributes
(num_elts_per_blk_q, num_elts_per_blk_k, num_elts_per_blk_v, qk_int8) unchanged.
In `@tensorrt_llm/_torch/visual_gen/modules/attention.py`:
- Around line 311-327: The code incorrectly computes seq_len and seq_len_kv
using q.shape[1], which is the number of heads for HND layout; update the
computation to derive sequence length from the correct dimension after reshape
(use q.shape[2] for HND layout) and likewise compute seq_len_kv from k.shape[2]
when k is not None; then pass those corrected seq_len and seq_len_kv values in
the kwargs to self.attn.forward (adjust logic around seq_len = q.shape[...] and
seq_len_kv = k.shape[...] to handle both HND and other layouts consistently).
In `@tests/unittest/_torch/visual_gen/test_attention_integration.py`:
- Around line 291-387: The test test_sage_attention_self_attention parametrizes
seq_len with very large values (e.g., 16384, 32760) causing CI timeouts/OOM;
change the parametrization to a small "fast" set for normal unit runs (e.g.,
[256, 512, 1024]) and move the heavy shapes into a separate gated marker (e.g.,
pytest.mark.slow or a new param set gated by pytest.config option) so the heavy
runs are only executed when requested; update the decorator around
test_sage_attention_self_attention and reference the same
SageAttentionConfig/Attention/NaiveWanSelfAttention flow so the logic and
weight-copying (copy_weights_self_attention) remain unchanged.
In `@tests/unittest/_torch/visual_gen/test_attention_perf.py`:
- Around line 947-1038: The tests only assert timing results but not that the
SageAttention execution path ran; update the production benchmarking path to
emit an observable flag (e.g., add a field like "used_sage" or "exec_path" to
the dict returned by _bench and benchmark.benchmark_single when the
SageAttention code path is taken) and then assert that flag in
test_sage_runs_and_times, test_sage_vs_vanilla_quick, and
test_sage_vs_trtllm_wan_shapes (use the existing result variables sage, vanilla,
trtllm and the helper _bench/benchmark_single symbols to locate changes); ensure
the runtime sets that flag inside the SageAttention dispatch/implementation code
so the tests fail if code silently falls back to TRTLLM.
- Around line 897-910: Change the mutable lists WAN_SEQ_LENS and QUICK_SEQ_LENS
to immutable tuples (replace [...] with (...)) and update the three zip calls
that pair _SAGE_CONFIGS with _SAGE_CONFIG_IDS to use zip(..., strict=True) so
mismatched lengths raise errors; specifically modify the zip invocations that
iterate over _SAGE_CONFIGS and _SAGE_CONFIG_IDS (the three places where those
two names are zipped) to pass strict=True.
---
Outside diff comments:
In `@tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py`:
- Line 1: The file trtllm.py still contains a 2025 NVIDIA copyright header;
update the header year to 2026 in the top-of-file comment (the SPDX/copyright
line) so the file reflects the current modification year.
In `@tests/unittest/_torch/visual_gen/test_attention_perf.py`:
- Line 1: Update the SPDX copyright header year from 2025 to 2026 in the file by
editing the top-of-file header line that currently reads
"SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All
rights reserved." to reflect 2026 so the modified file's header matches the
current changes.
---
Nitpick comments:
In `@tests/unittest/_torch/visual_gen/test_attention_integration.py`:
- Around line 693-701: This change only updates unit tests under
tests/unittest/_torch/visual_gen/test_attention_integration.py (the loop
invoking test_sage_attention_self_attention with results[label]) and does not
add any integration defs, so explicitly state that QA integration list updates
are unnecessary by adding a clear one-line comment or PR note near the test
block (referencing test_sage_attention_self_attention and the results dict/label
loop) that says "QA integration test-list updates are unnecessary for this PR"
to satisfy the coding guideline.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 2e0d8f56-e8da-4005-8f3a-f81b82bfd310
📒 Files selected for processing (15)
cpp/tensorrt_llm/thop/attentionOp.cppexamples/visual_gen/README.mdexamples/visual_gen/visual_gen_flux.pyexamples/visual_gen/visual_gen_ltx2.pyexamples/visual_gen/visual_gen_wan_i2v.pyexamples/visual_gen/visual_gen_wan_t2v.pytensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/visual_gen/attention_backend/trtllm.pytensorrt_llm/_torch/visual_gen/attention_backend/utils.pytensorrt_llm/_torch/visual_gen/config.pytensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.pytensorrt_llm/_torch/visual_gen/modules/attention.pytests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_sage_attention.pytests/unittest/_torch/visual_gen/test_attention_integration.pytests/unittest/_torch/visual_gen/test_attention_perf.py
|
PR_Github #45971 [ run ] completed with state
|
83d8984 to
b7be394
Compare
|
/bot run |
|
PR_Github #46052 [ run ] triggered by Bot. Commit: |
|
PR_Github #46052 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #46095 [ run ] triggered by Bot. Commit: |
|
PR_Github #46095 [ run ] completed with state
|
|
/bot run |
|
PR_Github #46180 [ run ] triggered by Bot. Commit: |
b7be394 to
276a366
Compare
|
I rebased to pick up #13598, which is needed to fix the debug build failures |
|
/bot run |
|
|
||
|
|
||
| def _sage_attention_config_for_model(model_path: str) -> tuple[dict, str]: | ||
| """INT8 Q/K Sage preset: (1,4,1) for Wan2.1-I2V, else (1,16,1).""" |
There was a problem hiding this comment.
This is interesting – it seems I2V is more sensitive to attn accuracy. Is that right?
There was a problem hiding this comment.
Tbh I haven't actually compared I2V against T2V.
Overall, smaller models tend to post higher precision requirements. From my personal perspective, special handling should detect Wan2.1 1.3B instead of Wan2.1 I2V.
Perhaps @o-stoner implemented this only for I2V from actual tests. I'll have to check it up with real runs later.
| "backend": args.attention_backend, | ||
| } | ||
| if args.enable_sage_attention: | ||
| sage_cfg, sage_preset = _sage_attention_config_for_model(args.model_path) |
There was a problem hiding this comment.
What sage_preset stands for? Is sage_cfg alone not enough?
It seems every model has a recommended sage config. Can we relocate this as part of model's default setting, instead of scattered within examples
There was a problem hiding this comment.
These are preset names (Wan2.1-I2V or default).
Looks like AI code style again.
Sorry I wanted to first see whether all feature are functioning. Haven't carefully checked these lines. 🙇
|
PR_Github #46196 [ run ] triggered by Bot. Commit: |
|
PR_Github #46180 [ run ] completed with state |
|
PR_Github #46196 [ run ] completed with state
|
|
/bot run |
|
PR_Github #46221 [ run ] triggered by Bot. Commit: |
|
PR_Github #46221 [ run ] completed with state
|
|
/bot run |
ed2a1b6 to
90509d3
Compare
|
/bot run |
|
PR_Github #46422 [ run ] triggered by Bot. Commit: |
|
PR_Github #46422 [ run ] completed with state
|
90509d3 to
1f031bc
Compare
|
/bot run --add-multi-gpu-test |
|
PR_Github #46466 [ run ] triggered by Bot. Commit: |
|
PR_Github #46466 [ run ] completed with state |
|
/bot run --add-multi-gpu-test |
|
PR_Github #46474 [ run ] triggered by Bot. Commit: |
|
PR_Github #46474 [ run ] completed with state
|
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
Signed-off-by: Ruqing Xu <7891482+xrq-phys@users.noreply.github.com>
Signed-off-by: Ruqing Xu <7891482+xrq-phys@users.noreply.github.com>
Signed-off-by: Ruqing Xu <7891482+xrq-phys@users.noreply.github.com>
Signed-off-by: Ruqing Xu <7891482+xrq-phys@users.noreply.github.com>
Signed-off-by: Ruqing Xu <7891482+xrq-phys@users.noreply.github.com>
Signed-off-by: Ruqing Xu <7891482+xrq-phys@users.noreply.github.com>
Signed-off-by: Ruqing Xu <7891482+xrq-phys@users.noreply.github.com>
1f031bc to
ca40211
Compare
|
/bot run --add-multi-gpu-test |
|
PR_Github #46488 [ run ] triggered by Bot. Commit: |
Summary by CodeRabbit
New Features
--enable_sage_attentionflag (requires TRTLLM attention backend).Documentation
Tests
Description
Copied from #13425 with new commits to fix CI failures.
Integrate SageAttention kernels from #12937 into VisualGen for Wan and FLUX.2 models.
LTX-2 SageAttention is not yet supported.
Preliminary perf / quality compared to #12548 baseline from example scripts:
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.