Skip to content

[TRTLLM-12128][feat] enable SageAttention for Wan/FLUX#13425

Open
o-stoner wants to merge 5 commits intoNVIDIA:mainfrom
o-stoner:user/o-stoner/visual-gen-enable-sageattn
Open

[TRTLLM-12128][feat] enable SageAttention for Wan/FLUX#13425
o-stoner wants to merge 5 commits intoNVIDIA:mainfrom
o-stoner:user/o-stoner/visual-gen-enable-sageattn

Conversation

@o-stoner
Copy link
Copy Markdown
Collaborator

@o-stoner o-stoner commented Apr 24, 2026

Summary by CodeRabbit

  • New Features

    • Added SageAttention support to visual generation models (FLUX, WAN T2V, and WAN I2V) with a new --enable_sage_attention CLI option
    • Enabled per-block INT8 quantization optimizations for improved memory efficiency and attention performance
  • Documentation

    • Updated configuration guidance for the new SageAttention option

Description

Integrate SageAttention kernels from #12937 into VisualGen for Wan and FLUX.2 models.
LTX-2 SageAttention is not yet supported.

Preliminary perf / quality compared to #12548 baseline from example scripts:

Model Configs Pipeline Time (Baseline) Output (Baseline) Pipeline Time (SageAttention) Output (SageAttention) Speedup
Wan2.1-T2V-1.3B "A cute cat playing piano", 480x832, 33 frames, 50 steps, trtllm-fp8-blockwise linear 7.94s video 7.54s video 1.05x
Wan2.1-I2V-14B-480P 480x832, 81 frames, 50 steps, trtllm-nvfp4 linear, parallel VAE disabled 128.79s video 117.38s video 1.10x
Wan2.1-I2V-14B-720P 720x1280, 81 frames, 50 steps, trtllm-nvfp4 linear, parallel VAE disabled 501.23s video 458.33s video 1.09x
Wan2.1-T2V-14B-Diffusers "A cute cat playing piano", 720x1280, 81 frames, 50 steps, trtllm-nvfp4 linear 484.63s video 413.65s video 1.17x
Wan2.2-I2V-A14B 720x1280, 81 frames, 40 steps, trtllm-nvfp4 linear, parallel VAE disabled 387.83s video 340.48s video 1.14x
Wan2.2-T2V-A14B "A cute cat playing piano", 720x1280, 81 frames, 40 steps, trtllm-nvfp4 linear, parallel VAE disabled 389.10s video 331.24s video 1.17x
FLUX.2-dev "A cat sitting on a windowsill", 1024x1024, 50 steps, trtllm-nvfp4 linear 5.60s PNG 5.61s PNG 1x

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com>
@o-stoner
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45422 [ run ] triggered by Bot. Commit: 2c0a848 Link to invocation

@o-stoner o-stoner changed the title [TRTLLM-12128][feat] enable SageAttention [TRTLLM-12128][feat] enable SageAttention for WAN/FLUX Apr 24, 2026
@o-stoner o-stoner changed the title [TRTLLM-12128][feat] enable SageAttention for WAN/FLUX [TRTLLM-12128][feat] enable SageAttention for Wan/FLUX Apr 24, 2026
@o-stoner o-stoner marked this pull request as ready for review April 24, 2026 23:28
@o-stoner o-stoner requested review from a team as code owners April 24, 2026 23:28
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 24, 2026

📝 Walkthrough

Walkthrough

This pull request adds SageAttention support to visual generation models by refactoring the attention configuration from individual parameters into a unified SageAttentionConfig class, updating the TRTLLM attention backend API to accept explicit sequence metadata, and adding configuration validation and runtime checks across multiple visual generation examples and pipelines.

Changes

Cohort / File(s) Summary
Core Tensor Operations
cpp/tensorrt_llm/thop/attentionOp.cpp
Adjusted context-path validation and pointer setup; Runner::run now conditionally derives k_ptr/v_ptr from optional tensors with slicing, and updated error messages to clarify allowed attention configurations when SageAttention is enabled.
VisualGen CLI Examples
examples/visual_gen/visual_gen_flux.py, examples/visual_gen/visual_gen_wan_i2v.py, examples/visual_gen/visual_gen_wan_t2v.py
Added --enable_sage_attention CLI option and dynamic attention configuration building; when enabled, each script constructs SageAttention-specific per-block element layouts and INT8 quantization settings based on model checkpoint patterns, with appropriate logging.
VisualGen Configuration
examples/visual_gen/visual_gen_ltx2.py
Refactored attention configuration construction into a local dictionary instead of inline backend specification; no new parameters added.
Documentation
examples/visual_gen/README.md
Extended documentation to include WAN T2V SageAttention run configuration example and added --enable_sage_attention to the common arguments reference table.
VisualGen Config System
tensorrt_llm/_torch/visual_gen/config.py
Introduced SageAttentionConfig Pydantic model and added sage_attention_config field to AttentionConfig with validation ensuring TRTLLM backend requirement and supported block-size/quantization combinations.
Attention Backend
tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py
Replaced four separate SageAttention keyword arguments with single sage_attention_config parameter; updated forward signature to require explicit batch_size and seq_len and optional seq_len_kv; refactored runtime branching to reshape tensors and apply SageAttention-specific quantization when config is present.
Backend Factory & Dispatch
tensorrt_llm/_torch/visual_gen/attention_backend/utils.py, tensorrt_llm/_torch/attention_backend/trtllm.py
Updated factory to forward sage_attention_config to backend constructors; dispatch layer now prevents trtllm_gen_attention kernel usage when SageAttention is enabled, falling back to thop.attention.
VisualGen Modules & Pipelines
tensorrt_llm/_torch/visual_gen/modules/attention.py, tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py
Attention module now passes explicit batch_size, seq_len, and derived seq_len_kv to backend forward; LTX2 pipeline added check to reject SageAttention configuration with NotImplementedError.
Tests
tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_sage_attention.py, tests/unittest/_torch/visual_gen/test_attention_integration.py, tests/unittest/_torch/visual_gen/test_attention_perf.py
Updated test harnesses to use SageAttentionConfig object instead of separate parameters; added new parametrized SageAttention integration tests validating shape, finiteness, and cosine similarity (>0.99); introduced SM100-gated performance benchmarks comparing SageAttention against VANILLA and TRTLLM baselines.

