[TRTLLM-12128][feat] enable SageAttention for Wan/FLUX#13425
[TRTLLM-12128][feat] enable SageAttention for Wan/FLUX#13425o-stoner wants to merge 5 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #45422 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis pull request adds SageAttention support to visual generation models by refactoring the attention configuration from individual parameters into a unified Changes
Sequence DiagramsequenceDiagram
participant CLI as Example Script
participant Config as Config System
participant Factory as Attention Factory
participant Backend as TrtllmAttention
participant Ops as Low-level Ops
CLI->>CLI: Parse --enable_sage_attention
CLI->>Config: Build AttentionConfig with sage_attention_config
Config->>Config: Validate backend=TRTLLM & supported block sizes
Config-->>Factory: Return AttentionConfig
Factory->>Factory: Extract sage_attention_config
Factory->>Backend: Create with sage_attention_config
Backend->>Backend: Store sage_attention_config
alt SageAttention Enabled
Backend->>Backend: Reshape Q/K/V to flattened forms
Backend->>Ops: Call forward with per-block quantization
Ops-->>Backend: Return attention output
else Standard Path
Backend->>Backend: Fuse/concatenate Q/K/V
Backend->>Ops: Call forward without SageAttention params
Ops-->>Backend: Return attention output
end
Backend-->>CLI: Attention result
Estimated Code Review Effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_sage_attention.py (1)
21-23:⚠️ Potential issue | 🟡 MinorFix the example pytest target.
The documented command points at
test_sage_ulysses_attention.py, but this file is namedtest_ulysses_sage_attention.py, so copy/paste will fail.🛠 Suggested fix
- pytest tests/unittest/_torch/visual_gen/multi_gpu/test_sage_ulysses_attention.py -v + pytest tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_sage_attention.py -v🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_sage_attention.py` around lines 21 - 23, Update the example pytest command in the docstring so it references the correct test filename: replace the incorrect `test_sage_ulysses_attention.py` target with `test_ulysses_sage_attention.py` (the actual test file name) in the run example shown in the file `tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_sage_attention.py`; ensure the example reads: pytest tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_sage_attention.py -v so copy/paste runs the right test.cpp/tensorrt_llm/thop/attentionOp.cpp (1)
652-675:⚠️ Potential issue | 🔴 CriticalRemove the overly broad global gate.
The check
is_mla_enable || is_fused_qkv || use_sage_attnconflates three independent concerns and blocks future callers from using unfused K/V without MLA or SageAttention enabled. The SageAttention-specific validation at lines 669–675 remains and ensures correctness for SageAttention paths.🔧 Suggested fix
- TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv || use_sage_attn, - "Context attention only allows these non-MLA cases: fused QKV; separate QKV with SageAttention"); TLLM_CHECK_WITH_INFO(update_kv_cache, "KV cache update cannot be disabled now");🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/thop/attentionOp.cpp` around lines 652 - 675, Remove the broad gate that currently enforces "is_mla_enable || is_fused_qkv || use_sage_attn" by deleting the TLLM_CHECK_WITH_INFO call that uses those three flags; instead rely on the existing, more specific validations already present (the later TLLM_CHECK_WITH_INFO calls that reference is_fused_qkv, k.has_value(), v.has_value(), update_kv_cache, and the SageAttention checks guarded by use_sage_attn). Ensure no other logic depends on that global check and keep the local checks for fused QKV (is_fused_qkv), KV-cache update (update_kv_cache), and SageAttention (use_sage_attn) intact so unfused K/V paths are allowed when appropriate.tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py (1)
268-295:⚠️ Potential issue | 🟠 MajorCross-attention with
seq_len != seq_len_kvcan fail in non-Sage path.When
k/vare provided and lengths differ,_concat_qkv()flattens to[B*seq_len, ...]and[B*kv_seq_len, ...];torch.cat(..., dim=-1)then errors due to mismatched first dimension.Suggested fix (handle unequal Q/KV lengths explicitly)
else: if k is None and v is None: qkv = q.reshape(batch_size * seq_len, -1) - else: + output = super().forward( + q=qkv, + k=None, + v=None, + metadata=prepared_metadata, + attention_mask=attention_mask, + ) + elif kv_seq_len != seq_len: + q_flat = q.reshape(batch_size * seq_len, -1) + k_flat = k.reshape(batch_size * kv_seq_len, -1) + v_flat = v.reshape(batch_size * kv_seq_len, -1) + output = super().forward( + q=q_flat, + k=k_flat, + v=v_flat, + metadata=prepared_metadata, + attention_mask=attention_mask, + ) + else: qkv = self._concat_qkv(q, k, v, batch_size, seq_len, kv_seq_len) - output = super().forward( - q=qkv, - k=None, - v=None, - metadata=prepared_metadata, - attention_mask=attention_mask, - ) + output = super().forward( + q=qkv, + k=None, + v=None, + metadata=prepared_metadata, + attention_mask=attention_mask, + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py` around lines 268 - 295, The non-Sage branch fails when seq_len_kv != seq_len because _concat_qkv flattens Q and K/V to different first-dim sizes (batch_size*seq_len vs batch_size*kv_seq_len) and torch.cat errors; fix by detecting when k and v are provided and kv_seq_len != seq_len and, instead of calling _concat_qkv, reshape q to (batch_size*seq_len, -1) and k, v to (batch_size*kv_seq_len, -1) individually and pass q, k, v separately into super().forward (same pattern used in the Sage branch) so the backend receives correctly-shaped Q/K/V; update the forward implementation (and if needed _concat_qkv) to only use _concat_qkv when seq_len == seq_len_kv.
🧹 Nitpick comments (2)
tests/unittest/_torch/visual_gen/test_attention_perf.py (1)
897-903: Prefer immutable class-level test matrices.
WAN_SEQ_LENSandQUICK_SEQ_LENSare mutable class attributes; accidental mutation can leak across tests.Suggested diff
- WAN_SEQ_LENS = [ + WAN_SEQ_LENS = ( (1, 12, 14040, 128, "wan_1.3b_480p_33f"), (1, 12, 32760, 128, "wan_1.3b_480p_81f"), (1, 40, 14040, 128, "wan_14b_480p_33f"), (1, 40, 32760, 128, "wan_14b_480p_81f"), (1, 40, 75600, 128, "wan_14b_720p_81f"), - ] + ) ... - QUICK_SEQ_LENS = [ + QUICK_SEQ_LENS = ( (1, 12, 1024, 128, "quick_1k"), (1, 12, 4096, 128, "quick_4k"), (2, 12, 1024, 128, "quick_batch2"), - ] + )Also applies to: 906-910
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/visual_gen/test_attention_perf.py` around lines 897 - 903, WAN_SEQ_LENS and QUICK_SEQ_LENS are defined as mutable lists at class scope which can be accidentally mutated across tests; change both to immutable sequence types (e.g., tuples or nested tuples) and keep them as class-level constants (retain names) so test matrices cannot be modified at runtime, updating any code that iterates them to work with tuples if necessary (refer to WAN_SEQ_LENS and QUICK_SEQ_LENS in the test class).tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py (1)
271-274: Use an explicit exception instead ofassertfor runtime validation.
assertcan be stripped with optimized Python execution; this check should be a guaranteed runtime guard.Suggested diff
- assert k is not None and v is not None, ( - "SageAttention requires separate Q, K, V tensors" - ) + if k is None or v is None: + raise ValueError("SageAttention requires separate Q, K, V tensors")As per coding guidelines: “Use built-in Python exception types; use exceptions for error handling, not return values.”
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py` around lines 271 - 274, Replace the runtime assert in the block that checks self.sage_attention_config with an explicit exception raise so the validation always runs; specifically, in the method containing the lines checking self.sage_attention_config, replace "assert k is not None and v is not None" with a raised built-in exception (e.g., raise ValueError) that includes a clear message like "SageAttention requires separate Q, K, V tensors" and reference k and v being None as needed to aid debugging.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/visual_gen/config.py`:
- Around line 71-80: The descriptions claim "0 disables" for
num_elts_per_blk_q/k/v but _validate_sage_attn_config only allows whitelisted
SageAttention presets; update the validator to implement the disabled path:
inside _validate_sage_attn_config detect if any of num_elts_per_blk_q,
num_elts_per_blk_k, or num_elts_per_blk_v == 0 and treat that as an explicit
"disable quantization"/custom-config mode (skip the preset whitelist check and
ensure qk_int8 handling remains correct), otherwise continue enforcing the
existing preset whitelist; reference the fields num_elts_per_blk_q,
num_elts_per_blk_k, num_elts_per_blk_v and the validator function
_validate_sage_attn_config when making the change.
- Around line 54-96: SageAttentionConfig currently inherits from BaseModel
causing unknown fields to be ignored; change its base to StrictBaseModel so
nested user-facing Pydantic validation fails fast. Locate the class definition
for SageAttentionConfig and replace BaseModel with StrictBaseModel (keeping
existing field names like num_elts_per_blk_q, num_elts_per_blk_k,
num_elts_per_blk_v, qk_int8 and the PydanticField usages) so
AttentionConfig.sage_attention_config will enforce strict validation.
In `@tensorrt_llm/_torch/visual_gen/modules/attention.py`:
- Around line 311-320: The seq_len and seq_len_kv values are taken from
q.shape[1]/k.shape[1] after the HND reshape, which yields head dimensions (so
VANILLA/FA4 backends get head counts instead of real sequence lengths); update
the logic in the attention module where seq_len and seq_len_kv are set (before
calling self.attn.forward) to compute lengths from the pre-reshaped tensors or
branch on backend_layout (e.g., if backend_layout indicates HND, derive
seq_len/seq_len_kv from the original un-reshaped sequence dimension or use
stored sequence_length variables, otherwise keep the current q.shape/k.shape
logic) and pass those corrected seq_len and seq_len_kv into self.attn.forward to
ensure correct sequence length semantics for HND backends.
In `@tests/unittest/_torch/visual_gen/test_attention_integration.py`:
- Around line 289-387: The test only exercises num_elts_per_blk_k values 1 and
16 via the qk_int8 boolean in test_sage_attention_self_attention, so add
coverage for the Wan2.1 preset (k=4) by changing the parametrization: introduce
a new parameter (e.g., num_elts_per_blk_k_values) or replace qk_int8 with a
param that enumerates [1,4,16], then construct
SageAttentionConfig(num_elts_per_blk_k=that_value, num_elts_per_blk_q=1,
num_elts_per_blk_v=1, qk_int8=(that_value!=1)) inside the test; keep the test
function name test_sage_attention_self_attention and SageAttentionConfig usage
so the new k=4 path is exercised.
In `@tests/unittest/_torch/visual_gen/test_attention_perf.py`:
- Line 942: The parametrized tests use zip(_SAGE_CONFIGS, _SAGE_CONFIG_IDS)
which will silently drop cases if the two lists differ in length; update the
pytest.mark.parametrize calls to use zip(..., strict=True) for each occurrence
(e.g. where pytest.mark.parametrize("sage_cfg,cfg_id", zip(_SAGE_CONFIGS,
_SAGE_CONFIG_IDS)) at lines shown and the other occurrences referencing
_SAGE_CONFIGS and _SAGE_CONFIG_IDS) so that a length mismatch raises
immediately; modify each zip invocation (including the ones at the other noted
locations) to pass strict=True.
- Around line 969-999: Add a loose performance assertion to prevent silent
regressions: in test_sage_vs_vanilla_quick (and similarly in
test_sage_vs_trtllm_wan_shapes) after computing speedup_vs_vanilla and
speedup_vs_trtllm, add an assertion such as assert speedup_vs_vanilla > 0.9,
f"slowdown vs VANILLA: {speedup_vs_vanilla:.2f}x" (and/or assert
speedup_vs_trtllm > 0.9 with a similar message) so the test fails when
SageAttention is materially slower; update the assertion threshold as needed for
CI stability and keep the existing print statements intact.
---
Outside diff comments:
In `@cpp/tensorrt_llm/thop/attentionOp.cpp`:
- Around line 652-675: Remove the broad gate that currently enforces
"is_mla_enable || is_fused_qkv || use_sage_attn" by deleting the
TLLM_CHECK_WITH_INFO call that uses those three flags; instead rely on the
existing, more specific validations already present (the later
TLLM_CHECK_WITH_INFO calls that reference is_fused_qkv, k.has_value(),
v.has_value(), update_kv_cache, and the SageAttention checks guarded by
use_sage_attn). Ensure no other logic depends on that global check and keep the
local checks for fused QKV (is_fused_qkv), KV-cache update (update_kv_cache),
and SageAttention (use_sage_attn) intact so unfused K/V paths are allowed when
appropriate.
In `@tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py`:
- Around line 268-295: The non-Sage branch fails when seq_len_kv != seq_len
because _concat_qkv flattens Q and K/V to different first-dim sizes
(batch_size*seq_len vs batch_size*kv_seq_len) and torch.cat errors; fix by
detecting when k and v are provided and kv_seq_len != seq_len and, instead of
calling _concat_qkv, reshape q to (batch_size*seq_len, -1) and k, v to
(batch_size*kv_seq_len, -1) individually and pass q, k, v separately into
super().forward (same pattern used in the Sage branch) so the backend receives
correctly-shaped Q/K/V; update the forward implementation (and if needed
_concat_qkv) to only use _concat_qkv when seq_len == seq_len_kv.
In `@tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_sage_attention.py`:
- Around line 21-23: Update the example pytest command in the docstring so it
references the correct test filename: replace the incorrect
`test_sage_ulysses_attention.py` target with `test_ulysses_sage_attention.py`
(the actual test file name) in the run example shown in the file
`tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_sage_attention.py`;
ensure the example reads: pytest
tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_sage_attention.py -v so
copy/paste runs the right test.
---
Nitpick comments:
In `@tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py`:
- Around line 271-274: Replace the runtime assert in the block that checks
self.sage_attention_config with an explicit exception raise so the validation
always runs; specifically, in the method containing the lines checking
self.sage_attention_config, replace "assert k is not None and v is not None"
with a raised built-in exception (e.g., raise ValueError) that includes a clear
message like "SageAttention requires separate Q, K, V tensors" and reference k
and v being None as needed to aid debugging.
In `@tests/unittest/_torch/visual_gen/test_attention_perf.py`:
- Around line 897-903: WAN_SEQ_LENS and QUICK_SEQ_LENS are defined as mutable
lists at class scope which can be accidentally mutated across tests; change both
to immutable sequence types (e.g., tuples or nested tuples) and keep them as
class-level constants (retain names) so test matrices cannot be modified at
runtime, updating any code that iterates them to work with tuples if necessary
(refer to WAN_SEQ_LENS and QUICK_SEQ_LENS in the test class).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: acbdcd9e-da10-404d-aede-f5cacad4beae
📒 Files selected for processing (15)
cpp/tensorrt_llm/thop/attentionOp.cppexamples/visual_gen/README.mdexamples/visual_gen/visual_gen_flux.pyexamples/visual_gen/visual_gen_ltx2.pyexamples/visual_gen/visual_gen_wan_i2v.pyexamples/visual_gen/visual_gen_wan_t2v.pytensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/visual_gen/attention_backend/trtllm.pytensorrt_llm/_torch/visual_gen/attention_backend/utils.pytensorrt_llm/_torch/visual_gen/config.pytensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.pytensorrt_llm/_torch/visual_gen/modules/attention.pytests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_sage_attention.pytests/unittest/_torch/visual_gen/test_attention_integration.pytests/unittest/_torch/visual_gen/test_attention_perf.py
| class SageAttentionConfig(BaseModel): | ||
| """Configuration for SageAttention quantization (TRTLLM backend only). | ||
|
|
||
| SageAttention quantizes Q/K/V into FP8 (or INT8 for Q/K) with per-block | ||
| scaling factors, enabling faster attention kernels. Providing this config | ||
| to AttentionConfig enables SageAttention; omitting it (None) disables it. | ||
|
|
||
| Similar to ``sparse_attention_config`` for the base TRTLLM attention | ||
| backend — the presence of the config object signals enablement. | ||
|
|
||
| Currently these (num_elts_per_blk_q, num_elts_per_blk_k, num_elts_per_blk_v) | ||
| combinations are enabled: | ||
| - (1, 1, 1) | ||
| - (1, 4, 1) | ||
| - (1, 16, 1) [for qk_int8 == True only] | ||
| """ | ||
|
|
||
| num_elts_per_blk_q: int = PydanticField( | ||
| 1, ge=0, description="Elements per quantization block for Q (0 disables)" | ||
| ) | ||
| num_elts_per_blk_k: int = PydanticField( | ||
| 4, ge=0, description="Elements per quantization block for K (0 disables)" | ||
| ) | ||
| num_elts_per_blk_v: int = PydanticField( | ||
| 1, ge=0, description="Elements per quantization block for V (0 disables)" | ||
| ) | ||
| qk_int8: bool = PydanticField(True, description="Use INT8 (vs E4M3) for Q/K quantization") | ||
|
|
||
|
|
||
| class AttentionConfig(StrictBaseModel): | ||
| """Configuration for Attention layers.""" | ||
|
|
||
| backend: Literal["VANILLA", "TRTLLM", "FA4"] = PydanticField( | ||
| "VANILLA", description="Attention backend: VANILLA (PyTorch SDPA), TRTLLM, FA4" | ||
| ) | ||
| sage_attention_config: Optional[SageAttentionConfig] = PydanticField( | ||
| None, | ||
| description=( | ||
| "SageAttention config (TRTLLM backend only). " | ||
| "Set to a SageAttentionConfig instance to enable SageAttention; " | ||
| "leave as None to disable." | ||
| ), | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify the current inheritance and StrictBaseModel definition in the repo.
rg -n -C 2 'class SageAttentionConfig|class StrictBaseModel' \
tensorrt_llm/_torch/visual_gen/config.py \
tensorrt_llm/llmapi/utils.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 654
🏁 Script executed:
head -60 tensorrt_llm/_torch/visual_gen/config.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2250
🏁 Script executed:
rg -n 'class \w+\((BaseModel|StrictBaseModel)' tensorrt_llm/_torch/visual_gen/config.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 499
🏁 Script executed:
sed -n '615,625p' tensorrt_llm/_torch/visual_gen/config.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 539
Make SageAttentionConfig inherit from StrictBaseModel.
This is a user-facing nested Pydantic model used within AttentionConfig, but it inherits from BaseModel instead of StrictBaseModel. This means unknown fields will be silently ignored rather than failing fast. Per the coding guidelines, user-facing Pydantic model classes should inherit from StrictBaseModel to fail fast on invalid field names.
Suggested fix
-class SageAttentionConfig(BaseModel):
+class SageAttentionConfig(StrictBaseModel):StrictBaseModel is already imported in the file (line 13).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/visual_gen/config.py` around lines 54 - 96,
SageAttentionConfig currently inherits from BaseModel causing unknown fields to
be ignored; change its base to StrictBaseModel so nested user-facing Pydantic
validation fails fast. Locate the class definition for SageAttentionConfig and
replace BaseModel with StrictBaseModel (keeping existing field names like
num_elts_per_blk_q, num_elts_per_blk_k, num_elts_per_blk_v, qk_int8 and the
PydanticField usages) so AttentionConfig.sage_attention_config will enforce
strict validation.
| num_elts_per_blk_q: int = PydanticField( | ||
| 1, ge=0, description="Elements per quantization block for Q (0 disables)" | ||
| ) | ||
| num_elts_per_blk_k: int = PydanticField( | ||
| 4, ge=0, description="Elements per quantization block for K (0 disables)" | ||
| ) | ||
| num_elts_per_blk_v: int = PydanticField( | ||
| 1, ge=0, description="Elements per quantization block for V (0 disables)" | ||
| ) | ||
| qk_int8: bool = PydanticField(True, description="Use INT8 (vs E4M3) for Q/K quantization") |
There was a problem hiding this comment.
Align the 0 disables wording with the validator.
The field descriptions suggest that zero is a supported “disable” value, but _validate_sage_attn_config() only accepts the whitelisted SageAttention presets. Either update the description or add a real disabled path so the API matches the validation behavior.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/visual_gen/config.py` around lines 71 - 80, The
descriptions claim "0 disables" for num_elts_per_blk_q/k/v but
_validate_sage_attn_config only allows whitelisted SageAttention presets; update
the validator to implement the disabled path: inside _validate_sage_attn_config
detect if any of num_elts_per_blk_q, num_elts_per_blk_k, or num_elts_per_blk_v
== 0 and treat that as an explicit "disable quantization"/custom-config mode
(skip the preset whitelist check and ensure qk_int8 handling remains correct),
otherwise continue enforcing the existing preset whitelist; reference the fields
num_elts_per_blk_q, num_elts_per_blk_k, num_elts_per_blk_v and the validator
function _validate_sage_attn_config when making the change.
| seq_len = q.shape[1] | ||
| seq_len_kv = k.shape[1] if k is not None else seq_len | ||
| out = self.attn.forward( | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| batch_size=batch_size, | ||
| seq_len=seq_len, | ||
| seq_len_kv=seq_len_kv, | ||
| **kwargs, |
There was a problem hiding this comment.
Fix the sequence-length derivation for HND backends.
q.shape[1] and k.shape[1] are the head dimensions after the HND reshape, so VANILLA/FA4 will pass head counts into forward(...) instead of the true sequence lengths. Compute the lengths before reshaping, or branch on backend_layout when setting seq_len and seq_len_kv.
🔧 Suggested fix
- seq_len = q.shape[1]
- seq_len_kv = k.shape[1] if k is not None else seq_len
+ seq_len = q.shape[2] if backend_layout == AttentionTensorLayout.HND else q.shape[1]
+ seq_len_kv = (
+ k.shape[2] if backend_layout == AttentionTensorLayout.HND else seq_len
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/visual_gen/modules/attention.py` around lines 311 - 320,
The seq_len and seq_len_kv values are taken from q.shape[1]/k.shape[1] after the
HND reshape, which yields head dimensions (so VANILLA/FA4 backends get head
counts instead of real sequence lengths); update the logic in the attention
module where seq_len and seq_len_kv are set (before calling self.attn.forward)
to compute lengths from the pre-reshaped tensors or branch on backend_layout
(e.g., if backend_layout indicates HND, derive seq_len/seq_len_kv from the
original un-reshaped sequence dimension or use stored sequence_length variables,
otherwise keep the current q.shape/k.shape logic) and pass those corrected
seq_len and seq_len_kv into self.attn.forward to ensure correct sequence length
semantics for HND backends.
| # seq_len: pow2 baselines + real WAN latent token counts (VAE 8x spatial, 4x temporal, patch [1,2,2]) | ||
| # batch_size: B=1 (cfg_size=2, split across GPUs) / B=2 (cfg_size=1, single GPU) | ||
| @pytest.mark.parametrize("seq_len", [256, 512, 1560, 3600, 4096, 16384, 32760]) | ||
| @pytest.mark.parametrize("batch_size", [1, 2]) | ||
| @pytest.mark.parametrize("qk_int8", [False, True]) | ||
| def test_sage_attention_self_attention(qk_int8: bool, batch_size: int, seq_len: int): | ||
| """Test SageAttention (TRTLLM + sage_attention_config) self-attention. | ||
|
|
||
| SageAttention quantizes Q/K/V with per-block scaling factors, so outputs | ||
| are expected to differ from the naive SDPA reference. We verify: | ||
| 1. Forward pass completes without error | ||
| 2. Output shape matches naive | ||
| 3. Outputs are finite (no NaN/Inf) | ||
| 4. Approximate agreement with naive (cosine similarity > 0.99) | ||
| """ | ||
| print("\n" + "=" * 60) | ||
| print(f"Testing SageAttention (qk_int8={qk_int8}, B={batch_size}, S={seq_len})") | ||
| print("=" * 60) | ||
|
|
||
| # The sm100 sage kernel only has cubins for head_dim=128, | ||
| # so match the WAN model dimensions (12 heads, head_dim=128). | ||
| num_heads = 12 | ||
| head_dim = 128 | ||
| hidden_size = num_heads * head_dim # 1536 | ||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| dtype = torch.bfloat16 | ||
|
|
||
| print(f"Config: B={batch_size}, S={seq_len}, H={hidden_size}, heads={num_heads}, D={head_dim}") | ||
| print(f"Device: {device}, dtype: {dtype}") | ||
|
|
||
| sage_cfg = SageAttentionConfig( | ||
| num_elts_per_blk_q=1, | ||
| num_elts_per_blk_k=16 if qk_int8 else 1, | ||
| num_elts_per_blk_v=1, | ||
| qk_int8=qk_int8, | ||
| ) | ||
|
|
||
| # Create models | ||
| naive = NaiveWanSelfAttention(hidden_size, num_heads, head_dim, dtype=dtype).to(device) | ||
|
|
||
| model_config = create_model_config( | ||
| hidden_size, | ||
| num_heads, | ||
| head_dim, | ||
| attn_backend="TRTLLM", | ||
| sage_attention_config=sage_cfg, | ||
| ) | ||
| integrated = Attention( | ||
| hidden_size, num_heads, qkv_mode=QKVMode.FUSE_QKV, config=model_config | ||
| ).to(device) | ||
|
|
||
| # Copy weights | ||
| copy_weights_self_attention(naive, integrated) | ||
|
|
||
| naive.eval() | ||
| integrated.eval() | ||
|
|
||
| # Create inputs | ||
| torch.manual_seed(42) | ||
| hidden_states = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype) | ||
| freqs_cos_HSD, freqs_sin_HSD = generate_rope_embeddings(seq_len, head_dim, device, is_HSD=True) | ||
| freqs_cos_SHD, freqs_sin_SHD = generate_rope_embeddings(seq_len, head_dim, device, is_HSD=False) | ||
|
|
||
| # Forward pass | ||
| with torch.no_grad(): | ||
| out_naive = naive(hidden_states, freqs_cos_HSD, freqs_sin_HSD) | ||
| out_sage = integrated(hidden_states, freqs=(freqs_cos_SHD, freqs_sin_SHD)) | ||
|
|
||
| # --- Assertions --- | ||
|
|
||
| # 1. Shape match | ||
| assert out_sage.shape == out_naive.shape, ( | ||
| f"Shape mismatch: sage={out_sage.shape}, naive={out_naive.shape}" | ||
| ) | ||
|
|
||
| # 2. All values finite (no NaN / Inf) | ||
| assert torch.isfinite(out_sage).all(), ( | ||
| f"SageAttention output contains NaN or Inf (B={batch_size}, S={seq_len})" | ||
| ) | ||
|
|
||
| # 3. Cosine similarity — sage quantization (FP8 per-block) introduces larger | ||
| # error than bf16 rounding, so elementwise allclose is too strict. | ||
| # Cosine similarity captures directional agreement robustly. | ||
| max_diff = (out_naive - out_sage).abs().max().item() | ||
| mean_diff = (out_naive - out_sage).abs().mean().item() | ||
| cos_sim = F.cosine_similarity( | ||
| out_naive.reshape(-1).float(), out_sage.reshape(-1).float(), dim=0 | ||
| ).item() | ||
|
|
||
| print(f"\n Output shape: {out_sage.shape}") | ||
| print(f" Max absolute diff: {max_diff:.2e}") | ||
| print(f" Mean absolute diff: {mean_diff:.2e}") | ||
| print(f" Cosine similarity: {cos_sim:.6f}") | ||
|
|
||
| assert cos_sim > 0.99, ( | ||
| f"SageAttention cosine similarity too low: {cos_sim:.4f} < 0.99 " | ||
| f"(B={batch_size}, S={seq_len}, qk_int8={qk_int8})" | ||
| ) | ||
| return cos_sim > 0.99 |
There was a problem hiding this comment.
Add coverage for the Wan2.1 k=4 preset.
This parametrization only exercises num_elts_per_blk_k values of 1 and 16, but the new Wan2.1 helper in visual_gen_wan_t2v.py selects 4. A regression in that path would still pass here. As per coding guidelines, this test should cover the model-specific preset that the feature actually wires up.
🔎 Verification / suggested test expansion
#!/bin/bash
# Check whether any other visual-gen tests already cover the Wan2.1 k=4 SageAttention preset.
rg -n -C 2 'num_elts_per_blk_k\s*=\s*4|SageAttentionConfig\(' \
tests/unittest/_torch/visual_gen \
examples/visual_gen \
tensorrt_llm/_torch/visual_gen-@pytest.mark.parametrize("qk_int8", [False, True])
-def test_sage_attention_self_attention(qk_int8: bool, batch_size: int, seq_len: int):
+@pytest.mark.parametrize(
+ "qk_int8,num_elts_per_blk_k",
+ [(False, 1), (True, 4), (True, 16)],
+)
+def test_sage_attention_self_attention(
+ qk_int8: bool, num_elts_per_blk_k: int, batch_size: int, seq_len: int
+):
@@
- num_elts_per_blk_k=16 if qk_int8 else 1,
+ num_elts_per_blk_k=num_elts_per_blk_k,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unittest/_torch/visual_gen/test_attention_integration.py` around lines
289 - 387, The test only exercises num_elts_per_blk_k values 1 and 16 via the
qk_int8 boolean in test_sage_attention_self_attention, so add coverage for the
Wan2.1 preset (k=4) by changing the parametrization: introduce a new parameter
(e.g., num_elts_per_blk_k_values) or replace qk_int8 with a param that
enumerates [1,4,16], then construct
SageAttentionConfig(num_elts_per_blk_k=that_value, num_elts_per_blk_q=1,
num_elts_per_blk_v=1, qk_int8=(that_value!=1)) inside the test; keep the test
function name test_sage_attention_self_attention and SageAttentionConfig usage
so the new k=4 path is exercised.
| # Quick / CI tests | ||
| # ------------------------------------------------------------------ | ||
|
|
||
| @pytest.mark.parametrize("sage_cfg,cfg_id", zip(_SAGE_CONFIGS, _SAGE_CONFIG_IDS)) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
wc -l tests/unittest/_torch/visual_gen/test_attention_perf.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 124
🏁 Script executed:
sed -n '937,947p' tests/unittest/_torch/visual_gen/test_attention_perf.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 463
🏁 Script executed:
sed -n '963,973p' tests/unittest/_torch/visual_gen/test_attention_perf.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 617
🏁 Script executed:
sed -n '1000,1010p' tests/unittest/_torch/visual_gen/test_attention_perf.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 482
🏁 Script executed:
rg -n "_SAGE_CONFIGS\s*=" tests/unittest/_torch/visual_gen/test_attention_perf.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 85
🏁 Script executed:
rg -n "_SAGE_CONFIG_IDS\s*=" tests/unittest/_torch/visual_gen/test_attention_perf.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 141
🏁 Script executed:
sed -n '1/50p' tests/unittest/_torch/visual_gen/test_attention_perf.py | head -100Repository: NVIDIA/TensorRT-LLM
Length of output: 115
🏁 Script executed:
rg -B2 -A15 "^_SAGE_CONFIGS\s*=" tests/unittest/_torch/visual_gen/test_attention_perf.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 779
🏁 Script executed:
rg -B2 -A15 "^_SAGE_CONFIG_IDS\s*=" tests/unittest/_torch/visual_gen/test_attention_perf.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 720
Add strict=True to zip() in parametrized test matrices.
Without strict=True, any future list-length mismatch between _SAGE_CONFIGS and _SAGE_CONFIG_IDS silently truncates cases and reduces coverage.
Suggested diff
- `@pytest.mark.parametrize`("sage_cfg,cfg_id", zip(_SAGE_CONFIGS, _SAGE_CONFIG_IDS))
+ `@pytest.mark.parametrize`("sage_cfg,cfg_id", zip(_SAGE_CONFIGS, _SAGE_CONFIG_IDS, strict=True))Also applies to: 968, 1005
🧰 Tools
🪛 Ruff (0.15.11)
[warning] 942-942: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unittest/_torch/visual_gen/test_attention_perf.py` at line 942, The
parametrized tests use zip(_SAGE_CONFIGS, _SAGE_CONFIG_IDS) which will silently
drop cases if the two lists differ in length; update the pytest.mark.parametrize
calls to use zip(..., strict=True) for each occurrence (e.g. where
pytest.mark.parametrize("sage_cfg,cfg_id", zip(_SAGE_CONFIGS, _SAGE_CONFIG_IDS))
at lines shown and the other occurrences referencing _SAGE_CONFIGS and
_SAGE_CONFIG_IDS) so that a length mismatch raises immediately; modify each zip
invocation (including the ones at the other noted locations) to pass
strict=True.
| def test_sage_vs_vanilla_quick(self, sage_cfg: SageAttentionConfig, cfg_id: str): | ||
| """Compare SageAttention timing against VANILLA at a quick size. | ||
|
|
||
| Does not assert a minimum speedup — the goal is to catch regressions | ||
| where sage unexpectedly becomes much slower than plain SDPA. | ||
| """ | ||
| if sage_cfg.qk_int8 and torch.cuda.get_device_capability()[1] == 3: | ||
| pytest.skip("SM103 does not have Int8 Tensor Cores.") | ||
|
|
||
| batch, num_heads, seq_len, head_dim = 1, 12, 4096, 128 | ||
|
|
||
| vanilla = self.benchmark.benchmark_single( | ||
| batch, num_heads, seq_len, head_dim, backend="VANILLA", verbose=False | ||
| ) | ||
| trtllm = self.benchmark.benchmark_single( | ||
| batch, num_heads, seq_len, head_dim, backend="TRTLLM", verbose=False | ||
| ) | ||
| sage = self._bench(batch, num_heads, seq_len, head_dim, sage_cfg) | ||
|
|
||
| assert vanilla is not None, "VANILLA benchmark failed" | ||
| assert trtllm is not None, "TRTLLM benchmark failed" | ||
| assert sage is not None, f"SageAttention {cfg_id} benchmark failed" | ||
|
|
||
| speedup_vs_vanilla = vanilla["avg_ms"] / sage["avg_ms"] | ||
| speedup_vs_trtllm = trtllm["avg_ms"] / sage["avg_ms"] | ||
| print( | ||
| f"\n sage {cfg_id} at S={seq_len}: " | ||
| f"avg={sage['avg_ms']:.3f}ms " | ||
| f"vs_vanilla={speedup_vs_vanilla:.2f}x " | ||
| f"vs_trtllm={speedup_vs_trtllm:.2f}x" | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify whether perf CI list coverage exists for this feature area.
fd -i 'l0_perf.yml' tests/integration/test_lists/test-db
fd -i 'llm_perf_*.yml' tests/integration/test_lists/qa
fd -i 'test_perf_sanity.py' tests/integration/defs/perf
# Look for WAN/visual-gen/sage attention coverage in perf defs/lists.
rg -n --iglob '*l0*.yml' --iglob 'llm_perf_*.yml' --iglob 'test_perf_sanity.py' \
'sage|sageattention|wan|visual_gen|attention' tests/integrationRepository: NVIDIA/TensorRT-LLM
Length of output: 50376
🏁 Script executed:
# Search specifically for test_attention_perf.py in test lists and perf files
rg -i 'test_attention_perf' tests/integration/test_lists/ tests/integration/defs/perf/ --max-count 20Repository: NVIDIA/TensorRT-LLM
Length of output: 166
🏁 Script executed:
# Search for sage or sageattention specifically in perf test files
rg -i 'sage|sageattention' tests/integration/test_lists/test-db/ tests/integration/defs/perf/ --max-count 20Repository: NVIDIA/TensorRT-LLM
Length of output: 6085
🏁 Script executed:
# Check the actual test method to see if there are assertions or guards
sed -n '969,1040p' tests/unittest/_torch/visual_gen/test_attention_perf.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2989
Add regression guards or a perf baseline assertion to catch SageAttention slowdowns.
The test_sage_vs_vanilla_quick and test_sage_vs_trtllm_wan_shapes tests (also at lines 1010–1038) print speedup metrics but do not assert minimum thresholds. The docstring explicitly states "Does not assert a minimum speedup." If SageAttention becomes materially slower, these tests pass anyway, allowing regressions to go undetected.
While the test file is listed in tests/integration/test_lists/test-db/l0_b200.yml, the lack of threshold enforcement means performance regressions slip through CI. Add either:
- A loose minimum speedup assertion (e.g.,
assert speedup_vs_vanilla > 0.9, f"slowdown: {speedup_vs_vanilla:.2f}x"), or - A perf-sanity baseline in
tests/integration/defs/perf/test_perf_sanity.pythat records and trend-checks latencies.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unittest/_torch/visual_gen/test_attention_perf.py` around lines 969 -
999, Add a loose performance assertion to prevent silent regressions: in
test_sage_vs_vanilla_quick (and similarly in test_sage_vs_trtllm_wan_shapes)
after computing speedup_vs_vanilla and speedup_vs_trtllm, add an assertion such
as assert speedup_vs_vanilla > 0.9, f"slowdown vs VANILLA:
{speedup_vs_vanilla:.2f}x" (and/or assert speedup_vs_trtllm > 0.9 with a similar
message) so the test fails when SageAttention is materially slower; update the
assertion threshold as needed for CI stability and keep the existing print
statements intact.
|
PR_Github #45422 [ run ] completed with state
|
Summary by CodeRabbit
New Features
--enable_sage_attentionCLI optionDocumentation
Description
Integrate SageAttention kernels from #12937 into VisualGen for Wan and FLUX.2 models.
LTX-2 SageAttention is not yet supported.
Preliminary perf / quality compared to #12548 baseline from example scripts:
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.