diff --git a/examples/vllm_serve/sparse_attn_worker.py b/examples/vllm_serve/sparse_attn_worker.py new file mode 100644 index 00000000000..b2a647192e3 --- /dev/null +++ b/examples/vllm_serve/sparse_attn_worker.py @@ -0,0 +1,295 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom vLLM workers for sparse attention. + +``SparseAttnWorker``: Replaces attention implementations with ModelOpt sparse +variants on each Attention module after model loading. For MHA/GQA models the +impl is replaced entirely; for MLA models (DeepSeek) the prefill methods are +monkey-patched on the existing impl. + +``SparseQuantWorker``: Applies quantization first, then sparse attention via +direct module walk (registry stacking does not work due to ``_DMRegistryCls`` +forward identity check). + +Usage: + SPARSE_ATTN_CFG=SPARSE_SOFTMAX_DEFAULT python vllm_serve_sparse_attn.py \\ + meta-llama/Llama-3.1-8B --enforce-eager +""" + +import fnmatch +import json +import os +from typing import Any + +from fakequant_worker import disable_compilation +try: + from vllm.attention.layer import Attention as VLLMAttention # vllm < 0.16 +except ModuleNotFoundError: + from vllm.model_executor.layers.attention import Attention as VLLMAttention # vllm >= 0.16 + +from vllm.v1.worker.gpu_worker import Worker as BaseWorker + +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import ( + ModelOptSparseAttentionImpl, + patch_mla_impl_for_sparse, +) + +try: + from vllm.model_executor.layers.attention.mla_attention import MLACommonImpl + + _HAS_MLA = True +except ImportError: + _HAS_MLA = False + +# --------------------------------------------------------------------------- +# Configuration from environment variables +# --------------------------------------------------------------------------- + +sparse_config: dict[str, Any] = { + "sparse_cfg": os.environ.get("SPARSE_ATTN_CFG", None), + "calib_config_path": os.environ.get("SPARSE_CALIB_CONFIG_PATH", None), + "skip_softmax_threshold": os.environ.get("SKIP_SOFTMAX_THRESHOLD", None), +} + + +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- + + +_DEFAULT_SPARSE_CFG = { + "sparse_cfg": { + "*attn*": { + "sparsity_n": 2, + "sparsity_m": 4, + "num_sink_tokens": 0, + "dense_window_size": 1, + "enable": True, + }, + "default": {"enable": False}, + }, +} + + +def _build_sparse_config(env_config: dict[str, Any]) -> dict | None: + """Build sparse_cfg dict from env vars.""" + cfg_name = env_config["sparse_cfg"] + if cfg_name is None: + return None + # Try looking up preset from mtsa, fall back to default + cfg = getattr(mtsa, cfg_name, None) + if cfg is not None: + return cfg + # Use built-in default if name matches + if cfg_name in ("SPARSE_SOFTMAX_DEFAULT", "default"): + return _DEFAULT_SPARSE_CFG + raise ValueError( + f"Unknown sparse config: {cfg_name}. Set SPARSE_ATTN_CFG to 'default' or a valid preset name." + ) + + +def _load_sparse_config(path: str) -> dict: + """Load offline calibration config JSON.""" + with open(path) as f: + calib_cfg = json.load(f) + + sparse_cfg = {} + for pattern, layer_cfg in calib_cfg.items(): + if pattern == "calibration": + sparse_cfg[pattern] = layer_cfg + continue + layer_cfg.setdefault("method", "triton_sparse_softmax") + layer_cfg.setdefault("backend", "triton") + layer_cfg.setdefault("enable", True) + sparse_cfg[pattern] = layer_cfg + sparse_cfg["default"] = {"enable": False} + + return {"sparse_cfg": sparse_cfg} + + +def _match_sparse_config(module_name: str, sparse_cfg: dict) -> dict | None: + """Match a module name against sparse_cfg patterns.""" + cfg = sparse_cfg.get("sparse_cfg", sparse_cfg) + for pattern, layer_cfg in cfg.items(): + if pattern in ("default", "calibration"): + continue + if fnmatch.fnmatch(module_name, pattern): + return layer_cfg + return None + + +def _build_sparse_kw(layer_cfg: dict, env_threshold: float | None) -> dict: + """Extract sparse kernel kwargs from a per-layer config dict. + + ``env_threshold`` (from SKIP_SOFTMAX_THRESHOLD) overrides any per-layer + ``skip_softmax_threshold`` when set. + """ + sparse_kw: dict = {} + sparsity_n = layer_cfg.get("sparsity_n", 0) + if sparsity_n > 0: + sparse_kw["sparsity_n"] = sparsity_n + sparse_kw["sparsity_m"] = layer_cfg.get("sparsity_m", 4) + sparse_kw["num_sink_tokens"] = layer_cfg.get("num_sink_tokens", 0) + sparse_kw["dense_window_size"] = layer_cfg.get("dense_window_size", 1) + threshold = layer_cfg.get("skip_softmax_threshold") + if env_threshold is not None: + threshold = env_threshold + if threshold: + sparse_kw["skip_softmax_threshold"] = float(threshold) + return sparse_kw + + +def _replace_attention_impl(worker, config: dict): + """Replace attention impls with ModelOpt sparse variants on all Attention layers. + + Handles both MHA/GQA layers (replace impl entirely) and MLA layers + (monkey-patch prefill methods on the existing impl). + + Shared by SparseAttnWorker and SparseQuantWorker. + """ + if config["calib_config_path"]: + cfg = _load_sparse_config(config["calib_config_path"]) + else: + cfg = _build_sparse_config(config) + + if cfg is None: + return + + env_threshold = config.get("skip_softmax_threshold") + env_threshold = float(env_threshold) if env_threshold is not None else None + + model = worker.model_runner.model + if hasattr(model, "unwrap"): + model = model.unwrap() + + patched_mha = 0 + patched_mla = 0 + # Group layers by their sparse_kw config so we can print a concise summary. + config_groups: dict[tuple, list[str]] = {} + for name, module in model.named_modules(): + if not isinstance(module, VLLMAttention): + continue + + # Match per-layer sparse config using name-based patterns + layer_cfg = _match_sparse_config(name, cfg) + if layer_cfg is None or not layer_cfg.get("enable", True): + continue + + sparse_kw = _build_sparse_kw(layer_cfg, env_threshold) + + # MLA layers: monkey-patch prefill methods (decode unchanged) + if _HAS_MLA and isinstance(module.impl, MLACommonImpl): + patch_mla_impl_for_sparse(module.impl, sparse_kw) + patched_mla += 1 + config_groups.setdefault(tuple(sorted(sparse_kw.items())), []).append(name) + continue + + # MHA/GQA layers: replace impl entirely + old_impl = module.impl + new_impl = ModelOptSparseAttentionImpl( + num_heads=old_impl.num_heads, + head_size=old_impl.head_size, + scale=old_impl.scale, + num_kv_heads=old_impl.num_kv_heads, + alibi_slopes=old_impl.alibi_slopes, + sliding_window=None, # overwritten below + kv_cache_dtype=old_impl.kv_cache_dtype, + logits_soft_cap=old_impl.logits_soft_cap, + attn_type=getattr(old_impl, "attn_type", module.attn_type), + kv_sharing_target_layer_name=getattr(old_impl, "kv_sharing_target_layer_name", None), + ) + # Copy the already-transformed sliding_window tuple directly, + # since __init__ transforms int -> (sw-1, 0) and we can't reverse it. + new_impl.sliding_window = old_impl.sliding_window + # Store per-layer sparse kwargs on the impl for forward() to read + new_impl.sparse_kw = sparse_kw + module.impl = new_impl + patched_mha += 1 + config_groups.setdefault(tuple(sorted(sparse_kw.items())), []).append(name) + + total = patched_mha + patched_mla + parts = [] + if patched_mha: + parts.append(f"{patched_mha} MHA/GQA") + if patched_mla: + parts.append(f"{patched_mla} MLA") + detail = " + ".join(parts) if parts else "0" + print( + f"[ModelOpt] Sparse attention: configured {total} attention layers ({detail})", + flush=True, + ) + for kw_items, layer_names in config_groups.items(): + kw = dict(kw_items) + variants = [] + if "sparsity_n" in kw: + variants.append(f"{kw['sparsity_n']}:{kw['sparsity_m']} N:M") + if "skip_softmax_threshold" in kw: + variants.append(f"skip_softmax(threshold={kw['skip_softmax_threshold']})") + variant_str = " + ".join(variants) if variants else "dense (no sparse kernel)" + print( + f"[ModelOpt] kernel={variant_str} " + f"sparse_kw={kw} layers={len(layer_names)} " + f"(e.g. {layer_names[0]})", + flush=True, + ) + + +# --------------------------------------------------------------------------- +# Workers +# --------------------------------------------------------------------------- + + +class SparseAttnWorker(BaseWorker): + """vLLM worker that uses the ModelOpt sparse attention backend. + + Replaces FlashAttentionImpl with ModelOptSparseAttentionImpl on each + Attention module right after model loading — before any forward pass + (including determine_available_memory profiling). + """ + + def load_model(self, *args, **kwargs) -> None: + """Load model, then replace attention impl with sparse variant.""" + super().load_model(*args, **kwargs) + _replace_attention_impl(self, sparse_config) + + +class SparseQuantWorker(BaseWorker): + """vLLM worker that applies quantization + sparse attention. + + Quantization uses the standard registry-based ``mtq.quantize()``. + Sparse attention replaces FlashAttentionImpl with ModelOptSparseAttentionImpl + (same approach as SparseAttnWorker). + """ + + def load_model(self, *args, **kwargs) -> None: + """Load model, then replace attention impl with sparse variant.""" + super().load_model(*args, **kwargs) + _replace_attention_impl(self, sparse_config) + + def compile_or_warm_up_model(self) -> None: + """Apply quantization before warm-up.""" + from fakequant_worker import _fakequant_run_prolog_worker, quant_config + + model = self.model_runner.model + if hasattr(model, "unwrap"): + model = model.unwrap() + + with disable_compilation(model): + if quant_config["quant_cfg"] or quant_config["kv_quant_cfg"]: + _fakequant_run_prolog_worker(self) + + super().compile_or_warm_up_model() diff --git a/examples/vllm_serve/vllm_serve_sparse_attn.py b/examples/vllm_serve/vllm_serve_sparse_attn.py new file mode 100644 index 00000000000..2d73c92ec83 --- /dev/null +++ b/examples/vllm_serve/vllm_serve_sparse_attn.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Launch vLLM with sparse attention. + +Usage: + SPARSE_ATTN_CFG=SPARSE_SOFTMAX_DEFAULT python vllm_serve_sparse_attn.py \\ + meta-llama/Llama-3.1-8B --max-model-len 8192 + +2:4 sparsity + skip-softmax (configurable threshold): + SPARSE_ATTN_CFG=SPARSE_SOFTMAX_SKIP_DEFAULT \\ + SKIP_SOFTMAX_THRESHOLD=0.05 \\ + python vllm_serve_sparse_attn.py meta-llama/Llama-3.1-8B + +Combined with quantization: + QUANT_CFG=INT8_SMOOTHQUANT_CFG SPARSE_ATTN_CFG=SPARSE_SOFTMAX_DEFAULT \\ + python vllm_serve_sparse_attn.py meta-llama/Llama-3.1-8B +""" + +import os +import sys +from pathlib import Path + +import uvloop +import vllm +from packaging import version +from vllm.entrypoints.openai.api_server import run_server +from vllm.entrypoints.openai.cli_args import make_arg_parser + +vllm_version = version.parse(vllm.__version__) +if vllm_version <= version.parse("0.11.0"): + from vllm.utils import FlexibleArgumentParser +else: + from vllm.utils.argparse_utils import FlexibleArgumentParser + +# Pass sparse attention env vars to ray workers (if supported by this vLLM version) +additional_env_vars = { + "SPARSE_ATTN_CFG", + "SPARSE_CALIB_CONFIG_PATH", + "SKIP_SOFTMAX_THRESHOLD", + "QUANT_DATASET", + "QUANT_CALIB_SIZE", + "QUANT_CFG", + "AMAX_FILE_PATH", + "KV_QUANT_CFG", +} + +try: + if vllm_version <= version.parse("0.11.0"): + from vllm.executor.ray_distributed_executor import RayDistributedExecutor + else: + from vllm.v1.executor.ray_executor import RayDistributedExecutor + if hasattr(RayDistributedExecutor, "ADDITIONAL_ENV_VARS"): + RayDistributedExecutor.ADDITIONAL_ENV_VARS.update(additional_env_vars) +except ImportError: + pass # Ray not installed, single-node only + + +def main(): + """Launch vLLM with sparse attention worker.""" + parser = FlexibleArgumentParser(description="vLLM model server with sparse attention") + parser.add_argument("model", type=str, help="The path or name of the model to serve") + parser = make_arg_parser(parser) + + # Ensure workers can import our custom worker module + repo_root = str(Path(__file__).resolve().parent) + if repo_root not in sys.path: + sys.path.insert(0, repo_root) + os.environ["PYTHONPATH"] = os.environ.get("PYTHONPATH", "") + ":" + f"{repo_root}" + + # Select worker based on env vars + has_quant = os.environ.get("QUANT_CFG") or os.environ.get("KV_QUANT_CFG") + has_sparse = os.environ.get("SPARSE_ATTN_CFG") or os.environ.get("SPARSE_CALIB_CONFIG_PATH") + + if has_quant and has_sparse: + worker_cls = "sparse_attn_worker.SparseQuantWorker" + elif has_sparse: + worker_cls = "sparse_attn_worker.SparseAttnWorker" + else: + print("Warning: No SPARSE_ATTN_CFG or QUANT_CFG set. Running standard vLLM.") + worker_cls = None + + if has_sparse: + print( + "[ModelOpt] Sparse attention enabled: " + f"SPARSE_ATTN_CFG={os.environ.get('SPARSE_ATTN_CFG')} " + f"SPARSE_CALIB_CONFIG_PATH={os.environ.get('SPARSE_CALIB_CONFIG_PATH')} " + f"SKIP_SOFTMAX_THRESHOLD={os.environ.get('SKIP_SOFTMAX_THRESHOLD')} " + f"worker_cls={worker_cls}", + flush=True, + ) + + if worker_cls: + parser.set_defaults(worker_cls=worker_cls) + + args = parser.parse_args() + uvloop.run(run_server(args)) + + +if __name__ == "__main__": + main() diff --git a/modelopt/torch/kernels/common/attention/triton_fa.py b/modelopt/torch/kernels/common/attention/triton_fa.py index a4b3cc90e32..662fe430ec8 100644 --- a/modelopt/torch/kernels/common/attention/triton_fa.py +++ b/modelopt/torch/kernels/common/attention/triton_fa.py @@ -80,6 +80,95 @@ def _load_sparsity_helpers() -> None: _FWD_CONFIGS = [triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=1, num_warps=4)] +# --------------------------------------------------------------------------- +# Paged KV cache helpers +# --------------------------------------------------------------------------- +@triton.jit +def _load_paged_k_tile( + K_cache, # [num_blocks, page_size, num_kv_heads, head_dim] + Block_table, # [batch, max_blocks_per_seq] + batch_idx, + kv_head_idx, + kv_start, + kv_pos, # [BLOCK_N] relative positions + dim_pos, # [BLOCK_D] + seq_len_kv, + stride_kc_block, + stride_kc_pos, + stride_kc_head, + PAGE_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + HEAD_DIM: tl.constexpr, + max_blocks_per_seq, +): + """Load K^T tile [BLOCK_D, BLOCK_N] from paged KV cache.""" + d_mask = dim_pos < HEAD_DIM + kv_abs = kv_start + kv_pos # absolute token positions + kv_valid = kv_abs < seq_len_kv + + # Translate token positions -> (page_id, offset_in_page) + page_local = kv_abs // PAGE_SIZE + offset_in_page = kv_abs % PAGE_SIZE + page_global = tl.load( + Block_table + batch_idx * max_blocks_per_seq + page_local, + mask=kv_valid, + other=0, + ) + + # Load K values: K_cache[page_global, offset_in_page, kv_head_idx, dim] + # K^T layout [BLOCK_D, BLOCK_N] for Q @ K^T matmul + k_ptrs = ( + page_global[None, :] * stride_kc_block + + offset_in_page[None, :] * stride_kc_pos + + kv_head_idx * stride_kc_head + + dim_pos[:, None] + ) + return tl.load(K_cache + k_ptrs, mask=kv_valid[None, :] & d_mask[:, None], other=0.0) + + +@triton.jit +def _load_paged_v_tile( + V_cache, # [num_blocks, page_size, num_kv_heads, head_dim] + Block_table, # [batch, max_blocks_per_seq] + batch_idx, + kv_head_idx, + kv_start, + kv_pos, # [BLOCK_N] relative positions + dim_pos, # [BLOCK_D] + seq_len_kv, + stride_vc_block, + stride_vc_pos, + stride_vc_head, + PAGE_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + HEAD_DIM: tl.constexpr, + max_blocks_per_seq, +): + """Load V tile [BLOCK_N, BLOCK_D] from paged KV cache.""" + d_mask = dim_pos < HEAD_DIM + kv_abs = kv_start + kv_pos + kv_valid = kv_abs < seq_len_kv + + page_local = kv_abs // PAGE_SIZE + offset_in_page = kv_abs % PAGE_SIZE + page_global = tl.load( + Block_table + batch_idx * max_blocks_per_seq + page_local, + mask=kv_valid, + other=0, + ) + + # V layout [BLOCK_N, BLOCK_D] + v_ptrs = ( + page_global[:, None] * stride_vc_block + + offset_in_page[:, None] * stride_vc_pos + + kv_head_idx * stride_vc_head + + dim_pos[None, :] + ) + return tl.load(V_cache + v_ptrs, mask=kv_valid[:, None] & d_mask[None, :], other=0.0) + + # --------------------------------------------------------------------------- # Masking helper # --------------------------------------------------------------------------- @@ -116,6 +205,7 @@ def _attn_fwd( K, # [total_kv, num_kv_heads, head_dim] key tensor V, # [total_kv, num_kv_heads, head_dim] value tensor qk_scale, # softmax_scale * log2(e) + sm_scale, # softmax_scale (1/sqrt(head_dim) by default); used by scale-factor skip-softmax b_start_loc, # [batch] start offset of each Q sequence b_seq_len, # [batch] length of each Q sequence b_start_loc_k, # [batch] start offset of each KV sequence @@ -145,10 +235,24 @@ def _attn_fwd( NUM_SINK_TOKENS: tl.constexpr = 0, # KV positions before this are kept dense (attention sinks) DENSE_WINDOW_SIZE: tl.constexpr = 64, # Tokens near diagonal kept dense (absolute, BLOCK_N-independent) APPLY_SKIP_SOFTMAX: tl.constexpr = False, # Skip KV tiles with negligible scores - SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, # log2(lambda) * sm_scale, pre-scaled for comparison on scaled scores + SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, # Fixed mode: log2(lambda) * sm_scale, pre-scaled for scaled scores + USE_SKIP_SCALE_FACTOR: tl.constexpr = False, # Scale-factor mode: threshold = scale_factor / seq_k per sequence + SKIP_THRESHOLD_SCALE_LOG2: tl.constexpr = 0.0, # Scale-factor mode: log2(scale_factor) * sm_scale Sparsity_total=None, # Optional int64 scalar for counting total tiles (atomic) Sparsity_skipped=None, # Optional int64 scalar for counting skipped tiles (atomic) MEASURE_SPARSITY: tl.constexpr = False, # When True, count total/skipped tiles via atomic adds + IS_PAGED: tl.constexpr = False, # Whether K/V are in paged cache + K_cache=None, # [num_blocks, page_size, num_kv_heads, head_dim] paged K + V_cache=None, # [num_blocks, page_size, num_kv_heads, head_dim] paged V + Block_table=None, # [batch, max_blocks_per_seq] page table + stride_kc_block=0, + stride_kc_pos=0, + stride_kc_head=0, + stride_vc_block=0, + stride_vc_pos=0, + stride_vc_head=0, + PAGE_SIZE: tl.constexpr = 16, + max_blocks_per_seq=0, ): # --- Grid: (batch, num_q_heads, num_q_tiles) --- # Example: batch=2, num_q_heads=32, seq_len=256, BLOCK_M=128 @@ -168,6 +272,19 @@ def _attn_fwd( if tile_q * BLOCK_M >= seq_len_q: return # This Q tile is past the sequence end + # --- Per-program effective skip-softmax threshold (pre-scaled log2 space) --- + # Two modes share the same comparison: tile_row_max < (row_max + skip_threshold_log2). + # - Fixed mode: threshold = lambda (constant). Pre-computed in Python. + # - Scale-factor mode: threshold = scale_factor / seq_k_per_sequence. Subtract + # log2(seq_k) * sm_scale from the pre-computed log2(scale_factor) * sm_scale. + if APPLY_SKIP_SOFTMAX: + if USE_SKIP_SCALE_FACTOR: + skip_threshold_log2 = SKIP_THRESHOLD_SCALE_LOG2 - tl.log2( + seq_len_kv.to(tl.float32) + ) * sm_scale + else: + skip_threshold_log2 = SKIP_THRESHOLD_LOG2 + # --- Tile position indices --- q_pos = tile_q * BLOCK_M + tl.arange(0, BLOCK_M) # Absolute Q token positions kv_pos = tl.arange(0, BLOCK_N) # Relative KV positions within a tile @@ -195,12 +312,32 @@ def _attn_fwd( kv_start = tl.multiple_of(kv_start, BLOCK_N) # Compiler hint for alignment # Load K^T [BLOCK_D, BLOCK_N] (transposed layout for Q @ K^T matmul) - k_offs = (kv_offset + kv_start + kv_pos[None, :]) * stride_kbs + dim_pos[:, None] - k = tl.load( - k_base + k_offs, - mask=((kv_start + kv_pos[None, :]) < seq_len_kv) & d_mask[:, None], - other=0.0, - ) + if IS_PAGED: + k = _load_paged_k_tile( + K_cache, + Block_table, + batch_idx, + kv_head_idx, + kv_start, + kv_pos, + dim_pos, + seq_len_kv, + stride_kc_block, + stride_kc_pos, + stride_kc_head, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + HEAD_DIM, + max_blocks_per_seq, + ) + else: + k_offs = (kv_offset + kv_start + kv_pos[None, :]) * stride_kbs + dim_pos[:, None] + k = tl.load( + k_base + k_offs, + mask=((kv_start + kv_pos[None, :]) < seq_len_kv) & d_mask[:, None], + other=0.0, + ) # scores = Q @ K^T * scale [BLOCK_M, BLOCK_N] scores = tl.dot(q, k) * qk_scale @@ -229,7 +366,7 @@ def _attn_fwd( skip_tile = _skip_softmax_decision( scores, row_max, - SKIP_THRESHOLD_LOG2, + skip_threshold_log2, Sparsity_total, Sparsity_skipped, MEASURE_SPARSITY, @@ -245,12 +382,32 @@ def _attn_fwd( acc = acc * correction[:, None] # Load V and accumulate - v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] - v = tl.load( - v_base + v_offs, - mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], - other=0.0, - ) + if IS_PAGED: + v = _load_paged_v_tile( + V_cache, + Block_table, + batch_idx, + kv_head_idx, + kv_start, + kv_pos, + dim_pos, + seq_len_kv, + stride_vc_block, + stride_vc_pos, + stride_vc_head, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + HEAD_DIM, + max_blocks_per_seq, + ) + else: + v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] + v = tl.load( + v_base + v_offs, + mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], + other=0.0, + ) acc = tl.dot(p.to(v.dtype), v, acc) row_max = m_new # else: tile skipped — no softmax, no V load, no BMM2 for this tile @@ -358,6 +515,8 @@ def _attn_bwd_dq( DENSE_WINDOW_SIZE: tl.constexpr = 64, APPLY_SKIP_SOFTMAX: tl.constexpr = False, SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, + USE_SKIP_SCALE_FACTOR: tl.constexpr = False, + SKIP_THRESHOLD_SCALE_LOG2: tl.constexpr = 0.0, ): """Phase 3 of backward: compute dQ for one Q tile, looping over KV tiles. @@ -402,6 +561,15 @@ def _attn_bwd_dq( dq = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32) kv_bound = seq_len_kv if not IS_CAUSAL else tl.minimum((tile_q + 1) * BLOCK_M, seq_len_kv) + # Per-program effective skip-softmax threshold (see forward kernel for derivation). + if APPLY_SKIP_SOFTMAX: + if USE_SKIP_SCALE_FACTOR: + skip_threshold_log2 = SKIP_THRESHOLD_SCALE_LOG2 - tl.log2( + seq_len_kv.to(tl.float32) + ) * sm_scale + else: + skip_threshold_log2 = SKIP_THRESHOLD_LOG2 + # --- Loop over KV tiles: recompute S, then compute dQ contribution --- for kv_start in range(0, kv_bound, BLOCK_N): kv_mask = (kv_start + kv_pos) < seq_len_kv @@ -448,7 +616,7 @@ def _attn_bwd_dq( # max, so this conservatively zeros out at least what forward skipped. if APPLY_SKIP_SOFTMAX: tile_row_max = tl.max(scores, 1) - can_skip = tile_row_max < (lse + SKIP_THRESHOLD_LOG2) + can_skip = tile_row_max < (lse + skip_threshold_log2) p = tl.where(can_skip[:, None], 0.0, p) # dP = dO @ V^T, dS = P * (dP - delta), dQ += dS @ K @@ -504,6 +672,8 @@ def _attn_bwd_dkdv( DENSE_WINDOW_SIZE: tl.constexpr = 64, APPLY_SKIP_SOFTMAX: tl.constexpr = False, SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, + USE_SKIP_SCALE_FACTOR: tl.constexpr = False, + SKIP_THRESHOLD_SCALE_LOG2: tl.constexpr = 0.0, ): """Phase 2 of backward: compute dK, dV for one KV tile. @@ -547,6 +717,15 @@ def _attn_bwd_dkdv( dk = tl.zeros([BLOCK_N, BLOCK_D], dtype=tl.float32) dv = tl.zeros([BLOCK_N, BLOCK_D], dtype=tl.float32) + # Per-program effective skip-softmax threshold (see forward kernel for derivation). + if APPLY_SKIP_SOFTMAX: + if USE_SKIP_SCALE_FACTOR: + skip_threshold_log2 = SKIP_THRESHOLD_SCALE_LOG2 - tl.log2( + seq_len_kv.to(tl.float32) + ) * sm_scale + else: + skip_threshold_log2 = SKIP_THRESHOLD_LOG2 + n_q_tiles = (seq_len_q + BLOCK_M - 1) // BLOCK_M # Causal: Q position i attends to KV 0..i, so this KV tile (at kv_start) # only receives gradients from Q tiles where q_pos >= kv_start. Skip earlier ones. @@ -602,7 +781,7 @@ def _attn_bwd_dkdv( # max, so this conservatively zeros out at least what forward skipped. if APPLY_SKIP_SOFTMAX: tile_row_max = tl.max(scores, 1) - can_skip = tile_row_max < (lse + SKIP_THRESHOLD_LOG2) + can_skip = tile_row_max < (lse + skip_threshold_log2) p = tl.where(can_skip[:, None], 0.0, p) # dV += P^T @ dO @@ -643,6 +822,12 @@ def forward( skip_softmax_threshold, skip_softmax_raw_threshold, measure_sparsity, + skip_softmax_threshold_scale_prefill, + skip_softmax_threshold_scale_decode, + k_cache, + v_cache, + block_table, + page_size, ): HEAD_DIM = q.shape[2] num_q_heads = q.shape[1] @@ -650,6 +835,8 @@ def forward( kv_group_num = num_q_heads // num_kv_heads batch = b_seq_len.shape[0] + is_paged = k_cache is not None + # Prefill: Q/K/V are the same packed tensor, reuse Q offsets for K/V. # Decode: K/V is a separate KV cache tensor, caller must pass explicit metadata. if b_seq_len_k is None: @@ -657,29 +844,67 @@ def forward( b_start_loc_k = b_start_loc max_input_len_k = max_input_len + # Paged mode: b_start_loc_k may be None (KV is in paged cache, not contiguous). + # Provide a dummy tensor so Triton can compile the tl.load (it won't be used). + if b_start_loc_k is None: + b_start_loc_k = torch.zeros_like(b_start_loc) + # Pre-multiply scale by log2(e) so the kernel can use exp2() # exp(score * sm_scale) = exp2(score * sm_scale * log2(e)) qk_scale = sm_scale * LOG2E # Triton tiles must be powers of 2; pad head dim BLOCK_D = triton.next_power_of_2(HEAD_DIM) - # Skip-softmax threshold in scaled log2 space for the kernel. - # Two modes: - # 1. raw_threshold: passed directly as skip_threshold_log2 (for testing) - # 2. lambda threshold: converted via log2(lambda) * sm_scale - if skip_softmax_raw_threshold is not None: - apply_skip = True + # Skip-softmax: convert threshold to scaled log2 space for the kernel. + # The BLASST reference (https://arxiv.org/pdf/2512.12087) checks + # ln(lambda) on unscaled scores. Our kernel works in log2-scaled space + # (scores pre-multiplied by qk_scale = sm_scale * LOG2E), so we + # pre-scale: threshold_scaled = log2(lambda) * sm_scale. + # + # Three modes (mutually exclusive, listed by precedence): + # - Raw: skip_softmax_raw_threshold passed directly as the kernel's + # ``skip_threshold_log2`` (for testing). + # - Fixed: skip_softmax_threshold = lambda (a constant). + # - Scale-factor: per-phase scale_factor; effective threshold is + # scale_factor / seq_k_per_sequence, computed inside the kernel. + # Phase is inferred here from max_input_len (decode iff == 1), + # matching the FlashSkipSoftmax PyTorch convention. + raw_mode = skip_softmax_raw_threshold is not None + fixed_mode = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0 + scale_prefill_set = ( + skip_softmax_threshold_scale_prefill is not None + and skip_softmax_threshold_scale_prefill > 0.0 + ) + scale_decode_set = ( + skip_softmax_threshold_scale_decode is not None + and skip_softmax_threshold_scale_decode > 0.0 + ) + assert not (fixed_mode and (scale_prefill_set or scale_decode_set)), ( + "skip_softmax_threshold (fixed mode) is mutually exclusive with " + "skip_softmax_threshold_scale_{prefill,decode} (scale-factor mode)." + ) + + is_decode = max_input_len == 1 + active_scale = ( + skip_softmax_threshold_scale_decode if is_decode + else skip_softmax_threshold_scale_prefill + ) + use_scale_factor = active_scale is not None and active_scale > 0.0 + apply_skip = raw_mode or fixed_mode or use_scale_factor + + if raw_mode: skip_threshold_log2 = skip_softmax_raw_threshold - elif skip_softmax_threshold is not None and skip_softmax_threshold > 0.0: - apply_skip = True - # The BLASST reference (https://arxiv.org/pdf/2512.12087) checks - # ln(lambda) on unscaled scores. Our kernel works in log2-scaled space - # (scores pre-multiplied by qk_scale = sm_scale * LOG2E), so we - # pre-scale: threshold_scaled = log2(lambda) * sm_scale. + skip_threshold_scale_log2 = 0.0 + elif fixed_mode: skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale + skip_threshold_scale_log2 = 0.0 + elif use_scale_factor: + skip_threshold_log2 = 0.0 + skip_threshold_scale_log2 = math.log2(active_scale) * sm_scale else: apply_skip = False skip_threshold_log2 = 0.0 + skip_threshold_scale_log2 = 0.0 o = torch.empty_like(q) lse = torch.empty(q.shape[0], num_q_heads, device=q.device, dtype=torch.float32) @@ -702,6 +927,7 @@ def grid(META): k, v, qk_scale, + sm_scale, b_start_loc, b_seq_len, b_start_loc_k, @@ -730,9 +956,23 @@ def grid(META): DENSE_WINDOW_SIZE=dense_window_size, APPLY_SKIP_SOFTMAX=apply_skip, SKIP_THRESHOLD_LOG2=skip_threshold_log2, + USE_SKIP_SCALE_FACTOR=use_scale_factor, + SKIP_THRESHOLD_SCALE_LOG2=skip_threshold_scale_log2, Sparsity_total=sparsity_total, Sparsity_skipped=sparsity_skipped, MEASURE_SPARSITY=do_measure, + IS_PAGED=is_paged, + K_cache=k_cache, + V_cache=v_cache, + Block_table=block_table, + stride_kc_block=k_cache.stride(0) if is_paged else 0, + stride_kc_pos=k_cache.stride(1) if is_paged else 0, + stride_kc_head=k_cache.stride(2) if is_paged else 0, + stride_vc_block=v_cache.stride(0) if is_paged else 0, + stride_vc_pos=v_cache.stride(1) if is_paged else 0, + stride_vc_head=v_cache.stride(2) if is_paged else 0, + PAGE_SIZE=page_size, + max_blocks_per_seq=block_table.shape[1] if is_paged else 0, # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune ) @@ -758,6 +998,8 @@ def grid(META): ctx.dense_window_size = dense_window_size ctx.apply_skip = apply_skip ctx.skip_threshold_log2 = skip_threshold_log2 + ctx.use_skip_scale_factor = use_scale_factor + ctx.skip_threshold_scale_log2 = skip_threshold_scale_log2 return o @staticmethod @@ -838,6 +1080,8 @@ def backward(ctx, grad_output): DENSE_WINDOW_SIZE=ctx.dense_window_size, APPLY_SKIP_SOFTMAX=ctx.apply_skip, SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2, + USE_SKIP_SCALE_FACTOR=ctx.use_skip_scale_factor, + SKIP_THRESHOLD_SCALE_LOG2=ctx.skip_threshold_scale_log2, num_warps=num_warps, num_stages=1, ) @@ -863,6 +1107,8 @@ def backward(ctx, grad_output): DENSE_WINDOW_SIZE=ctx.dense_window_size, APPLY_SKIP_SOFTMAX=ctx.apply_skip, SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2, + USE_SKIP_SCALE_FACTOR=ctx.use_skip_scale_factor, + SKIP_THRESHOLD_SCALE_LOG2=ctx.skip_threshold_scale_log2, num_warps=num_warps, num_stages=1, ) @@ -871,21 +1117,27 @@ def backward(ctx, grad_output): dq, dk, dv, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, + None, # b_start_loc + None, # b_seq_len + None, # max_input_len + None, # is_causal + None, # sm_scale + None, # b_start_loc_k + None, # b_seq_len_k + None, # max_input_len_k + None, # sparsity_n + None, # sparsity_m + None, # num_sink_tokens + None, # dense_window_size + None, # skip_softmax_threshold + None, # skip_softmax_raw_threshold + None, # measure_sparsity + None, # skip_softmax_threshold_scale_prefill + None, # skip_softmax_threshold_scale_decode + None, # k_cache + None, # v_cache + None, # block_table + None, # page_size ) @@ -909,8 +1161,14 @@ def attention( skip_softmax_threshold: float | None = None, skip_softmax_raw_threshold: float | None = None, measure_sparsity: bool = False, + skip_softmax_threshold_scale_prefill: float | None = None, + skip_softmax_threshold_scale_decode: float | None = None, + k_cache: torch.Tensor | None = None, + v_cache: torch.Tensor | None = None, + block_table: torch.Tensor | None = None, + page_size: int = 16, ) -> torch.Tensor: - """Variable-length flash attention with GQA, autograd, and optional N:M sparse softmax and skip-softmax. + """Variable-length flash attention with GQA, autograd, optional sparsity, and paged KV. Args: q: [total_q_tokens, num_q_heads, head_dim] @@ -933,7 +1191,7 @@ def attention( (attention sinks). Absolute token count, BLOCK_N-independent. dense_window_size: Tokens near the query diagonal kept dense (local attention window). Absolute token count, BLOCK_N-independent. - Default 64 (one reference block). + Default 64 tokens. skip_softmax_threshold: BLASST threshold lambda (https://arxiv.org/pdf/2512.12087). Skip KV tiles where ``exp(tile_max - running_max) < lambda``, meaning the tile's @@ -950,6 +1208,23 @@ def attention( and skipped tiles via atomic counters. The counts are stored as ``_sparsity_total`` and ``_sparsity_skipped`` attributes on the returned output tensor. + skip_softmax_threshold_scale_prefill: Scale-factor mode (prefill). + When set, the effective skip threshold is computed per sequence + as ``scale_factor / seq_k`` (matches the FlashSkipSoftmax PyTorch + calibrated path: ``a * exp(b * target_sparsity) / seqlen``). Used + only when ``max_input_len > 1``. Mutually exclusive with + ``skip_softmax_threshold``. + skip_softmax_threshold_scale_decode: Scale-factor mode (decode). + Same semantics as ``skip_softmax_threshold_scale_prefill`` but + applied when ``max_input_len == 1`` (decode). Mutually exclusive + with ``skip_softmax_threshold``. + k_cache: Paged K cache [num_blocks, page_size, num_kv_heads, head_dim]. + When provided, K/V are read from paged cache via block_table + instead of from contiguous k/v tensors. + v_cache: Paged V cache [num_blocks, page_size, num_kv_heads, head_dim]. + block_table: Page table [batch, max_blocks_per_seq] mapping sequence + block indices to global page IDs. + page_size: Number of tokens per page in the KV cache. Returns: Output tensor [total_q_tokens, num_q_heads, head_dim]. @@ -975,7 +1250,152 @@ def attention( skip_softmax_threshold, skip_softmax_raw_threshold, measure_sparsity, + skip_softmax_threshold_scale_prefill, + skip_softmax_threshold_scale_decode, + k_cache, + v_cache, + block_table, + page_size, ) -__all__ = ["LOG2E", "_apply_mask", "attention"] +def attention_with_lse( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + b_start_loc: torch.Tensor, + b_seq_len: torch.Tensor, + max_input_len: int, + is_causal: bool = True, + softmax_scale: float | None = None, + b_start_loc_k: torch.Tensor | None = None, + b_seq_len_k: torch.Tensor | None = None, + max_input_len_k: int | None = None, + *, + sparsity_n: int = 0, + sparsity_m: int = 4, + num_sink_tokens: int = 0, + dense_window_size: int = 64, + skip_softmax_threshold: float | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Variable-length flash attention returning both output and LSE. + + Same interface as :func:`attention` but returns ``(output, lse)`` where + *lse* is the log-sum-exp in **natural-log** space with shape + ``[total_q_tokens, num_q_heads]``. Intended for inference workloads that + need LSE for attention-state merging (e.g. MLA chunked prefill). + + This function does **not** support paged KV cache or autograd — use the + contiguous Q/K/V path only. + + Args: + q: [total_q_tokens, num_q_heads, head_dim] + k: [total_kv_tokens, num_kv_heads, head_dim] + v: [total_kv_tokens, num_kv_heads, head_dim] + b_start_loc: [batch] start offset of each Q sequence in the flat tensor. + b_seq_len: [batch] length of each Q sequence. + max_input_len: Maximum Q sequence length (for grid sizing). + is_causal: Whether to apply causal masking. + softmax_scale: Scale factor (default: 1/sqrt(head_dim)). + b_start_loc_k: [batch] start offset for K/V (None = same as Q). + b_seq_len_k: [batch] length for K/V (None = same as Q). + max_input_len_k: Maximum K/V sequence length (None = same as Q). + sparsity_n: N:M sparsity — keep top-N of every M attention scores. + sparsity_m: N:M sparsity — group size (4 or 8). + num_sink_tokens: KV positions before this index are kept dense. + dense_window_size: Tokens near the query diagonal kept dense. + skip_softmax_threshold: BLASST threshold lambda. + + Returns: + (output, lse): + output: [total_q_tokens, num_q_heads, head_dim] + lse: [total_q_tokens, num_q_heads] in natural-log space + """ + sm_scale = 1.0 / (q.shape[2] ** 0.5) if softmax_scale is None else softmax_scale + + HEAD_DIM = q.shape[2] + num_q_heads = q.shape[1] + num_kv_heads = k.shape[1] + kv_group_num = num_q_heads // num_kv_heads + batch = b_seq_len.shape[0] + + if b_seq_len_k is None: + b_seq_len_k = b_seq_len + b_start_loc_k = b_start_loc + + if b_start_loc_k is None: + b_start_loc_k = torch.zeros_like(b_start_loc) + + qk_scale = sm_scale * LOG2E + BLOCK_D = triton.next_power_of_2(HEAD_DIM) + + apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0 + if apply_skip: + assert skip_softmax_threshold is not None + skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale + else: + skip_threshold_log2 = 0.0 + + o = torch.empty_like(q) + lse = torch.empty(q.shape[0], num_q_heads, device=q.device, dtype=torch.float32) + + def grid(META): + return (batch, num_q_heads, triton.cdiv(max_input_len, META["BLOCK_M"])) + + _attn_fwd[grid]( + q, + k, + v, + qk_scale, + sm_scale, + b_start_loc, + b_seq_len, + b_start_loc_k, + b_seq_len_k, + o, + lse, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + o.stride(0), + o.stride(1), + lse.stride(0), + lse.stride(1), + N_CTX=max_input_len, + kv_group_num=kv_group_num, + BLOCK_D=BLOCK_D, + IS_CAUSAL=is_causal, + HEAD_DIM=HEAD_DIM, + STORE_LSE=True, + SPARSITY_N=sparsity_n, + SPARSITY_M=sparsity_m, + NUM_SINK_TOKENS=num_sink_tokens, + DENSE_WINDOW_SIZE=dense_window_size, + APPLY_SKIP_SOFTMAX=apply_skip, + SKIP_THRESHOLD_LOG2=skip_threshold_log2, + IS_PAGED=False, + K_cache=None, + V_cache=None, + Block_table=None, + stride_kc_block=0, + stride_kc_pos=0, + stride_kc_head=0, + stride_vc_block=0, + stride_vc_pos=0, + stride_vc_head=0, + PAGE_SIZE=16, + max_blocks_per_seq=0, + ) + + # Convert LSE from log2 space to natural-log space. + # Kernel stores: row_max + log2(row_sum) where row_max is in log2-scaled space. + # Standard LSE = stored_lse * ln(2). + lse.mul_(math.log(2)) + + return o, lse + + +__all__ = ["LOG2E", "_apply_mask", "attention", "attention_with_lse"] diff --git a/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py b/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py index f066f9c4b7d..9d534eec4c8 100644 --- a/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py +++ b/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py @@ -142,7 +142,7 @@ def _apply_sparse_nm_to_qk_tile( def _skip_softmax_decision( scores, row_max, - SKIP_THRESHOLD_LOG2: tl.constexpr, + skip_threshold_log2, Sparsity_total, Sparsity_skipped, MEASURE_SPARSITY: tl.constexpr, @@ -167,7 +167,7 @@ def _skip_softmax_decision( """ tile_row_max = tl.max(scores, 1) # [BLOCK_M] — ~m_i^(j) (scaled) # Per-row: True if row's tile max is negligible vs running max - can_skip = tile_row_max < (row_max + SKIP_THRESHOLD_LOG2) + can_skip = tile_row_max < (row_max + skip_threshold_log2) # Per-tile: skip entire tile only if ALL rows are negligible skip_tile = tl.min(can_skip.to(tl.int32)) == 1 diff --git a/modelopt/torch/quantization/plugins/vllm.py b/modelopt/torch/quantization/plugins/vllm.py index 95ca3240b73..346172f82f6 100644 --- a/modelopt/torch/quantization/plugins/vllm.py +++ b/modelopt/torch/quantization/plugins/vllm.py @@ -24,8 +24,22 @@ import torch + +def _vllm_module_spec_exists(name: str) -> bool: + """True if *name* is importable; safe when parent packages are absent. + + ``importlib.util.find_spec`` can raise ``ModuleNotFoundError`` for nested + names (e.g. ``vllm.attention.layers``) when ``vllm.attention`` was removed + in newer vLLM layouts, instead of returning ``None``. + """ + try: + return importlib.util.find_spec(name) is not None + except ModuleNotFoundError: + return False + + # Try multiple import paths for vLLM compatibility across versions -if importlib.util.find_spec("vllm.attention"): +if _vllm_module_spec_exists("vllm.attention"): import vllm.attention as vllm_attention # vllm < 0.16.0 else: import vllm.model_executor.layers.attention as vllm_attention # vllm >= 0.16.0 @@ -50,12 +64,7 @@ except ImportError: continue -try: - _has_attention_layers = importlib.util.find_spec("vllm.attention.layers") is not None -except (ModuleNotFoundError, ValueError): - _has_attention_layers = False - -if _has_attention_layers: # vllm < 0.15.0 +if _vllm_module_spec_exists("vllm.attention.layers"): # vllm < 0.15.0 from vllm.attention.layers.cross_attention import CrossAttention from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention else: @@ -68,12 +77,7 @@ except ImportError: EncoderOnlyAttention = None -try: - _has_attention_layer = importlib.util.find_spec("vllm.attention.layer") is not None -except (ModuleNotFoundError, ValueError): - _has_attention_layer = False - -if _has_attention_layer: +if _vllm_module_spec_exists("vllm.attention.layer"): import vllm.attention.layer as vllm_attention try: diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index eed50b87af1..f9a57d382cb 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -693,7 +693,25 @@ class VSAConfig(SparseAttentionConfig): "sparse_cfg": { "*attn*": { "method": "triton_skip_softmax", - "skip_softmax_threshold": 0.1, + "skip_softmax_threshold": 0.001, + "backend": "triton", + "enable": True, + }, + "default": {"enable": False}, + }, +} + + +# 2:4 sparsity combined with skip-softmax tile skipping (Triton kernel) +SPARSE_SOFTMAX_SKIP_DEFAULT = { + "sparse_cfg": { + "*attn*": { + "method": "triton_sparse_softmax", + "sparsity_n": 2, + "sparsity_m": 4, + "num_sink_tokens": 0, + "dense_window_size": 64, + "skip_softmax_threshold": 0.001, "backend": "triton", "enable": True, }, @@ -707,6 +725,7 @@ class VSAConfig(SparseAttentionConfig): "SKIP_SOFTMAX_DEFAULT", "SKIP_SOFTMAX_TRITON_DEFAULT", "SPARSE_SOFTMAX_DEFAULT", + "SPARSE_SOFTMAX_SKIP_DEFAULT", "VSA_DEFAULT", "CalibrationConfig", "FlashSkipSoftmaxConfig", diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py index 434fc18214b..3a513d52c53 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py @@ -15,6 +15,8 @@ """Plugins for sparse attention integration with various frameworks.""" +from modelopt.torch.utils import import_plugin + # List of model plugins that are called during conversion # Each plugin is a callable that takes (model) and performs validation/setup CUSTOM_MODEL_PLUGINS: list = [] diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py b/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py new file mode 100644 index 00000000000..34724c1953d --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py @@ -0,0 +1,273 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ModelOpt sparse attention backend for vLLM. + +Registers a custom vLLM attention backend that uses the ModelOpt Triton kernel +with paged KV cache support. Integration approach: + +- No module replacement — the Attention module stays intact with all its state +- Only ``impl`` is swapped from FlashAttentionImpl to ModelOptSparseAttentionImpl +- KV cache update is handled by vLLM (inherited ``do_kv_cache_update``) +- Only ``forward()`` is overridden to call our Triton kernel for both prefill and decode + +For MLA (Multi-Latent Attention) models like DeepSeek, a different strategy is used: +the MLA impl's prefill methods are monkey-patched to call our Triton kernel in +contiguous (non-paged) mode, since MLA decompresses KV latents before attention. +""" + +import importlib.util +import types + +import torch +import torch.nn.functional as F +from vllm.logger import init_logger +from vllm.v1.attention.backends.flash_attn import ( + FlashAttentionBackend, + FlashAttentionImpl, + FlashAttentionMetadata, +) + +from modelopt.torch.kernels.common.attention.triton_fa import attention as triton_attention +from modelopt.torch.kernels.common.attention.triton_fa import attention_with_lse + +_HAS_MLA = ( + importlib.util.find_spec("vllm.model_executor.layers.attention.mla_attention") is not None +) + +logger = init_logger(__name__) + +_LOGGED_CONFIGS: set[tuple] = set() + + +class ModelOptSparseAttentionImpl(FlashAttentionImpl): + """Attention implementation that uses the ModelOpt Triton kernel. + + Inherits from FlashAttentionImpl to reuse: + - __init__ (all configuration) + - do_kv_cache_update (KV cache writing) + Only overrides forward() to replace the attention computation. + """ + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward with ModelOpt Triton sparse attention kernel.""" + assert output is not None, "Output tensor must be provided." + + if attn_metadata is None: + # Profiling run + return output.fill_(0) + + num_actual_tokens = attn_metadata.num_actual_tokens + is_prefill = attn_metadata.max_query_len > 1 + + # Unpack paged KV cache: [2, num_blocks, page_size, num_kv_heads, head_dim] + key_cache, value_cache = kv_cache.unbind(0) + page_size = key_cache.shape[1] + + # Per-layer sparse kwargs (set by _replace_attention_impl in the worker) + sparse_kw = getattr(self, "sparse_kw", {}) + + # Log once per unique sparse_kw config so each variant is reported + # the first time it actually executes (prefill or decode). + kw_key = tuple(sorted(sparse_kw.items())) + if kw_key not in _LOGGED_CONFIGS: + variants = [] + if "sparsity_n" in sparse_kw: + variants.append( + f"{sparse_kw['sparsity_n']}:{sparse_kw['sparsity_m']} N:M" + ) + if "skip_softmax_threshold" in sparse_kw: + variants.append( + f"skip_softmax(threshold={sparse_kw['skip_softmax_threshold']})" + ) + variant_str = " + ".join(variants) if variants else "dense" + msg = ( + f"[ModelOpt] sparse Triton kernel active: kernel={variant_str} " + f"phase={'prefill' if is_prefill else 'decode'} " + f"layer={layer.__class__.__name__} " + f"sparse_kw={sparse_kw}" + ) + logger.warning(msg) + print(msg, flush=True) + _LOGGED_CONFIGS.add(kw_key) + + # Prepare metadata for our kernel + q = query[:num_actual_tokens].contiguous() + cu_seqlens_q = attn_metadata.query_start_loc + seq_lens = attn_metadata.seq_lens + batch = seq_lens.shape[0] + + b_start_loc = cu_seqlens_q[:batch] + b_seq_len = cu_seqlens_q[1 : batch + 1] - cu_seqlens_q[:batch] + + # Dummy K/V for paged mode: not used by the kernel (KV are read from + # k_cache/v_cache via block_table), but shape[1] must be num_kv_heads + # so the kernel computes the correct GQA ratio (num_q_heads // num_kv_heads). + k_dummy = torch.empty(0, self.num_kv_heads, self.head_size, device=q.device, dtype=q.dtype) + + # Call ModelOpt Triton kernel with paged KV. + # b_seq_len is the query length (e.g., 6 for prefill, 1 for decode). + # b_seq_len_k is the total KV length including cache (e.g., 6 for first + # prefill, 7/8/... for subsequent decode steps). + triton_out = triton_attention( + q, + k=k_dummy, + v=k_dummy, + # Query metadata + b_start_loc=b_start_loc, + b_seq_len=b_seq_len, + max_input_len=attn_metadata.max_query_len, + is_causal=is_prefill, # causal for prefill, non-causal for decode + softmax_scale=self.scale, + # KV metadata + b_start_loc_k=None, # paged mode: KV offsets not needed + b_seq_len_k=seq_lens, # total KV length per sequence + max_input_len_k=attn_metadata.max_seq_len, + # Paged KV cache + k_cache=key_cache, # [num_blocks, page_size, num_kv_heads, head_dim] + v_cache=value_cache, # [num_blocks, page_size, num_kv_heads, head_dim] + block_table=attn_metadata.block_table, # [batch, max_blocks] + page_size=page_size, # tokens per page in the KV cache + **sparse_kw, + ) + + output[:num_actual_tokens] = triton_out + return output + + +class ModelOptSparseAttentionBackend(FlashAttentionBackend): + """Attention backend that uses ModelOpt's sparse Triton kernel. + + Inherits everything from FlashAttentionBackend except get_impl_cls and get_name. + """ + + @staticmethod + def get_name() -> str: + """Return backend name.""" + return "MODELOPT_SPARSE" + + @staticmethod + def get_impl_cls() -> type: + """Return the attention implementation class.""" + return ModelOptSparseAttentionImpl + + +# --------------------------------------------------------------------------- +# MLA (Multi-Latent Attention) sparse prefill support +# --------------------------------------------------------------------------- +# MLA models (DeepSeek) decompress KV latents to full Q, K, V tensors before +# calling attention in the prefill path. We replace the prefill methods on the +# MLA impl to use our Triton kernel in contiguous mode. V is zero-padded to +# match Q/K head_dim; the caller (_forward_prefill) slices the output back to +# V's head_dim when _pad_v=True. +# +# Decode is unchanged — it uses specialized MLA-aware backends (FlashInfer MLA, +# FlashMLA, TRT-LLM) that operate on compressed latents. +# --------------------------------------------------------------------------- + + +def _modelopt_mla_run_prefill_new_tokens(self, prefill, q, k, v, return_softmax_lse): + """ModelOpt sparse attention for MLA new tokens (causal). + + Replaces ``MLACommonImpl._run_prefill_new_tokens`` when sparse attention + is enabled. Pads V to Q's head_dim, calls the ModelOpt Triton kernel in + contiguous mode, and returns LSE in ``[num_heads, total_tokens]`` format. + """ + padded_v = F.pad(v, [0, q.shape[-1] - v.shape[-1]]) if v.shape[-1] < q.shape[-1] else v + + cu = prefill.query_start_loc + batch = cu.shape[0] - 1 + b_start_loc = cu[:batch] + b_seq_len = cu[1 : batch + 1] - cu[:batch] + + o, lse = attention_with_lse( + q, + k, + padded_v, + b_start_loc=b_start_loc, + b_seq_len=b_seq_len, + max_input_len=prefill.max_query_len, + is_causal=True, + softmax_scale=self.scale, + **self._modelopt_sparse_kw, + ) + + if return_softmax_lse: + return o, lse.transpose(0, 1).contiguous() + return o + + +def _modelopt_mla_run_prefill_context_chunk(self, prefill, chunk_idx, q, k, v): + """ModelOpt sparse attention for MLA context chunks (non-causal). + + Replaces ``MLACommonImpl._run_prefill_context_chunk``. Always returns + ``(output, lse)`` since chunked context needs LSE for merging. + """ + padded_v = F.pad(v, [0, q.shape[-1] - v.shape[-1]]) if v.shape[-1] < q.shape[-1] else v + + cu_q = prefill.query_start_loc + cu_k = prefill.chunked_context.cu_seq_lens[chunk_idx] + batch = cu_q.shape[0] - 1 + + sparse_kw = dict(self._modelopt_sparse_kw) + sparse_kw["dense_window_size"] = 0 # no dense window for non-causal context + + o, lse = attention_with_lse( + q, + k, + padded_v, + b_start_loc=cu_q[:batch], + b_seq_len=cu_q[1 : batch + 1] - cu_q[:batch], + max_input_len=prefill.max_query_len, + is_causal=False, + softmax_scale=self.scale, + b_start_loc_k=cu_k[:batch], + b_seq_len_k=cu_k[1 : batch + 1] - cu_k[:batch], + max_input_len_k=prefill.chunked_context.max_seq_lens[chunk_idx], + **sparse_kw, + ) + + return o, lse.transpose(0, 1).contiguous() + + +def patch_mla_impl_for_sparse(impl, sparse_kw: dict) -> None: + """Monkey-patch an MLACommonImpl to use ModelOpt sparse prefill. + + Sets ``_pad_v=True`` so that ``_forward_prefill`` slices the output back + to ``v_head_dim`` after attention. Replaces the prefill method pointers + with our Triton-kernel-based implementations. + + Args: + impl: An ``MLACommonImpl`` instance (or subclass like ``FlashInferMLAImpl``). + sparse_kw: Sparse attention config dict with keys like ``sparsity_n``, + ``sparsity_m``, ``num_sink_tokens``, ``dense_window_size``. + """ + impl._modelopt_sparse_kw = sparse_kw + impl._pad_v = True + impl._run_prefill_new_tokens = types.MethodType(_modelopt_mla_run_prefill_new_tokens, impl) + impl._run_prefill_context_chunk = types.MethodType( + _modelopt_mla_run_prefill_context_chunk, impl + ) diff --git a/pyproject.toml b/pyproject.toml index 7d45bdbd920..94e8ca61eea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,8 @@ dependencies = [ "rich", "safetensors", "scipy", + "uvloop>=0.22.1", + "vllm==0.20.1", ] [project.optional-dependencies] @@ -248,6 +250,7 @@ convention = "google" [tool.ruff.lint.isort] known-first-party = ["modelopt"] +known-third-party = ["vllm"] split-on-trailing-comma = false diff --git a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py index 56f0a9e9d86..49f3c882b0b 100644 --- a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py +++ b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py @@ -285,3 +285,215 @@ def test_skip_softmax_via_sparsify(self, tiny_llama_dir): assert not torch.isinf(logits_skip).any(), "Inf in skip-softmax logits" # On short sequences (64 tokens), no tiles are skipped — output should match dense torch.testing.assert_close(logits_skip, logits_dense, rtol=1e-3, atol=1e-3) + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestSkipSoftmaxScaleFactor: + """Scale-factor mode: effective threshold = scale_factor / seq_k per sequence.""" + + def _make_inputs(self, batch=2, seq_len=256, num_heads=4, num_kv_heads=2, head_dim=64): + total = batch * seq_len + torch.manual_seed(77) + q = torch.randn(total, num_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + v = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + locs, lens = make_varlen_meta([seq_len] * batch) + return q, k, v, locs, lens + + def test_disabled_matches_dense(self): + """scale_prefill/decode=None produces bit-identical output to dense.""" + q, k, v, locs, lens = self._make_inputs() + scale = 1.0 / (64**0.5) + out_none = attention(q, k, v, locs, lens, 256, softmax_scale=scale) + out_scale_none = attention( + q, + k, + v, + locs, + lens, + 256, + softmax_scale=scale, + skip_softmax_threshold_scale_prefill=None, + skip_softmax_threshold_scale_decode=None, + ) + assert torch.equal(out_none, out_scale_none) + + def test_equivalent_to_fixed_when_seqlen_uniform(self): + """With uniform seq_k, scale-factor mode == fixed mode when threshold = scale / seq_k.""" + seq_len = 512 + q, k, v, locs, lens = self._make_inputs(batch=2, seq_len=seq_len) + scale = 1.0 / (64**0.5) + scale_factor = 5.0 + fixed_threshold = scale_factor / seq_len + + out_fixed = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + skip_softmax_threshold=fixed_threshold, + ) + out_scale = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + skip_softmax_threshold_scale_prefill=scale_factor, + ) + # Same algorithm path; should match to fp16 precision. + torch.testing.assert_close(out_scale, out_fixed, rtol=1e-3, atol=1e-3) + + def test_decode_phase_uses_decode_scale(self): + """In decode (max_input_len==1), only the decode scale factor is applied.""" + batch = 2 + seq_lens_k = [128, 128] # uniform so fixed/scale equivalence holds + num_heads, num_kv_heads, head_dim = 4, 2, 64 + scale = 1.0 / (head_dim**0.5) + scale_factor_decode = 5.0 + fixed_threshold = scale_factor_decode / seq_lens_k[0] + + torch.manual_seed(42) + q_flat = torch.randn(batch, num_heads, head_dim, device="cuda", dtype=torch.float16) + total_kv = sum(seq_lens_k) + k_flat = torch.randn(total_kv, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + v_flat = torch.randn(total_kv, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + + b_start_loc_q = torch.arange(batch, device="cuda", dtype=torch.int32) + b_seq_len_q = torch.ones(batch, device="cuda", dtype=torch.int32) + cumsum = [0] + for sl in seq_lens_k: + cumsum.append(cumsum[-1] + sl) + b_start_loc_k = torch.tensor(cumsum[:-1], device="cuda", dtype=torch.int32) + b_seq_len_k = torch.tensor(seq_lens_k, device="cuda", dtype=torch.int32) + + common = dict( + b_start_loc_k=b_start_loc_k, + b_seq_len_k=b_seq_len_k, + max_input_len_k=max(seq_lens_k), + is_causal=False, + softmax_scale=scale, + ) + + out_fixed = attention( + q_flat, k_flat, v_flat, b_start_loc_q, b_seq_len_q, 1, + skip_softmax_threshold=fixed_threshold, **common, + ) + # Decode scale should fire because max_input_len == 1. + out_decode = attention( + q_flat, k_flat, v_flat, b_start_loc_q, b_seq_len_q, 1, + skip_softmax_threshold_scale_decode=scale_factor_decode, **common, + ) + torch.testing.assert_close(out_decode, out_fixed, rtol=1e-3, atol=1e-3) + + # Prefill scale alone should be a no-op (decode phase ignores prefill scale). + out_prefill_only = attention( + q_flat, k_flat, v_flat, b_start_loc_q, b_seq_len_q, 1, + skip_softmax_threshold_scale_prefill=scale_factor_decode, **common, + ) + out_dense = attention( + q_flat, k_flat, v_flat, b_start_loc_q, b_seq_len_q, 1, **common, + ) + assert torch.equal(out_prefill_only, out_dense) + + def test_prefill_phase_uses_prefill_scale(self): + """In prefill (max_input_len>1), only the prefill scale factor is applied.""" + seq_len = 256 + q, k, v, locs, lens = self._make_inputs(batch=1, seq_len=seq_len) + scale = 1.0 / (64**0.5) + + # Decode scale alone should be inactive in prefill — output equals dense. + out_dense = attention(q, k, v, locs, lens, seq_len, softmax_scale=scale) + out_decode_only = attention( + q, k, v, locs, lens, seq_len, + softmax_scale=scale, + skip_softmax_threshold_scale_decode=5.0, + ) + assert torch.equal(out_decode_only, out_dense) + + def test_mutual_exclusivity(self): + """Setting both fixed threshold and a scale factor raises an error.""" + q, k, v, locs, lens = self._make_inputs() + scale = 1.0 / (64**0.5) + with pytest.raises(AssertionError, match="mutually exclusive"): + attention( + q, k, v, locs, lens, 256, + softmax_scale=scale, + skip_softmax_threshold=1e-3, + skip_softmax_threshold_scale_prefill=5.0, + ) + + def test_matches_pytorch_calibrated_reference(self): + """Triton scale-factor mode matches the FlashSkipSoftmax calibrated path.""" + from modelopt.torch.sparsity.attention_sparsity.methods.flash_skip_softmax import ( + FlashSkipSoftmax, + ) + + batch, seq_len = 1, 256 + num_heads, num_kv_heads, head_dim = 4, 4, 64 # MHA for simplicity + scale = 1.0 / (head_dim**0.5) + scale_factor = 5.0 # corresponds to fixed threshold ≈ 5/256 ≈ 0.0195 + + torch.manual_seed(123) + q_4d = torch.randn(batch, num_heads, seq_len, head_dim, device="cuda", dtype=torch.float32) + k_4d = torch.randn( + batch, num_kv_heads, seq_len, head_dim, device="cuda", dtype=torch.float32 + ) + v_4d = torch.randn( + batch, num_kv_heads, seq_len, head_dim, device="cuda", dtype=torch.float32 + ) + + scores = torch.matmul(q_4d, k_4d.transpose(-2, -1)) * scale + causal_mask = torch.triu(torch.ones(seq_len, seq_len, device="cuda"), diagonal=1).bool() + scores = scores.masked_fill(causal_mask[None, None, :, :], float("-inf")) + + # PyTorch reference: drive the calibrated dynamic-threshold path. + # FlashSkipSoftmax expects scale_factor = a * exp(b * target_sparsity); we + # bypass that decomposition by setting b=0 so scale_factor = a. + method = FlashSkipSoftmax( + method_config={ + "thresholds": {"prefill": [1.0]}, # unused when calibration_params is set + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + method.calibration_params = { + "prefill": {"a": scale_factor, "b": 0.0}, + "decode": {"a": scale_factor, "b": 0.0}, + } + method.target_sparse_ratio = {"prefill": 0.0, "decode": 0.0} + + sparse_mask, _ = method.calculate_sparsity(scores) + if sparse_mask is not None: + scores = scores.masked_fill(~sparse_mask, float("-inf")) + p = torch.softmax(scores, dim=-1) + ref_out = torch.matmul(p, v_4d) + + total = batch * seq_len + q_flat = q_4d.permute(0, 2, 1, 3).reshape(total, num_heads, head_dim).contiguous() + k_flat = k_4d.permute(0, 2, 1, 3).reshape(total, num_kv_heads, head_dim).contiguous() + v_flat = v_4d.permute(0, 2, 1, 3).reshape(total, num_kv_heads, head_dim).contiguous() + locs = torch.arange(batch, device="cuda", dtype=torch.int32) * seq_len + lens = torch.full((batch,), seq_len, device="cuda", dtype=torch.int32) + + triton_out = attention( + q_flat, + k_flat, + v_flat, + locs, + lens, + seq_len, + is_causal=True, + softmax_scale=scale, + skip_softmax_threshold_scale_prefill=scale_factor, + ) + triton_out_4d = triton_out.view(batch, seq_len, num_heads, head_dim).permute(0, 2, 1, 3) + + torch.testing.assert_close(triton_out_4d, ref_out, rtol=5e-3, atol=5e-3) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py new file mode 100644 index 00000000000..f342bcd9ad2 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py @@ -0,0 +1,336 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for paged KV cache mode of the Triton flash attention kernel.""" + +import pytest +import torch +from conftest import make_qkv, make_varlen_meta + +pytestmark = [ + pytest.mark.filterwarnings("ignore::UserWarning"), + pytest.mark.filterwarnings("ignore::RuntimeWarning"), + pytest.mark.filterwarnings("ignore::DeprecationWarning"), +] + +from modelopt.torch.kernels import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE + +if TRITON_KERNEL_AVAILABLE: + from modelopt.torch.kernels import attention + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _scatter_to_paged_cache(k, v, b_start_loc, b_seq_len, num_kv_heads, head_dim, page_size): + """Scatter contiguous K/V into a paged KV cache + block table. + + Args: + k: [total_kv, num_kv_heads, head_dim] contiguous keys + v: [total_kv, num_kv_heads, head_dim] contiguous values + b_start_loc: [batch] start offsets + b_seq_len: [batch] sequence lengths + num_kv_heads: number of KV heads + head_dim: head dimension + page_size: tokens per page + + Returns: + k_cache: [num_blocks, page_size, num_kv_heads, head_dim] + v_cache: [num_blocks, page_size, num_kv_heads, head_dim] + block_table: [batch, max_blocks_per_seq] + """ + batch = b_seq_len.shape[0] + device = k.device + dtype = k.dtype + + # Calculate blocks needed per sequence + blocks_per_seq = [] + for b in range(batch): + slen = int(b_seq_len[b].item()) + blocks_per_seq.append((slen + page_size - 1) // page_size) + + max_blocks = max(blocks_per_seq) + num_blocks = sum(blocks_per_seq) + + k_cache = torch.zeros(num_blocks, page_size, num_kv_heads, head_dim, device=device, dtype=dtype) + v_cache = torch.zeros(num_blocks, page_size, num_kv_heads, head_dim, device=device, dtype=dtype) + block_table = torch.zeros(batch, max_blocks, device=device, dtype=torch.int32) + + global_block = 0 + for b in range(batch): + start = int(b_start_loc[b].item()) + slen = int(b_seq_len[b].item()) + for blk in range(blocks_per_seq[b]): + block_table[b, blk] = global_block + tok_start = blk * page_size + tok_end = min(tok_start + page_size, slen) + n_toks = tok_end - tok_start + k_cache[global_block, :n_toks] = k[start + tok_start : start + tok_end] + v_cache[global_block, :n_toks] = v[start + tok_start : start + tok_end] + global_block += 1 + + return k_cache, v_cache, block_table + + +# --------------------------------------------------------------------------- +# Paged KV cache tests +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestPagedKV: + """Paged KV cache mode tests — verify paged output matches contiguous.""" + + def test_paged_matches_contiguous(self): + """Paged mode produces same output as contiguous mode with identical data.""" + batch = 2 + seq_len = 128 + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + scale = 1.0 / (head_dim**0.5) + total = batch * seq_len + + torch.manual_seed(42) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta([seq_len] * batch) + + # Contiguous reference + out_contig = attention(q, k, v, locs, lens, seq_len, softmax_scale=scale) + + # Build paged cache from the same K/V + locs_k, lens_k = locs, lens + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k, v, locs_k, lens_k, num_kv_heads, head_dim, page_size + ) + + # Paged mode + out_paged = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + b_start_loc_k=locs_k, + b_seq_len_k=lens_k, + max_input_len_k=seq_len, + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + torch.testing.assert_close(out_paged, out_contig, rtol=1e-2, atol=1e-2) + + def test_paged_no_nan(self): + """Paged mode output is finite.""" + batch = 2 + seq_len = 256 + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + scale = 1.0 / (head_dim**0.5) + total = batch * seq_len + + torch.manual_seed(55) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta([seq_len] * batch) + + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k, v, locs, lens, num_kv_heads, head_dim, page_size + ) + + out = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + b_seq_len_k=lens, + max_input_len_k=seq_len, + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + assert not torch.isnan(out).any(), "NaN in paged output" + assert not torch.isinf(out).any(), "Inf in paged output" + + def test_paged_variable_length(self): + """Paged mode works with variable-length sequences.""" + seq_lens = [64, 128] + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + scale = 1.0 / (head_dim**0.5) + total = sum(seq_lens) + + torch.manual_seed(77) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta(seq_lens) + + # Contiguous reference + out_contig = attention(q, k, v, locs, lens, max(seq_lens), softmax_scale=scale) + + # Paged + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k, v, locs, lens, num_kv_heads, head_dim, page_size + ) + + out_paged = attention( + q, + k, + v, + locs, + lens, + max(seq_lens), + softmax_scale=scale, + b_seq_len_k=lens, + max_input_len_k=max(seq_lens), + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + torch.testing.assert_close(out_paged, out_contig, rtol=1e-2, atol=1e-2) + + @pytest.mark.parametrize("page_size", [16, 32, 64]) + def test_paged_different_page_sizes(self, page_size): + """Paged mode works with different page sizes.""" + batch = 2 + seq_len = 128 + num_heads, num_kv_heads, head_dim = 4, 2, 64 + scale = 1.0 / (head_dim**0.5) + total = batch * seq_len + + torch.manual_seed(88) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta([seq_len] * batch) + + out_contig = attention(q, k, v, locs, lens, seq_len, softmax_scale=scale) + + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k, v, locs, lens, num_kv_heads, head_dim, page_size + ) + + out_paged = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + b_seq_len_k=lens, + max_input_len_k=seq_len, + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + torch.testing.assert_close(out_paged, out_contig, rtol=1e-2, atol=1e-2) + + def test_paged_with_sparsity(self): + """Paged mode works with N:M sparsity enabled.""" + batch = 2 + seq_len = 256 + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + scale = 1.0 / (head_dim**0.5) + total = batch * seq_len + + torch.manual_seed(99) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta([seq_len] * batch) + + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k, v, locs, lens, num_kv_heads, head_dim, page_size + ) + + out_paged_sparse = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + b_seq_len_k=lens, + max_input_len_k=seq_len, + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + sparsity_n=2, + sparsity_m=4, + ) + + assert not torch.isnan(out_paged_sparse).any(), "NaN in paged + sparse output" + assert not torch.isinf(out_paged_sparse).any(), "Inf in paged + sparse output" + assert out_paged_sparse.shape == q.shape + + def test_paged_decode(self): + """Paged mode works for decode (single Q token, long KV context).""" + batch = 2 + seq_lens_k = [64, 128] + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + scale = 1.0 / (head_dim**0.5) + total_kv = sum(seq_lens_k) + + torch.manual_seed(33) + q_flat = torch.randn(batch, num_heads, head_dim, device="cuda", dtype=torch.float16) + k_flat = torch.randn(total_kv, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + v_flat = torch.randn(total_kv, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + + b_start_loc_q = torch.arange(batch, device="cuda", dtype=torch.int32) + b_seq_len_q = torch.ones(batch, device="cuda", dtype=torch.int32) + cumsum = [0] + for sl in seq_lens_k: + cumsum.append(cumsum[-1] + sl) + b_start_loc_k = torch.tensor(cumsum[:-1], device="cuda", dtype=torch.int32) + b_seq_len_k = torch.tensor(seq_lens_k, device="cuda", dtype=torch.int32) + + # Build paged cache + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k_flat, v_flat, b_start_loc_k, b_seq_len_k, num_kv_heads, head_dim, page_size + ) + + out = attention( + q_flat, + k_flat, + v_flat, + b_start_loc_q, + b_seq_len_q, + 1, + is_causal=False, + softmax_scale=scale, + b_start_loc_k=b_start_loc_k, + b_seq_len_k=b_seq_len_k, + max_input_len_k=max(seq_lens_k), + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + assert out.shape == q_flat.shape + assert not torch.isnan(out).any(), "NaN in paged decode output"