Sequence Diagram

sequenceDiagram
    participant CLI as Example Script
    participant Config as Config System
    participant Factory as Attention Factory
    participant Backend as TrtllmAttention
    participant Ops as Low-level Ops

    CLI->>CLI: Parse --enable_sage_attention
    CLI->>Config: Build AttentionConfig with sage_attention_config
    Config->>Config: Validate backend=TRTLLM & supported block sizes
    Config-->>Factory: Return AttentionConfig
    Factory->>Factory: Extract sage_attention_config
    Factory->>Backend: Create with sage_attention_config
    Backend->>Backend: Store sage_attention_config
    
    alt SageAttention Enabled
        Backend->>Backend: Reshape Q/K/V to flattened forms
        Backend->>Ops: Call forward with per-block quantization
        Ops-->>Backend: Return attention output
    else Standard Path
        Backend->>Backend: Fuse/concatenate Q/K/V
        Backend->>Ops: Call forward without SageAttention params
        Ops-->>Backend: Return attention output
    end
    
    Backend-->>CLI: Attention result
Loading

Estimated Code Review Effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 68.89% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title '[TRTLLM-12128][feat] enable SageAttention for Wan/FLUX' directly and clearly summarizes the main change: enabling SageAttention support for Wan and FLUX models as a feature addition.
Description check ✅ Passed The PR description includes a clear explanation of the feature (SageAttention integration from PR #12937), mentions that LTX-2 is not yet supported, provides comprehensive performance/quality comparison table with multiple models and speedups, and has the PR checklist mostly completed with the final checkbox marked. However, the 'Test Coverage' section is empty.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_sage_attention.py (1)

21-23: ⚠️ Potential issue | 🟡 Minor

Fix the example pytest target.

The documented command points at test_sage_ulysses_attention.py, but this file is named test_ulysses_sage_attention.py, so copy/paste will fail.

🛠 Suggested fix
-    pytest tests/unittest/_torch/visual_gen/multi_gpu/test_sage_ulysses_attention.py -v
+    pytest tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_sage_attention.py -v
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_sage_attention.py`
around lines 21 - 23, Update the example pytest command in the docstring so it
references the correct test filename: replace the incorrect
`test_sage_ulysses_attention.py` target with `test_ulysses_sage_attention.py`
(the actual test file name) in the run example shown in the file
`tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_sage_attention.py`;
ensure the example reads: pytest
tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_sage_attention.py -v so
copy/paste runs the right test.
cpp/tensorrt_llm/thop/attentionOp.cpp (1)

652-675: ⚠️ Potential issue | 🔴 Critical

Remove the overly broad global gate.

The check is_mla_enable || is_fused_qkv || use_sage_attn conflates three independent concerns and blocks future callers from using unfused K/V without MLA or SageAttention enabled. The SageAttention-specific validation at lines 669–675 remains and ensures correctness for SageAttention paths.

🔧 Suggested fix
-    TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv || use_sage_attn,
-        "Context attention only allows these non-MLA cases: fused QKV; separate QKV with SageAttention");
     TLLM_CHECK_WITH_INFO(update_kv_cache, "KV cache update cannot be disabled now");
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/thop/attentionOp.cpp` around lines 652 - 675, Remove the
broad gate that currently enforces "is_mla_enable || is_fused_qkv ||
use_sage_attn" by deleting the TLLM_CHECK_WITH_INFO call that uses those three
flags; instead rely on the existing, more specific validations already present
(the later TLLM_CHECK_WITH_INFO calls that reference is_fused_qkv,
k.has_value(), v.has_value(), update_kv_cache, and the SageAttention checks
guarded by use_sage_attn). Ensure no other logic depends on that global check
and keep the local checks for fused QKV (is_fused_qkv), KV-cache update
(update_kv_cache), and SageAttention (use_sage_attn) intact so unfused K/V paths
are allowed when appropriate.
tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py (1)

268-295: ⚠️ Potential issue | 🟠 Major

Cross-attention with seq_len != seq_len_kv can fail in non-Sage path.

When k/v are provided and lengths differ, _concat_qkv() flattens to [B*seq_len, ...] and [B*kv_seq_len, ...]; torch.cat(..., dim=-1) then errors due to mismatched first dimension.

Suggested fix (handle unequal Q/KV lengths explicitly)
         else:
             if k is None and v is None:
                 qkv = q.reshape(batch_size * seq_len, -1)
-            else:
+                output = super().forward(
+                    q=qkv,
+                    k=None,
+                    v=None,
+                    metadata=prepared_metadata,
+                    attention_mask=attention_mask,
+                )
+            elif kv_seq_len != seq_len:
+                q_flat = q.reshape(batch_size * seq_len, -1)
+                k_flat = k.reshape(batch_size * kv_seq_len, -1)
+                v_flat = v.reshape(batch_size * kv_seq_len, -1)
+                output = super().forward(
+                    q=q_flat,
+                    k=k_flat,
+                    v=v_flat,
+                    metadata=prepared_metadata,
+                    attention_mask=attention_mask,
+                )
+            else:
                 qkv = self._concat_qkv(q, k, v, batch_size, seq_len, kv_seq_len)
-            output = super().forward(
-                q=qkv,
-                k=None,
-                v=None,
-                metadata=prepared_metadata,
-                attention_mask=attention_mask,
-            )
+                output = super().forward(
+                    q=qkv,
+                    k=None,
+                    v=None,
+                    metadata=prepared_metadata,
+                    attention_mask=attention_mask,
+                )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py` around lines 268
- 295, The non-Sage branch fails when seq_len_kv != seq_len because _concat_qkv
flattens Q and K/V to different first-dim sizes (batch_size*seq_len vs
batch_size*kv_seq_len) and torch.cat errors; fix by detecting when k and v are
provided and kv_seq_len != seq_len and, instead of calling _concat_qkv, reshape
q to (batch_size*seq_len, -1) and k, v to (batch_size*kv_seq_len, -1)
individually and pass q, k, v separately into super().forward (same pattern used
in the Sage branch) so the backend receives correctly-shaped Q/K/V; update the
forward implementation (and if needed _concat_qkv) to only use _concat_qkv when
seq_len == seq_len_kv.
🧹 Nitpick comments (2)
tests/unittest/_torch/visual_gen/test_attention_perf.py (1)

897-903: Prefer immutable class-level test matrices.

WAN_SEQ_LENS and QUICK_SEQ_LENS are mutable class attributes; accidental mutation can leak across tests.

Suggested diff
-    WAN_SEQ_LENS = [
+    WAN_SEQ_LENS = (
         (1, 12, 14040, 128, "wan_1.3b_480p_33f"),
         (1, 12, 32760, 128, "wan_1.3b_480p_81f"),
         (1, 40, 14040, 128, "wan_14b_480p_33f"),
         (1, 40, 32760, 128, "wan_14b_480p_81f"),
         (1, 40, 75600, 128, "wan_14b_720p_81f"),
-    ]
+    )
...
-    QUICK_SEQ_LENS = [
+    QUICK_SEQ_LENS = (
         (1, 12, 1024, 128, "quick_1k"),
         (1, 12, 4096, 128, "quick_4k"),
         (2, 12, 1024, 128, "quick_batch2"),
-    ]
+    )

Also applies to: 906-910

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/visual_gen/test_attention_perf.py` around lines 897 -
903, WAN_SEQ_LENS and QUICK_SEQ_LENS are defined as mutable lists at class scope
which can be accidentally mutated across tests; change both to immutable
sequence types (e.g., tuples or nested tuples) and keep them as class-level
constants (retain names) so test matrices cannot be modified at runtime,
updating any code that iterates them to work with tuples if necessary (refer to
WAN_SEQ_LENS and QUICK_SEQ_LENS in the test class).
tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py (1)

271-274: Use an explicit exception instead of assert for runtime validation.

assert can be stripped with optimized Python execution; this check should be a guaranteed runtime guard.

Suggested diff
-            assert k is not None and v is not None, (
-                "SageAttention requires separate Q, K, V tensors"
-            )
+            if k is None or v is None:
+                raise ValueError("SageAttention requires separate Q, K, V tensors")

As per coding guidelines: “Use built-in Python exception types; use exceptions for error handling, not return values.”

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py` around lines 271
- 274, Replace the runtime assert in the block that checks
self.sage_attention_config with an explicit exception raise so the validation
always runs; specifically, in the method containing the lines checking
self.sage_attention_config, replace "assert k is not None and v is not None"
with a raised built-in exception (e.g., raise ValueError) that includes a clear
message like "SageAttention requires separate Q, K, V tensors" and reference k
and v being None as needed to aid debugging.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/visual_gen/config.py`:
- Around line 71-80: The descriptions claim "0 disables" for
num_elts_per_blk_q/k/v but _validate_sage_attn_config only allows whitelisted
SageAttention presets; update the validator to implement the disabled path:
inside _validate_sage_attn_config detect if any of num_elts_per_blk_q,
num_elts_per_blk_k, or num_elts_per_blk_v == 0 and treat that as an explicit
"disable quantization"/custom-config mode (skip the preset whitelist check and
ensure qk_int8 handling remains correct), otherwise continue enforcing the
existing preset whitelist; reference the fields num_elts_per_blk_q,
num_elts_per_blk_k, num_elts_per_blk_v and the validator function
_validate_sage_attn_config when making the change.
- Around line 54-96: SageAttentionConfig currently inherits from BaseModel
causing unknown fields to be ignored; change its base to StrictBaseModel so
nested user-facing Pydantic validation fails fast. Locate the class definition
for SageAttentionConfig and replace BaseModel with StrictBaseModel (keeping
existing field names like num_elts_per_blk_q, num_elts_per_blk_k,
num_elts_per_blk_v, qk_int8 and the PydanticField usages) so
AttentionConfig.sage_attention_config will enforce strict validation.

In `@tensorrt_llm/_torch/visual_gen/modules/attention.py`:
- Around line 311-320: The seq_len and seq_len_kv values are taken from
q.shape[1]/k.shape[1] after the HND reshape, which yields head dimensions (so
VANILLA/FA4 backends get head counts instead of real sequence lengths); update
the logic in the attention module where seq_len and seq_len_kv are set (before
calling self.attn.forward) to compute lengths from the pre-reshaped tensors or
branch on backend_layout (e.g., if backend_layout indicates HND, derive
seq_len/seq_len_kv from the original un-reshaped sequence dimension or use
stored sequence_length variables, otherwise keep the current q.shape/k.shape
logic) and pass those corrected seq_len and seq_len_kv into self.attn.forward to
ensure correct sequence length semantics for HND backends.

In `@tests/unittest/_torch/visual_gen/test_attention_integration.py`:
- Around line 289-387: The test only exercises num_elts_per_blk_k values 1 and
16 via the qk_int8 boolean in test_sage_attention_self_attention, so add
coverage for the Wan2.1 preset (k=4) by changing the parametrization: introduce
a new parameter (e.g., num_elts_per_blk_k_values) or replace qk_int8 with a
param that enumerates [1,4,16], then construct
SageAttentionConfig(num_elts_per_blk_k=that_value, num_elts_per_blk_q=1,
num_elts_per_blk_v=1, qk_int8=(that_value!=1)) inside the test; keep the test
function name test_sage_attention_self_attention and SageAttentionConfig usage
so the new k=4 path is exercised.

In `@tests/unittest/_torch/visual_gen/test_attention_perf.py`:
- Line 942: The parametrized tests use zip(_SAGE_CONFIGS, _SAGE_CONFIG_IDS)
which will silently drop cases if the two lists differ in length; update the
pytest.mark.parametrize calls to use zip(..., strict=True) for each occurrence
(e.g. where pytest.mark.parametrize("sage_cfg,cfg_id", zip(_SAGE_CONFIGS,
_SAGE_CONFIG_IDS)) at lines shown and the other occurrences referencing
_SAGE_CONFIGS and _SAGE_CONFIG_IDS) so that a length mismatch raises
immediately; modify each zip invocation (including the ones at the other noted
locations) to pass strict=True.
- Around line 969-999: Add a loose performance assertion to prevent silent
regressions: in test_sage_vs_vanilla_quick (and similarly in
test_sage_vs_trtllm_wan_shapes) after computing speedup_vs_vanilla and
speedup_vs_trtllm, add an assertion such as assert speedup_vs_vanilla > 0.9,
f"slowdown vs VANILLA: {speedup_vs_vanilla:.2f}x" (and/or assert
speedup_vs_trtllm > 0.9 with a similar message) so the test fails when
SageAttention is materially slower; update the assertion threshold as needed for
CI stability and keep the existing print statements intact.

---

Outside diff comments:
In `@cpp/tensorrt_llm/thop/attentionOp.cpp`:
- Around line 652-675: Remove the broad gate that currently enforces
"is_mla_enable || is_fused_qkv || use_sage_attn" by deleting the
TLLM_CHECK_WITH_INFO call that uses those three flags; instead rely on the
existing, more specific validations already present (the later
TLLM_CHECK_WITH_INFO calls that reference is_fused_qkv, k.has_value(),
v.has_value(), update_kv_cache, and the SageAttention checks guarded by
use_sage_attn). Ensure no other logic depends on that global check and keep the
local checks for fused QKV (is_fused_qkv), KV-cache update (update_kv_cache),
and SageAttention (use_sage_attn) intact so unfused K/V paths are allowed when
appropriate.

In `@tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py`:
- Around line 268-295: The non-Sage branch fails when seq_len_kv != seq_len
because _concat_qkv flattens Q and K/V to different first-dim sizes
(batch_size*seq_len vs batch_size*kv_seq_len) and torch.cat errors; fix by
detecting when k and v are provided and kv_seq_len != seq_len and, instead of
calling _concat_qkv, reshape q to (batch_size*seq_len, -1) and k, v to
(batch_size*kv_seq_len, -1) individually and pass q, k, v separately into
super().forward (same pattern used in the Sage branch) so the backend receives
correctly-shaped Q/K/V; update the forward implementation (and if needed
_concat_qkv) to only use _concat_qkv when seq_len == seq_len_kv.

In `@tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_sage_attention.py`:
- Around line 21-23: Update the example pytest command in the docstring so it
references the correct test filename: replace the incorrect
`test_sage_ulysses_attention.py` target with `test_ulysses_sage_attention.py`
(the actual test file name) in the run example shown in the file
`tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_sage_attention.py`;
ensure the example reads: pytest
tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_sage_attention.py -v so
copy/paste runs the right test.

---

Nitpick comments:
In `@tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py`:
- Around line 271-274: Replace the runtime assert in the block that checks
self.sage_attention_config with an explicit exception raise so the validation
always runs; specifically, in the method containing the lines checking
self.sage_attention_config, replace "assert k is not None and v is not None"
with a raised built-in exception (e.g., raise ValueError) that includes a clear
message like "SageAttention requires separate Q, K, V tensors" and reference k
and v being None as needed to aid debugging.

In `@tests/unittest/_torch/visual_gen/test_attention_perf.py`:
- Around line 897-903: WAN_SEQ_LENS and QUICK_SEQ_LENS are defined as mutable
lists at class scope which can be accidentally mutated across tests; change both
to immutable sequence types (e.g., tuples or nested tuples) and keep them as
class-level constants (retain names) so test matrices cannot be modified at
runtime, updating any code that iterates them to work with tuples if necessary
(refer to WAN_SEQ_LENS and QUICK_SEQ_LENS in the test class).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: acbdcd9e-da10-404d-aede-f5cacad4beae

📥 Commits

Reviewing files that changed from the base of the PR and between c4b8e8e and 2c0a848.

📒 Files selected for processing (15)
  • cpp/tensorrt_llm/thop/attentionOp.cpp
  • examples/visual_gen/README.md
  • examples/visual_gen/visual_gen_flux.py
  • examples/visual_gen/visual_gen_ltx2.py
  • examples/visual_gen/visual_gen_wan_i2v.py
  • examples/visual_gen/visual_gen_wan_t2v.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py
  • tensorrt_llm/_torch/visual_gen/attention_backend/utils.py
  • tensorrt_llm/_torch/visual_gen/config.py
  • tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py
  • tensorrt_llm/_torch/visual_gen/modules/attention.py
  • tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_sage_attention.py
  • tests/unittest/_torch/visual_gen/test_attention_integration.py
  • tests/unittest/_torch/visual_gen/test_attention_perf.py

Comment on lines +54 to +96
class SageAttentionConfig(BaseModel):
"""Configuration for SageAttention quantization (TRTLLM backend only).

SageAttention quantizes Q/K/V into FP8 (or INT8 for Q/K) with per-block
scaling factors, enabling faster attention kernels. Providing this config
to AttentionConfig enables SageAttention; omitting it (None) disables it.

Similar to ``sparse_attention_config`` for the base TRTLLM attention
backend — the presence of the config object signals enablement.

Currently these (num_elts_per_blk_q, num_elts_per_blk_k, num_elts_per_blk_v)
combinations are enabled:
- (1, 1, 1)
- (1, 4, 1)
- (1, 16, 1) [for qk_int8 == True only]
"""

num_elts_per_blk_q: int = PydanticField(
1, ge=0, description="Elements per quantization block for Q (0 disables)"
)
num_elts_per_blk_k: int = PydanticField(
4, ge=0, description="Elements per quantization block for K (0 disables)"
)
num_elts_per_blk_v: int = PydanticField(
1, ge=0, description="Elements per quantization block for V (0 disables)"
)
qk_int8: bool = PydanticField(True, description="Use INT8 (vs E4M3) for Q/K quantization")


class AttentionConfig(StrictBaseModel):
"""Configuration for Attention layers."""

backend: Literal["VANILLA", "TRTLLM", "FA4"] = PydanticField(
"VANILLA", description="Attention backend: VANILLA (PyTorch SDPA), TRTLLM, FA4"
)
sage_attention_config: Optional[SageAttentionConfig] = PydanticField(
None,
description=(
"SageAttention config (TRTLLM backend only). "
"Set to a SageAttentionConfig instance to enable SageAttention; "
"leave as None to disable."
),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify the current inheritance and StrictBaseModel definition in the repo.
rg -n -C 2 'class SageAttentionConfig|class StrictBaseModel' \
  tensorrt_llm/_torch/visual_gen/config.py \
  tensorrt_llm/llmapi/utils.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 654


🏁 Script executed:

head -60 tensorrt_llm/_torch/visual_gen/config.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2250


🏁 Script executed:

rg -n 'class \w+\((BaseModel|StrictBaseModel)' tensorrt_llm/_torch/visual_gen/config.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 499


🏁 Script executed:

sed -n '615,625p' tensorrt_llm/_torch/visual_gen/config.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 539


Make SageAttentionConfig inherit from StrictBaseModel.

This is a user-facing nested Pydantic model used within AttentionConfig, but it inherits from BaseModel instead of StrictBaseModel. This means unknown fields will be silently ignored rather than failing fast. Per the coding guidelines, user-facing Pydantic model classes should inherit from StrictBaseModel to fail fast on invalid field names.

Suggested fix
-class SageAttentionConfig(BaseModel):
+class SageAttentionConfig(StrictBaseModel):

StrictBaseModel is already imported in the file (line 13).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/config.py` around lines 54 - 96,
SageAttentionConfig currently inherits from BaseModel causing unknown fields to
be ignored; change its base to StrictBaseModel so nested user-facing Pydantic
validation fails fast. Locate the class definition for SageAttentionConfig and
replace BaseModel with StrictBaseModel (keeping existing field names like
num_elts_per_blk_q, num_elts_per_blk_k, num_elts_per_blk_v, qk_int8 and the
PydanticField usages) so AttentionConfig.sage_attention_config will enforce
strict validation.

Comment on lines +71 to +80
num_elts_per_blk_q: int = PydanticField(
1, ge=0, description="Elements per quantization block for Q (0 disables)"
)
num_elts_per_blk_k: int = PydanticField(
4, ge=0, description="Elements per quantization block for K (0 disables)"
)
num_elts_per_blk_v: int = PydanticField(
1, ge=0, description="Elements per quantization block for V (0 disables)"
)
qk_int8: bool = PydanticField(True, description="Use INT8 (vs E4M3) for Q/K quantization")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Align the 0 disables wording with the validator.

The field descriptions suggest that zero is a supported “disable” value, but _validate_sage_attn_config() only accepts the whitelisted SageAttention presets. Either update the description or add a real disabled path so the API matches the validation behavior.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/config.py` around lines 71 - 80, The
descriptions claim "0 disables" for num_elts_per_blk_q/k/v but
_validate_sage_attn_config only allows whitelisted SageAttention presets; update
the validator to implement the disabled path: inside _validate_sage_attn_config
detect if any of num_elts_per_blk_q, num_elts_per_blk_k, or num_elts_per_blk_v
== 0 and treat that as an explicit "disable quantization"/custom-config mode
(skip the preset whitelist check and ensure qk_int8 handling remains correct),
otherwise continue enforcing the existing preset whitelist; reference the fields
num_elts_per_blk_q, num_elts_per_blk_k, num_elts_per_blk_v and the validator
function _validate_sage_attn_config when making the change.

Comment on lines +311 to +320
seq_len = q.shape[1]
seq_len_kv = k.shape[1] if k is not None else seq_len
out = self.attn.forward(
q=q,
k=k,
v=v,
batch_size=batch_size,
seq_len=seq_len,
seq_len_kv=seq_len_kv,
**kwargs,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix the sequence-length derivation for HND backends.

q.shape[1] and k.shape[1] are the head dimensions after the HND reshape, so VANILLA/FA4 will pass head counts into forward(...) instead of the true sequence lengths. Compute the lengths before reshaping, or branch on backend_layout when setting seq_len and seq_len_kv.

🔧 Suggested fix
-        seq_len = q.shape[1]
-        seq_len_kv = k.shape[1] if k is not None else seq_len
+        seq_len = q.shape[2] if backend_layout == AttentionTensorLayout.HND else q.shape[1]
+        seq_len_kv = (
+            k.shape[2] if backend_layout == AttentionTensorLayout.HND else seq_len
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/modules/attention.py` around lines 311 - 320,
The seq_len and seq_len_kv values are taken from q.shape[1]/k.shape[1] after the
HND reshape, which yields head dimensions (so VANILLA/FA4 backends get head
counts instead of real sequence lengths); update the logic in the attention
module where seq_len and seq_len_kv are set (before calling self.attn.forward)
to compute lengths from the pre-reshaped tensors or branch on backend_layout
(e.g., if backend_layout indicates HND, derive seq_len/seq_len_kv from the
original un-reshaped sequence dimension or use stored sequence_length variables,
otherwise keep the current q.shape/k.shape logic) and pass those corrected
seq_len and seq_len_kv into self.attn.forward to ensure correct sequence length
semantics for HND backends.

Comment on lines +289 to +387
# seq_len: pow2 baselines + real WAN latent token counts (VAE 8x spatial, 4x temporal, patch [1,2,2])
# batch_size: B=1 (cfg_size=2, split across GPUs) / B=2 (cfg_size=1, single GPU)
@pytest.mark.parametrize("seq_len", [256, 512, 1560, 3600, 4096, 16384, 32760])
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("qk_int8", [False, True])
def test_sage_attention_self_attention(qk_int8: bool, batch_size: int, seq_len: int):
"""Test SageAttention (TRTLLM + sage_attention_config) self-attention.

SageAttention quantizes Q/K/V with per-block scaling factors, so outputs
are expected to differ from the naive SDPA reference. We verify:
1. Forward pass completes without error
2. Output shape matches naive
3. Outputs are finite (no NaN/Inf)
4. Approximate agreement with naive (cosine similarity > 0.99)
"""
print("\n" + "=" * 60)
print(f"Testing SageAttention (qk_int8={qk_int8}, B={batch_size}, S={seq_len})")
print("=" * 60)

# The sm100 sage kernel only has cubins for head_dim=128,
# so match the WAN model dimensions (12 heads, head_dim=128).
num_heads = 12
head_dim = 128
hidden_size = num_heads * head_dim # 1536
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16

print(f"Config: B={batch_size}, S={seq_len}, H={hidden_size}, heads={num_heads}, D={head_dim}")
print(f"Device: {device}, dtype: {dtype}")

sage_cfg = SageAttentionConfig(
num_elts_per_blk_q=1,
num_elts_per_blk_k=16 if qk_int8 else 1,
num_elts_per_blk_v=1,
qk_int8=qk_int8,
)

# Create models
naive = NaiveWanSelfAttention(hidden_size, num_heads, head_dim, dtype=dtype).to(device)

model_config = create_model_config(
hidden_size,
num_heads,
head_dim,
attn_backend="TRTLLM",
sage_attention_config=sage_cfg,
)
integrated = Attention(
hidden_size, num_heads, qkv_mode=QKVMode.FUSE_QKV, config=model_config
).to(device)

# Copy weights
copy_weights_self_attention(naive, integrated)

naive.eval()
integrated.eval()

# Create inputs
torch.manual_seed(42)
hidden_states = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype)
freqs_cos_HSD, freqs_sin_HSD = generate_rope_embeddings(seq_len, head_dim, device, is_HSD=True)
freqs_cos_SHD, freqs_sin_SHD = generate_rope_embeddings(seq_len, head_dim, device, is_HSD=False)

# Forward pass
with torch.no_grad():
out_naive = naive(hidden_states, freqs_cos_HSD, freqs_sin_HSD)
out_sage = integrated(hidden_states, freqs=(freqs_cos_SHD, freqs_sin_SHD))

# --- Assertions ---

# 1. Shape match
assert out_sage.shape == out_naive.shape, (
f"Shape mismatch: sage={out_sage.shape}, naive={out_naive.shape}"
)

# 2. All values finite (no NaN / Inf)
assert torch.isfinite(out_sage).all(), (
f"SageAttention output contains NaN or Inf (B={batch_size}, S={seq_len})"
)

# 3. Cosine similarity — sage quantization (FP8 per-block) introduces larger
# error than bf16 rounding, so elementwise allclose is too strict.
# Cosine similarity captures directional agreement robustly.
max_diff = (out_naive - out_sage).abs().max().item()
mean_diff = (out_naive - out_sage).abs().mean().item()
cos_sim = F.cosine_similarity(
out_naive.reshape(-1).float(), out_sage.reshape(-1).float(), dim=0
).item()

print(f"\n Output shape: {out_sage.shape}")
print(f" Max absolute diff: {max_diff:.2e}")
print(f" Mean absolute diff: {mean_diff:.2e}")
print(f" Cosine similarity: {cos_sim:.6f}")

assert cos_sim > 0.99, (
f"SageAttention cosine similarity too low: {cos_sim:.4f} < 0.99 "
f"(B={batch_size}, S={seq_len}, qk_int8={qk_int8})"
)
return cos_sim > 0.99
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add coverage for the Wan2.1 k=4 preset.

This parametrization only exercises num_elts_per_blk_k values of 1 and 16, but the new Wan2.1 helper in visual_gen_wan_t2v.py selects 4. A regression in that path would still pass here. As per coding guidelines, this test should cover the model-specific preset that the feature actually wires up.

🔎 Verification / suggested test expansion
#!/bin/bash
# Check whether any other visual-gen tests already cover the Wan2.1 k=4 SageAttention preset.
rg -n -C 2 'num_elts_per_blk_k\s*=\s*4|SageAttentionConfig\(' \
  tests/unittest/_torch/visual_gen \
  examples/visual_gen \
  tensorrt_llm/_torch/visual_gen
-@pytest.mark.parametrize("qk_int8", [False, True])
-def test_sage_attention_self_attention(qk_int8: bool, batch_size: int, seq_len: int):
+@pytest.mark.parametrize(
+    "qk_int8,num_elts_per_blk_k",
+    [(False, 1), (True, 4), (True, 16)],
+)
+def test_sage_attention_self_attention(
+    qk_int8: bool, num_elts_per_blk_k: int, batch_size: int, seq_len: int
+):
@@
-        num_elts_per_blk_k=16 if qk_int8 else 1,
+        num_elts_per_blk_k=num_elts_per_blk_k,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/visual_gen/test_attention_integration.py` around lines
289 - 387, The test only exercises num_elts_per_blk_k values 1 and 16 via the
qk_int8 boolean in test_sage_attention_self_attention, so add coverage for the
Wan2.1 preset (k=4) by changing the parametrization: introduce a new parameter
(e.g., num_elts_per_blk_k_values) or replace qk_int8 with a param that
enumerates [1,4,16], then construct
SageAttentionConfig(num_elts_per_blk_k=that_value, num_elts_per_blk_q=1,
num_elts_per_blk_v=1, qk_int8=(that_value!=1)) inside the test; keep the test
function name test_sage_attention_self_attention and SageAttentionConfig usage
so the new k=4 path is exercised.

# Quick / CI tests
# ------------------------------------------------------------------

@pytest.mark.parametrize("sage_cfg,cfg_id", zip(_SAGE_CONFIGS, _SAGE_CONFIG_IDS))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

wc -l tests/unittest/_torch/visual_gen/test_attention_perf.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 124


🏁 Script executed:

sed -n '937,947p' tests/unittest/_torch/visual_gen/test_attention_perf.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 463


🏁 Script executed:

sed -n '963,973p' tests/unittest/_torch/visual_gen/test_attention_perf.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 617


🏁 Script executed:

sed -n '1000,1010p' tests/unittest/_torch/visual_gen/test_attention_perf.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 482


🏁 Script executed:

rg -n "_SAGE_CONFIGS\s*=" tests/unittest/_torch/visual_gen/test_attention_perf.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 85


🏁 Script executed:

rg -n "_SAGE_CONFIG_IDS\s*=" tests/unittest/_torch/visual_gen/test_attention_perf.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 141


🏁 Script executed:

sed -n '1/50p' tests/unittest/_torch/visual_gen/test_attention_perf.py | head -100

Repository: NVIDIA/TensorRT-LLM

Length of output: 115


🏁 Script executed:

rg -B2 -A15 "^_SAGE_CONFIGS\s*=" tests/unittest/_torch/visual_gen/test_attention_perf.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 779


🏁 Script executed:

rg -B2 -A15 "^_SAGE_CONFIG_IDS\s*=" tests/unittest/_torch/visual_gen/test_attention_perf.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 720


Add strict=True to zip() in parametrized test matrices.

Without strict=True, any future list-length mismatch between _SAGE_CONFIGS and _SAGE_CONFIG_IDS silently truncates cases and reduces coverage.

Suggested diff
-    `@pytest.mark.parametrize`("sage_cfg,cfg_id", zip(_SAGE_CONFIGS, _SAGE_CONFIG_IDS))
+    `@pytest.mark.parametrize`("sage_cfg,cfg_id", zip(_SAGE_CONFIGS, _SAGE_CONFIG_IDS, strict=True))

Also applies to: 968, 1005

🧰 Tools
🪛 Ruff (0.15.11)

[warning] 942-942: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/visual_gen/test_attention_perf.py` at line 942, The
parametrized tests use zip(_SAGE_CONFIGS, _SAGE_CONFIG_IDS) which will silently
drop cases if the two lists differ in length; update the pytest.mark.parametrize
calls to use zip(..., strict=True) for each occurrence (e.g. where
pytest.mark.parametrize("sage_cfg,cfg_id", zip(_SAGE_CONFIGS, _SAGE_CONFIG_IDS))
at lines shown and the other occurrences referencing _SAGE_CONFIGS and
_SAGE_CONFIG_IDS) so that a length mismatch raises immediately; modify each zip
invocation (including the ones at the other noted locations) to pass
strict=True.

Comment on lines +969 to +999
def test_sage_vs_vanilla_quick(self, sage_cfg: SageAttentionConfig, cfg_id: str):
"""Compare SageAttention timing against VANILLA at a quick size.

Does not assert a minimum speedup — the goal is to catch regressions
where sage unexpectedly becomes much slower than plain SDPA.
"""
if sage_cfg.qk_int8 and torch.cuda.get_device_capability()[1] == 3:
pytest.skip("SM103 does not have Int8 Tensor Cores.")

batch, num_heads, seq_len, head_dim = 1, 12, 4096, 128

vanilla = self.benchmark.benchmark_single(
batch, num_heads, seq_len, head_dim, backend="VANILLA", verbose=False
)
trtllm = self.benchmark.benchmark_single(
batch, num_heads, seq_len, head_dim, backend="TRTLLM", verbose=False
)
sage = self._bench(batch, num_heads, seq_len, head_dim, sage_cfg)

assert vanilla is not None, "VANILLA benchmark failed"
assert trtllm is not None, "TRTLLM benchmark failed"
assert sage is not None, f"SageAttention {cfg_id} benchmark failed"

speedup_vs_vanilla = vanilla["avg_ms"] / sage["avg_ms"]
speedup_vs_trtllm = trtllm["avg_ms"] / sage["avg_ms"]
print(
f"\n sage {cfg_id} at S={seq_len}: "
f"avg={sage['avg_ms']:.3f}ms "
f"vs_vanilla={speedup_vs_vanilla:.2f}x "
f"vs_trtllm={speedup_vs_trtllm:.2f}x"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify whether perf CI list coverage exists for this feature area.
fd -i 'l0_perf.yml' tests/integration/test_lists/test-db
fd -i 'llm_perf_*.yml' tests/integration/test_lists/qa
fd -i 'test_perf_sanity.py' tests/integration/defs/perf

# Look for WAN/visual-gen/sage attention coverage in perf defs/lists.
rg -n --iglob '*l0*.yml' --iglob 'llm_perf_*.yml' --iglob 'test_perf_sanity.py' \
  'sage|sageattention|wan|visual_gen|attention' tests/integration

Repository: NVIDIA/TensorRT-LLM

Length of output: 50376


🏁 Script executed:

# Search specifically for test_attention_perf.py in test lists and perf files
rg -i 'test_attention_perf' tests/integration/test_lists/ tests/integration/defs/perf/ --max-count 20

Repository: NVIDIA/TensorRT-LLM

Length of output: 166


🏁 Script executed:

# Search for sage or sageattention specifically in perf test files
rg -i 'sage|sageattention' tests/integration/test_lists/test-db/ tests/integration/defs/perf/ --max-count 20

Repository: NVIDIA/TensorRT-LLM

Length of output: 6085


🏁 Script executed:

# Check the actual test method to see if there are assertions or guards
sed -n '969,1040p' tests/unittest/_torch/visual_gen/test_attention_perf.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2989


Add regression guards or a perf baseline assertion to catch SageAttention slowdowns.

The test_sage_vs_vanilla_quick and test_sage_vs_trtllm_wan_shapes tests (also at lines 1010–1038) print speedup metrics but do not assert minimum thresholds. The docstring explicitly states "Does not assert a minimum speedup." If SageAttention becomes materially slower, these tests pass anyway, allowing regressions to go undetected.

While the test file is listed in tests/integration/test_lists/test-db/l0_b200.yml, the lack of threshold enforcement means performance regressions slip through CI. Add either:

  1. A loose minimum speedup assertion (e.g., assert speedup_vs_vanilla > 0.9, f"slowdown: {speedup_vs_vanilla:.2f}x"), or
  2. A perf-sanity baseline in tests/integration/defs/perf/test_perf_sanity.py that records and trend-checks latencies.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/visual_gen/test_attention_perf.py` around lines 969 -
999, Add a loose performance assertion to prevent silent regressions: in
test_sage_vs_vanilla_quick (and similarly in test_sage_vs_trtllm_wan_shapes)
after computing speedup_vs_vanilla and speedup_vs_trtllm, add an assertion such
as assert speedup_vs_vanilla > 0.9, f"slowdown vs VANILLA:
{speedup_vs_vanilla:.2f}x" (and/or assert speedup_vs_trtllm > 0.9 with a similar
message) so the test fails when SageAttention is materially slower; update the
assertion threshold as needed for CI stability and keep the existing print
statements intact.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45422 [ run ] completed with state SUCCESS. Commit: 2c0a848
/LLM/main/L0_MergeRequest_PR pipeline #35657 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants