From c6092150247d6726240996cf72818308d87b0e62 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 5 Jan 2026 05:20:53 +0000 Subject: [PATCH 001/114] add att backend impl --- .../common/basemodel/attention/__init__.py | 2 + .../common/basemodel/attention/base_att.py | 85 ++++++++++++ .../basemodel/attention/triton_backend.py | 123 ++++++++++++++++++ lightllm/common/basemodel/basemodel.py | 13 ++ lightllm/common/basemodel/infer_struct.py | 9 ++ .../triton_kernel/alibi_att}/__init__.py | 0 .../context_flashattention_nopad.py | 0 .../alibi_att}/token_attention_nopad_att1.py | 0 .../token_attention_nopad_reduceV.py | 0 .../token_attention_nopad_softmax.py | 41 ++++-- .../alibi_att}/token_flashattention_nopad.py | 0 .../layer_infer/transformer_layer_infer.py | 56 ++++---- lightllm/models/bloom/model.py | 6 + lightllm/server/api_cli.py | 22 ++++ lightllm/server/core/objs/start_args_type.py | 3 + 15 files changed, 315 insertions(+), 45 deletions(-) create mode 100644 lightllm/common/basemodel/attention/__init__.py create mode 100644 lightllm/common/basemodel/attention/base_att.py create mode 100644 lightllm/common/basemodel/attention/triton_backend.py rename lightllm/{models/bloom/triton_kernel => common/basemodel/triton_kernel/alibi_att}/__init__.py (100%) rename lightllm/{models/bloom/triton_kernel => common/basemodel/triton_kernel/alibi_att}/context_flashattention_nopad.py (100%) rename lightllm/{models/bloom/triton_kernel => common/basemodel/triton_kernel/alibi_att}/token_attention_nopad_att1.py (100%) rename lightllm/{models/bloom/triton_kernel => common/basemodel/triton_kernel/alibi_att}/token_attention_nopad_reduceV.py (100%) rename lightllm/{models/bloom/triton_kernel => common/basemodel/triton_kernel/alibi_att}/token_attention_nopad_softmax.py (77%) rename lightllm/{models/bloom/triton_kernel => common/basemodel/triton_kernel/alibi_att}/token_flashattention_nopad.py (100%) diff --git a/lightllm/common/basemodel/attention/__init__.py b/lightllm/common/basemodel/attention/__init__.py new file mode 100644 index 0000000000..4e7938ec8b --- /dev/null +++ b/lightllm/common/basemodel/attention/__init__.py @@ -0,0 +1,2 @@ +from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState +from .triton_backend import TritonAttBackend, TritonPrefillAttState, TritonDecodeAttState diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py new file mode 100644 index 0000000000..d3c0886866 --- /dev/null +++ b/lightllm/common/basemodel/attention/base_att.py @@ -0,0 +1,85 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +import torch +from typing import Optional + + +class BaseAttBackend: + """ + 用于创建支持各种不同的AttBackend, 如 fa3, flashinfer, triton 实现等, + 这个是单列模式, 每种backend只有一个实例 + """ + + _instances = {} + + def __new__(cls, *args, **kwargs): + """ + 重写__new__方法实现单例模式 + """ + # 检查是否已经有该类的实例 + if cls not in cls._instances: + # 创建新实例并存储 + instance = super().__new__(cls) + cls._instances[cls] = instance + # 返回已有的实例 + return cls._instances[cls] + + def create_att_prefill_state(self) -> "BasePrefillAttState": + raise NotImplementedError("not impl") + + def create_att_decode_state(self) -> "BaseDecodeAttState": + raise NotImplementedError("not impl") + + +@dataclass +class BasePrefillAttState(ABC): + backend: BaseAttBackend = None + infer_state = None + + @abstractmethod + def init_state(self): + pass + + @abstractmethod + def copy_for_prefill_cuda_graph(self, new_state: "BasePrefillAttState"): + pass + + @abstractmethod + def prefill_att( + self, + q: torch.Tensor, + k: torch.tensor, + v: torch.tensor, + layer_weight, + out: Optional[torch.Tensor] = None, + alloc_func=torch.empty, + use_alibi=False, + ) -> torch.Tensor: + raise NotImplementedError("not impl") + + +@dataclass +class BaseDecodeAttState(ABC): + backend: BaseAttBackend = None + infer_state = None + + @abstractmethod + def init_state(self): + pass + + @abstractmethod + def copy_for_decode_cuda_graph(self, new_state: "BaseDecodeAttState"): + pass + + @abstractmethod + def decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + out: Optional[torch.Tensor] = None, + alloc_func=torch.empty, + use_alibi=False, + ) -> torch.Tensor: + pass diff --git a/lightllm/common/basemodel/attention/triton_backend.py b/lightllm/common/basemodel/attention/triton_backend.py new file mode 100644 index 0000000000..a1fd69170f --- /dev/null +++ b/lightllm/common/basemodel/attention/triton_backend.py @@ -0,0 +1,123 @@ +import dataclasses +import torch +from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState +from typing import Optional + + +class TritonAttBackend(BaseAttBackend): + def create_att_prefill_state(self, infer_state) -> "TritonPrefillAttState": + return TritonPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "TritonDecodeAttState": + return TritonDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class TritonPrefillAttState(BasePrefillAttState): + def init_state(self): + pass + + def copy_for_prefill_cuda_graph(self, new_state: "TritonPrefillAttState"): + pass + + def prefill_att( + self, + q: torch.Tensor, + k: torch.tensor, + v: torch.tensor, + layer_weight, + out: Optional[torch.Tensor] = None, + alloc_func=torch.empty, + use_alibi=False, + ) -> torch.Tensor: + if use_alibi: + return self._alibi_prefill_att(q=q, k=k, v=v, layer_weight=layer_weight, out=out, alloc_func=alloc_func) + else: + raise NotImplementedError("error") + + def _alibi_prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + out: Optional[torch.Tensor] = None, + alloc_func=torch.empty, + ): + from lightllm.common.basemodel.infer_struct import InferStateInfo + + infer_state: InferStateInfo = self.infer_state + out = alloc_func(q.shape, q.dtype) if out is None else out + + from ..triton_kernel.alibi_att.context_flashattention_nopad import context_attention_fwd + + context_attention_fwd( + q, + k, + v, + out, + infer_state.b_req_idx, + layer_weight.tp_alibi, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + infer_state.req_manager.req_to_token_indexs, + ) + return out + + +@dataclasses.dataclass +class TritonDecodeAttState(BaseDecodeAttState): + def init_state(self): + pass + + def copy_for_decode_cuda_graph(self, new_state: "TritonDecodeAttState"): + pass + + def decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + out: Optional[torch.Tensor] = None, + alloc_func=torch.empty, + use_alibi=False, + ): + if use_alibi: + return self._alibi_decode_att(q=q, k=k, v=v, layer_weight=layer_weight, out=out, alloc_func=alloc_func) + else: + raise NotImplementedError("error") + + def _alibi_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + out: Optional[torch.Tensor] = None, + alloc_func=torch.empty, + ): + from lightllm.common.basemodel.infer_struct import InferStateInfo + + infer_state: InferStateInfo = self.infer_state + + from ..triton_kernel.alibi_att.token_flashattention_nopad import token_attention_fwd + + out = alloc_func(q.shape, q.dtype) if out is None else out + token_attention_fwd( + q, + k, + k, + out, + layer_weight.tp_alibi, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + infer_state.total_token_num, + alloc_tensor_func=alloc_func, + ) + return out diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 011f998fc0..9f39ff6567 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -32,6 +32,7 @@ from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.utils.infer_utils import post_empty_cache +from .attention import TritonAttBackend logger = init_logger(__name__) @@ -119,6 +120,7 @@ def __init__(self, kvargs): self._init_inferstate_cls() # wait必须在init cudagraph 之前,避免错误捕获 self._wait_other_modules_ready() + self._init_att_backend() self._autotune_warmup() self._init_padded_req() self._init_cudagraph() @@ -238,6 +240,12 @@ def _init_some_value(self): self.vocab_size = self.config["vocab_size"] return + def _init_att_backend(self): + self.prefill_att_backend = TritonAttBackend() + self.decode_att_backend = TritonAttBackend() + assert id(self.prefill_att_backend) == id(self.decode_att_backend) + return + def _init_cudagraph(self): self.graph = ( None if self.disable_cudagraph else CudaGraph(self.graph_max_batch_size, self.graph_max_len_in_batch) @@ -311,6 +319,11 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) # 特殊模型,特殊模式的特定变量初始化操作。 infer_state.mtp_draft_input_hiddens = model_input.mtp_draft_input_hiddens + if infer_state.is_prefill: + infer_state.prefill_att_state = self.prefill_att_backend.create_att_prefill_state(infer_state=infer_state) + else: + infer_state.decode_att_state = self.decode_att_backend.create_att_decode_state(infer_state=infer_state) + return infer_state def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_size: int): diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 8e7174bb39..adad0a6abc 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -11,6 +11,7 @@ from .batch_objs import ModelInput from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.dist_utils import get_global_dp_rank +from .attention import BasePrefillAttState, BaseDecodeAttState class InferStateInfo: @@ -19,6 +20,10 @@ class InferStateInfo: """ def __init__(self): + # prefill 和 decode 使用的 att 状态对象 + self.prefill_att_state: BasePrefillAttState = None + self.decode_att_state: BaseDecodeAttState = None + self.input_ids: torch.Tensor = None self.batch_size: int = None self.total_token_num: int = None @@ -90,6 +95,10 @@ def __init__(self): self.dp_input_split_sizes: List[List[int]] = None def init_some_extra_state(self, model): + if self.is_prefill: + self.prefill_att_state.init_state() + else: + self.decode_att_state.init_state() if self.is_prefill: ( diff --git a/lightllm/models/bloom/triton_kernel/__init__.py b/lightllm/common/basemodel/triton_kernel/alibi_att/__init__.py similarity index 100% rename from lightllm/models/bloom/triton_kernel/__init__.py rename to lightllm/common/basemodel/triton_kernel/alibi_att/__init__.py diff --git a/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/alibi_att/context_flashattention_nopad.py similarity index 100% rename from lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py rename to lightllm/common/basemodel/triton_kernel/alibi_att/context_flashattention_nopad.py diff --git a/lightllm/models/bloom/triton_kernel/token_attention_nopad_att1.py b/lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_att1.py similarity index 100% rename from lightllm/models/bloom/triton_kernel/token_attention_nopad_att1.py rename to lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_att1.py diff --git a/lightllm/models/bloom/triton_kernel/token_attention_nopad_reduceV.py b/lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_reduceV.py similarity index 100% rename from lightllm/models/bloom/triton_kernel/token_attention_nopad_reduceV.py rename to lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_reduceV.py diff --git a/lightllm/models/bloom/triton_kernel/token_attention_nopad_softmax.py b/lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_softmax.py similarity index 77% rename from lightllm/models/bloom/triton_kernel/token_attention_nopad_softmax.py rename to lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_softmax.py index 25af80fabf..adf97735f6 100644 --- a/lightllm/models/bloom/triton_kernel/token_attention_nopad_softmax.py +++ b/lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_softmax.py @@ -6,11 +6,15 @@ @triton.jit def _fwd_kernel_token_softmax( - Logics, B_Start_Loc, B_Seqlen, + Logics, + B_Start_Loc, + B_Seqlen, Prob_Out, - stride_logic_h, stride_logic_bs, - stride_prob_h, stride_prob_bs, - BLOCK_SIZE: tl.constexpr + stride_logic_h, + stride_logic_bs, + stride_prob_h, + stride_prob_bs, + BLOCK_SIZE: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -19,16 +23,22 @@ def _fwd_kernel_token_softmax( cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - row = tl.load(Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, - mask=col_offsets < cur_batch_seq_len, other=-float('inf')).to(tl.float32) + row = tl.load( + Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, + mask=col_offsets < cur_batch_seq_len, + other=-float("inf"), + ).to(tl.float32) row_minus_max = row - tl.max(row, axis=0) numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator - tl.store(Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) - * stride_prob_bs, softmax_output, mask=col_offsets < cur_batch_seq_len) + tl.store( + Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) * stride_prob_bs, + softmax_output, + mask=col_offsets < cur_batch_seq_len, + ) return @@ -44,10 +54,14 @@ def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len): num_warps = 16 _fwd_kernel_token_softmax[(batch, head_num)]( - Logics, B_Start_Loc, B_Seqlen, + Logics, + B_Start_Loc, + B_Seqlen, Prob_Out, - Logics.stride(0), Logics.stride(1), - Prob_Out.stride(0), Prob_Out.stride(1), + Logics.stride(0), + Logics.stride(1), + Prob_Out.stride(0), + Prob_Out.stride(1), num_warps=num_warps, BLOCK_SIZE=BLOCK_SIZE, ) @@ -59,7 +73,7 @@ def test1(): import torch B, N_CTX, H, D = 4, 1025, 12, 128 - + del D dtype = torch.float16 Logics = torch.empty((H, B * N_CTX), dtype=dtype, device="cuda").normal_(mean=0.1, std=10) @@ -85,6 +99,7 @@ def test2(): import torch B, N_CTX, H, D = 3, 1025, 12, 128 + del D dtype = torch.float16 @@ -107,7 +122,7 @@ def test2(): start = 0 for i in range(B): end = start + b_seq_len[i] - torch_o = Logics[:, start: end].reshape(H * 1, -1).softmax(-1).reshape(H, 1 * b_seq_len[i]) + torch_o = Logics[:, start:end].reshape(H * 1, -1).softmax(-1).reshape(H, 1 * b_seq_len[i]) start = end torch_out.append(torch_o) torch_out = torch.cat(torch_out, dim=-1) diff --git a/lightllm/models/bloom/triton_kernel/token_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/alibi_att/token_flashattention_nopad.py similarity index 100% rename from lightllm/models/bloom/triton_kernel/token_flashattention_nopad.py rename to lightllm/common/basemodel/triton_kernel/alibi_att/token_flashattention_nopad.py diff --git a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py index d82a23d039..71b710ed58 100755 --- a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py @@ -2,8 +2,6 @@ from typing import Tuple from lightllm.common.basemodel import TransformerLayerInferTpl from lightllm.models.bloom.layer_weights.transformer_layer_weight import BloomTransformerLayerWeight -from lightllm.models.bloom.triton_kernel.context_flashattention_nopad import context_attention_fwd -from lightllm.models.bloom.triton_kernel.token_flashattention_nopad import token_attention_fwd from lightllm.common.basemodel import InferStateInfo @@ -43,45 +41,39 @@ def _get_qkv( return q, cache_kv def _context_attention_kernel( - self, q, kv, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None + self, + q: torch.Tensor, + kv: torch.Tensor, + infer_state: InferStateInfo, + layer_weight: BloomTransformerLayerWeight, + out=None, ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - context_attention_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv[:, 0 : self.tp_k_head_num_, :], - kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.b_req_idx, - layer_weight.tp_alibi, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + _k = kv[:, 0 : self.tp_k_head_num_, :] + _v = kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] + o_tensor = infer_state.prefill_att_state.prefill_att( + q=_q, + k=_k, + v=_v, + layer_weight=layer_weight, + alloc_func=self.alloc_tensor, + use_alibi=True, ) + o_tensor = o_tensor.view(q.shape) return o_tensor def _token_attention_kernel( - self, q, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None + self, q: torch.Tensor, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - token_attention_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv[:, 0 : self.tp_k_head_num_, :], - kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - layer_weight.tp_alibi, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - infer_state.total_token_num, - alloc_tensor_func=self.alloc_tensor, + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + _k = kv[:, 0 : self.tp_k_head_num_, :] + _v = kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] + o_tensor = infer_state.decode_att_state.decode_att( + q=_q, k=_k, v=_v, layer_weight=layer_weight, alloc_func=self.alloc_tensor, use_alibi=True ) - return o_tensor + return o_tensor.view(q.shape) def _get_o(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor: o_tensor = layer_weight.o_proj.mm(input.view(-1, self.tp_o_head_num_ * self.head_dim_)) diff --git a/lightllm/models/bloom/model.py b/lightllm/models/bloom/model.py index 7e44ec2ebf..451c16fc1a 100644 --- a/lightllm/models/bloom/model.py +++ b/lightllm/models/bloom/model.py @@ -5,6 +5,7 @@ from lightllm.models.bloom.layer_weights.pre_and_post_layer_weight import BloomPreAndPostLayerWeight from lightllm.models.bloom.layer_weights.transformer_layer_weight import BloomTransformerLayerWeight from lightllm.common.basemodel import InferStateInfo, TpPartBaseModel +from lightllm.common.basemodel.attention import TritonAttBackend @ModelRegistry("bloom") @@ -35,3 +36,8 @@ def _init_config(self): def _reset_num_key_value_heads(self): self.config["num_key_value_heads"] = self.config["num_attention_heads"] return + + def _init_att_backend(self): + self.prefill_att_backend = TritonAttBackend() + self.decode_att_backend = TritonAttBackend() + return diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index d193bab41c..6ed2553eab 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -336,6 +336,28 @@ def make_argument_parser() -> argparse.ArgumentParser: help="""inference backend will use microbatch overlap mode for decode. only deepseekv3 model supported now.""", ) + parser.add_argument( + "--llm_prefill_att_backend", + type=str, + choices=[None, "triton", "fa3", "flashinfer"], + default=None, + help="""prefill attention kernel used in llm""", + ) + parser.add_argument( + "--llm_decode_att_backend", + type=str, + choices=[None, "triton", "fa3", "flashinfer"], + default=None, + help="""decode attention kernel used in llm""", + ) + parser.add_argument( + "--llm_kv_type", + type=str, + choices=[None, ""], + default=None, + help="""kv type used in llm""", + ) + parser.add_argument( "--enable_flashinfer_prefill", action="store_true", diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 5ebadaf165..6065923ffb 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -118,6 +118,9 @@ class StartArgs: vit_quant_cfg: Optional[str] = field(default=None) enable_flashinfer_prefill: bool = field(default=False) enable_flashinfer_decode: bool = field(default=False) + llm_prefill_att_backend: str = field(default=None, metadata={"choices": [None, "triton", "fa3", "flashinfer"]}) + llm_decode_att_backend: str = field(default=None, metadata={"choices": [None, "triton", "fa3", "flashinfer"]}) + llm_kv_type: str = field(default=None, metadata={"choices": [None, ""]}) sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]}) penalty_counter_mode: str = field( default="gpu_counter", metadata={"choices": ["cpu_counter", "pin_mem_counter", "gpu_counter"]} From 395a593e3dddf218fbfa248d31cee4587553ecbf Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 5 Jan 2026 08:11:55 +0000 Subject: [PATCH 002/114] fix alibi att backend --- lightllm/common/basemodel/attention/base_att.py | 5 +++-- lightllm/common/basemodel/attention/triton_backend.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index d3c0886866..c1b3e1b6e9 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -33,8 +33,9 @@ def create_att_decode_state(self) -> "BaseDecodeAttState": @dataclass class BasePrefillAttState(ABC): + backend: BaseAttBackend = None - infer_state = None + infer_state: object = None @abstractmethod def init_state(self): @@ -61,7 +62,7 @@ def prefill_att( @dataclass class BaseDecodeAttState(ABC): backend: BaseAttBackend = None - infer_state = None + infer_state: object = None @abstractmethod def init_state(self): diff --git a/lightllm/common/basemodel/attention/triton_backend.py b/lightllm/common/basemodel/attention/triton_backend.py index a1fd69170f..b2dfe1d5d1 100644 --- a/lightllm/common/basemodel/attention/triton_backend.py +++ b/lightllm/common/basemodel/attention/triton_backend.py @@ -109,7 +109,7 @@ def _alibi_decode_att( token_attention_fwd( q, k, - k, + v, out, layer_weight.tp_alibi, infer_state.req_manager.req_to_token_indexs, From 1d794aa14afcc3c1508e67253fc057ce14d0158d Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 5 Jan 2026 10:34:36 +0000 Subject: [PATCH 003/114] fix llama triton att backend --- .../common/basemodel/attention/base_att.py | 13 +- .../basemodel/attention/triton_backend.py | 128 ++++++--- .../basemodel/triton_kernel/att/__init__.py | 0 .../att}/context_flashattention_nopad.py | 0 .../triton_kernel/att}/flash_decoding.py | 5 +- .../att}/flash_decoding_stage1.py | 106 ++++++-- .../att}/flash_decoding_stage2.py | 50 ++-- .../triton_kernel/att/gqa_flash_decoding.py | 38 +++ .../att/gqa_flash_decoding_stage1.py | 176 ++++++++++++ .../att/gqa_flash_decoding_stage2.py | 82 ++++++ .../layer_infer/transformer_layer_infer.py | 250 ++++++++---------- 11 files changed, 627 insertions(+), 221 deletions(-) create mode 100644 lightllm/common/basemodel/triton_kernel/att/__init__.py rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att}/context_flashattention_nopad.py (100%) rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att}/flash_decoding.py (87%) rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att}/flash_decoding_stage1.py (63%) rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att}/flash_decoding_stage2.py (65%) create mode 100644 lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding.py create mode 100644 lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding_stage1.py create mode 100644 lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding_stage2.py diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index c1b3e1b6e9..d3be114f9a 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -1,7 +1,10 @@ +import torch from abc import ABC, abstractmethod from dataclasses import dataclass -import torch -from typing import Optional +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from lightllm.common.basemodel.infer_struct import InferStateInfo class BaseAttBackend: @@ -35,7 +38,7 @@ def create_att_decode_state(self) -> "BaseDecodeAttState": class BasePrefillAttState(ABC): backend: BaseAttBackend = None - infer_state: object = None + infer_state: "InferStateInfo" = None @abstractmethod def init_state(self): @@ -52,7 +55,6 @@ def prefill_att( k: torch.tensor, v: torch.tensor, layer_weight, - out: Optional[torch.Tensor] = None, alloc_func=torch.empty, use_alibi=False, ) -> torch.Tensor: @@ -62,7 +64,7 @@ def prefill_att( @dataclass class BaseDecodeAttState(ABC): backend: BaseAttBackend = None - infer_state: object = None + infer_state: "InferStateInfo" = None @abstractmethod def init_state(self): @@ -79,7 +81,6 @@ def decode_att( k: torch.Tensor, v: torch.Tensor, layer_weight, - out: Optional[torch.Tensor] = None, alloc_func=torch.empty, use_alibi=False, ) -> torch.Tensor: diff --git a/lightllm/common/basemodel/attention/triton_backend.py b/lightllm/common/basemodel/attention/triton_backend.py index b2dfe1d5d1..603261f71e 100644 --- a/lightllm/common/basemodel/attention/triton_backend.py +++ b/lightllm/common/basemodel/attention/triton_backend.py @@ -23,17 +23,16 @@ def copy_for_prefill_cuda_graph(self, new_state: "TritonPrefillAttState"): def prefill_att( self, q: torch.Tensor, - k: torch.tensor, - v: torch.tensor, + k: torch.Tensor, + v: torch.Tensor, layer_weight, - out: Optional[torch.Tensor] = None, alloc_func=torch.empty, use_alibi=False, ) -> torch.Tensor: if use_alibi: - return self._alibi_prefill_att(q=q, k=k, v=v, layer_weight=layer_weight, out=out, alloc_func=alloc_func) + return self._alibi_prefill_att(q=q, k=k, v=v, layer_weight=layer_weight, alloc_func=alloc_func) else: - raise NotImplementedError("error") + return self._nomarl_prefill_att(q=q, k=k, v=v, layer_weight=layer_weight, alloc_func=alloc_func) def _alibi_prefill_att( self, @@ -41,13 +40,9 @@ def _alibi_prefill_att( k: torch.Tensor, v: torch.Tensor, layer_weight, - out: Optional[torch.Tensor] = None, alloc_func=torch.empty, ): - from lightllm.common.basemodel.infer_struct import InferStateInfo - - infer_state: InferStateInfo = self.infer_state - out = alloc_func(q.shape, q.dtype) if out is None else out + out = alloc_func(q.shape, q.dtype) from ..triton_kernel.alibi_att.context_flashattention_nopad import context_attention_fwd @@ -56,13 +51,33 @@ def _alibi_prefill_att( k, v, out, - infer_state.b_req_idx, + self.infer_state.b_req_idx, layer_weight.tp_alibi, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_start_loc, + self.infer_state.b_seq_len, + self.infer_state.b_ready_cache_len, + self.infer_state.max_len_in_batch, + self.infer_state.req_manager.req_to_token_indexs, + ) + return out + + def _nomarl_prefill_att( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_weight, alloc_func=torch.empty + ): + from ..triton_kernel.att.context_flashattention_nopad import context_attention_fwd + + out = alloc_func(q.shape, q.dtype) + context_attention_fwd( + q, + k, + v, + out, + self.infer_state.b_req_idx, + self.infer_state.b_start_loc, + self.infer_state.b_seq_len, + self.infer_state.b_ready_cache_len, + self.infer_state.max_len_in_batch, + self.infer_state.req_manager.req_to_token_indexs, ) return out @@ -81,14 +96,24 @@ def decode_att( k: torch.Tensor, v: torch.Tensor, layer_weight, - out: Optional[torch.Tensor] = None, alloc_func=torch.empty, use_alibi=False, ): if use_alibi: - return self._alibi_decode_att(q=q, k=k, v=v, layer_weight=layer_weight, out=out, alloc_func=alloc_func) + return self._alibi_decode_att(q=q, k=k, v=v, layer_weight=layer_weight, alloc_func=alloc_func) else: - raise NotImplementedError("error") + q_head_num = q.shape[1] + k_head_num = k.shape[1] + if q_head_num == k_head_num: + return self._normal_decode_flash_decoding_att( + q=q, k=k, v=v, layer_weight=layer_weight, alloc_func=alloc_func + ) + elif q_head_num > k_head_num: + return self._normal_decode_gqa_flash_decoding_att( + q=q, k=k, v=v, layer_weight=layer_weight, alloc_func=alloc_func + ) + else: + raise NotImplementedError("error") def _alibi_decode_att( self, @@ -96,28 +121,69 @@ def _alibi_decode_att( k: torch.Tensor, v: torch.Tensor, layer_weight, - out: Optional[torch.Tensor] = None, alloc_func=torch.empty, ): - from lightllm.common.basemodel.infer_struct import InferStateInfo - - infer_state: InferStateInfo = self.infer_state - from ..triton_kernel.alibi_att.token_flashattention_nopad import token_attention_fwd - out = alloc_func(q.shape, q.dtype) if out is None else out + out = alloc_func(q.shape, q.dtype) token_attention_fwd( q, k, v, out, layer_weight.tp_alibi, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - infer_state.total_token_num, + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + self.infer_state.b_start_loc, + self.infer_state.b_seq_len, + self.infer_state.max_len_in_batch, + self.infer_state.total_token_num, + alloc_tensor_func=alloc_func, + ) + return out + + def _normal_decode_flash_decoding_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + alloc_func=torch.empty, + ): + from ..triton_kernel.att.flash_decoding import token_decode_attention_flash_decoding + + out = alloc_func(q.shape, q.dtype) + + token_decode_attention_flash_decoding( + q=q, + infer_state=self.infer_state, + cache_k=k, + cache_v=v, + out=out, alloc_tensor_func=alloc_func, ) return out + + def _normal_decode_gqa_flash_decoding_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + alloc_func=torch.empty, + ): + from ..triton_kernel.att.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding + + out = alloc_func(q.shape, q.dtype) + + print("wzj use gqa decode") + gqa_token_decode_attention_flash_decoding( + q=q, + infer_state=self.infer_state, + cache_k=k, + cache_v=v, + out=out, + alloc_tensor_func=alloc_func, + ) + + return out diff --git a/lightllm/common/basemodel/triton_kernel/att/__init__.py b/lightllm/common/basemodel/triton_kernel/att/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/att/context_flashattention_nopad.py similarity index 100% rename from lightllm/models/llama/triton_kernel/context_flashattention_nopad.py rename to lightllm/common/basemodel/triton_kernel/att/context_flashattention_nopad.py diff --git a/lightllm/models/llama/triton_kernel/flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/flash_decoding.py similarity index 87% rename from lightllm/models/llama/triton_kernel/flash_decoding.py rename to lightllm/common/basemodel/triton_kernel/att/flash_decoding.py index e47e308864..a386212486 100644 --- a/lightllm/models/llama/triton_kernel/flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/flash_decoding.py @@ -1,12 +1,11 @@ import torch -def token_decode_attention_flash_decoding( - q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty -): +def token_decode_attention_flash_decoding(q, infer_state, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty): BLOCK_SEQ = 256 batch_size = infer_state.batch_size max_len_in_batch = infer_state.max_len_in_batch + q_head_num, head_dim = q.shape[1], q.shape[2] calcu_shape1 = (batch_size, q_head_num, head_dim) from .flash_decoding_stage1 import flash_decode_stage1 diff --git a/lightllm/models/llama/triton_kernel/flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/flash_decoding_stage1.py similarity index 63% rename from lightllm/models/llama/triton_kernel/flash_decoding_stage1.py rename to lightllm/common/basemodel/triton_kernel/att/flash_decoding_stage1.py index 86a3af103d..4691e2db50 100644 --- a/lightllm/models/llama/triton_kernel/flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/flash_decoding_stage1.py @@ -2,21 +2,40 @@ import triton import triton.language as tl + @triton.jit def _fwd_kernel_flash_decode_stage1( - Q, K, V, sm_scale, Req_to_tokens, B_req_idx, B_Seqlen, - Mid_O, # [batch, head, seq_block_num, head_dim] - Mid_O_LogExpSum, #[batch, head, seq_block_num] - stride_req_to_tokens_b, stride_req_to_tokens_s, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_vbs, stride_vh, stride_vd, - stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, - stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, + Q, + K, + V, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, gqa_group_size, - BLOCK_SEQ: tl.constexpr, + BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr + BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -30,11 +49,18 @@ def _fwd_kernel_flash_decode_stage1( cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d - - block_n_size = tl.where(cur_batch_end_index - cur_batch_start_index <= 0, 0, cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1) // BLOCK_N - + + block_n_size = ( + tl.where( + cur_batch_end_index - cur_batch_start_index <= 0, + 0, + cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1, + ) + // BLOCK_N + ) + offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) - + q = tl.load(Q + off_q) sum_exp = 0.0 @@ -43,7 +69,11 @@ def _fwd_kernel_flash_decode_stage1( for start_n in range(0, block_n_size, 1): offs_n_new = start_n * BLOCK_N + offs_n - k_loc = tl.load(Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, mask=offs_n_new < cur_batch_end_index, other=0) + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) k_loc = k_loc.to(tl.int64) off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) @@ -51,7 +81,7 @@ def _fwd_kernel_flash_decode_stage1( att_value *= sm_scale att_value = tl.where(offs_n_new < cur_batch_end_index, att_value, float("-inf")) v = tl.load(V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - + cur_max_logic = tl.max(att_value, axis=0) new_max_logic = tl.maximum(cur_max_logic, max_logic) @@ -62,7 +92,7 @@ def _fwd_kernel_flash_decode_stage1( sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0) max_logic = new_max_logic - + need_store = tl.where(block_n_size == 0, 0, 1) for _ in range(0, need_store, 1): off_mid_o = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + seq_start_block * stride_mid_os + offs_d @@ -73,7 +103,9 @@ def _fwd_kernel_flash_decode_stage1( @torch.no_grad() -def flash_decode_stage1(q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_batch, mid_out, mid_out_logsumexp, block_seq): +def flash_decode_stage1( + q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_batch, mid_out, mid_out_logsumexp, block_seq +): BLOCK_SEQ = block_seq BLOCK_N = 16 assert BLOCK_SEQ % BLOCK_N == 0 @@ -85,17 +117,35 @@ def flash_decode_stage1(q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_ batch, head_num = B_req_idx.shape[0], q.shape[1] grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) gqa_group_size = q.shape[1] // k.shape[1] - + _fwd_kernel_flash_decode_stage1[grid]( - q, k, v, sm_scale, Req_to_tokens, B_req_idx, B_Seqlen, + q, + k, + v, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, mid_out, mid_out_logsumexp, - Req_to_tokens.stride(0), Req_to_tokens.stride(1), - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3), - mid_out_logsumexp.stride(0), mid_out_logsumexp.stride(1), mid_out_logsumexp.stride(2), + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + mid_out.stride(0), + mid_out.stride(1), + mid_out.stride(2), + mid_out.stride(3), + mid_out_logsumexp.stride(0), + mid_out_logsumexp.stride(1), + mid_out_logsumexp.stride(2), gqa_group_size, BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=Lk, @@ -103,4 +153,4 @@ def flash_decode_stage1(q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_ num_warps=1, num_stages=2, ) - return \ No newline at end of file + return diff --git a/lightllm/models/llama/triton_kernel/flash_decoding_stage2.py b/lightllm/common/basemodel/triton_kernel/att/flash_decoding_stage2.py similarity index 65% rename from lightllm/models/llama/triton_kernel/flash_decoding_stage2.py rename to lightllm/common/basemodel/triton_kernel/att/flash_decoding_stage2.py index 81227f967b..101e99dde5 100644 --- a/lightllm/models/llama/triton_kernel/flash_decoding_stage2.py +++ b/lightllm/common/basemodel/triton_kernel/att/flash_decoding_stage2.py @@ -6,14 +6,22 @@ @triton.jit def _fwd_kernel_flash_decode_stage2( B_Seqlen, - Mid_O, # [batch, head, seq_block_num, head_dim] - Mid_O_LogExpSum, #[batch, head, seq_block_num] - O, #[batch, head, head_dim] - stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, - stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, - stride_obs, stride_oh, stride_od, + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + O, # [batch, head, head_dim] + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + stride_obs, + stride_oh, + stride_od, BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr): + BLOCK_DMODEL: tl.constexpr, +): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -32,33 +40,43 @@ def _fwd_kernel_flash_decode_stage2( tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) new_max_logic = tl.maximum(tlogic, max_logic) - + old_scale = tl.exp(max_logic - new_max_logic) acc *= old_scale exp_logic = tl.exp(tlogic - new_max_logic) acc += exp_logic * tv sum_exp = sum_exp * old_scale + exp_logic max_logic = new_max_logic - + tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp) return @torch.no_grad() -def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq): +def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, out, block_seq): Lk = mid_out.shape[-1] assert Lk in {16, 32, 64, 128} batch, head_num = mid_out.shape[0], mid_out.shape[1] grid = (batch, head_num) - + _fwd_kernel_flash_decode_stage2[grid]( - B_Seqlen, mid_out, mid_out_logexpsum, O, - mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3), - mid_out_logexpsum.stride(0), mid_out_logexpsum.stride(1), mid_out_logexpsum.stride(2), - O.stride(0), O.stride(1), O.stride(2), + B_Seqlen, + mid_out, + mid_out_logexpsum, + out, + mid_out.stride(0), + mid_out.stride(1), + mid_out.stride(2), + mid_out.stride(3), + mid_out_logexpsum.stride(0), + mid_out_logexpsum.stride(1), + mid_out_logexpsum.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), BLOCK_SEQ=block_seq, BLOCK_DMODEL=Lk, num_warps=4, num_stages=2, ) - return \ No newline at end of file + return diff --git a/lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding.py new file mode 100644 index 0000000000..c56bf7d5ab --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding.py @@ -0,0 +1,38 @@ +import torch + + +def gqa_token_decode_attention_flash_decoding( + q: torch.Tensor, infer_state, cache_k: torch.Tensor, cache_v: torch.Tensor, out=None, alloc_tensor_func=torch.empty +): + BLOCK_SEQ = 128 + batch_size = infer_state.batch_size + max_len_in_batch = infer_state.max_len_in_batch + q_head_num, head_dim = q.shape[1], q.shape[2] + calcu_shape1 = (batch_size, q_head_num, head_dim) + + from .gqa_flash_decoding_stage1 import flash_decode_stage1 + from .gqa_flash_decoding_stage2 import flash_decode_stage2 + + o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out + + mid_o = alloc_tensor_func( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" + ) + mid_o_logexpsum = alloc_tensor_func( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" + ) + + flash_decode_stage1( + q.view(calcu_shape1), + cache_k, + cache_v, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + mid_o, + mid_o_logexpsum, + BLOCK_SEQ, + ) + flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) + return o_tensor diff --git a/lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding_stage1.py new file mode 100644 index 0000000000..320c2cf798 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding_stage1.py @@ -0,0 +1,176 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_decode_stage1( + Q, + K, + V, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + gqa_group_size, + Q_HEAD_NUM: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_kv_head = tl.program_id(1) + seq_start_block = tl.program_id(2) + + cur_q_head_offs = tl.arange(0, Q_HEAD_NUM) + cur_q_head_range = cur_kv_head * gqa_group_size + cur_q_head_offs + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + cur_batch_start_index = seq_start_block * BLOCK_SEQ + cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) + + off_q = cur_batch * stride_qbs + cur_q_head_range[:, None] * stride_qh + offs_d[None, :] + + block_n_size = ( + tl.where( + cur_batch_end_index - cur_batch_start_index <= 0, + 0, + cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1, + ) + // BLOCK_N + ) + + offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) + + q = tl.load(Q + off_q, mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size, other=0.0) + + sum_exp = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) + max_logic = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) - float("inf") + acc = tl.zeros([Q_HEAD_NUM, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, block_n_size, 1): + offs_n_new = start_n * BLOCK_N + offs_n + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ).to(tl.int64) + off_k = k_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] + k = tl.load(K + off_k, mask=offs_n_new[None, :] < cur_batch_end_index, other=0.0) + att_value = tl.dot(q, k.to(q.dtype)) + att_value *= sm_scale + att_value = tl.where(offs_n_new[None, :] < cur_batch_end_index, att_value, float("-inf")) + v = tl.load( + V + k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :], + mask=offs_n_new[:, None] < cur_batch_end_index, + other=0.0, + ) + + cur_max_logic = tl.max(att_value, axis=1) + new_max_logic = tl.maximum(cur_max_logic, max_logic) + + exp_logic = tl.exp(att_value - new_max_logic[:, None]) + logic_scale = tl.exp(max_logic - new_max_logic) + acc *= logic_scale[:, None] + acc += tl.dot(exp_logic.to(v.dtype), v) + + sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1) + max_logic = new_max_logic + + need_store = tl.where(block_n_size == 0, 0, 1) + for _ in range(0, need_store, 1): + off_mid_o = ( + cur_batch * stride_mid_ob + + cur_q_head_range[:, None] * stride_mid_oh + + seq_start_block * stride_mid_os + + offs_d[None, :] + ) + off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + seq_start_block + tl.store( + Mid_O + off_mid_o, + acc / sum_exp[:, None], + mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size, + ) + tl.store( + Mid_O_LogExpSum + off_mid_o_logexpsum, + max_logic + tl.log(sum_exp), + mask=cur_q_head_range < (cur_kv_head + 1) * gqa_group_size, + ) + return + + +@torch.no_grad() +def flash_decode_stage1( + q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_batch, mid_out, mid_out_logsumexp, block_seq +): + BLOCK_SEQ = block_seq + BLOCK_N = 16 + assert BLOCK_SEQ % BLOCK_N == 0 + # shape constraints + Lq, Lk = q.shape[-1], k.shape[-1] + assert Lq == Lk + assert Lk in {16, 32, 64, 128} + sm_scale = 1.0 / (Lk ** 0.5) + batch, kv_head_num = B_req_idx.shape[0], k.shape[1] + grid = (batch, kv_head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) + gqa_group_size = q.shape[1] // k.shape[1] + + _fwd_kernel_flash_decode_stage1[grid]( + q, + k, + v, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + mid_out, + mid_out_logsumexp, + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + mid_out.stride(0), + mid_out.stride(1), + mid_out.stride(2), + mid_out.stride(3), + mid_out_logsumexp.stride(0), + mid_out_logsumexp.stride(1), + mid_out_logsumexp.stride(2), + gqa_group_size, + Q_HEAD_NUM=max(16, triton.next_power_of_2(gqa_group_size)), + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK_N, + num_warps=2, + num_stages=2, + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding_stage2.py b/lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding_stage2.py new file mode 100644 index 0000000000..101e99dde5 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding_stage2.py @@ -0,0 +1,82 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + O, # [batch, head, head_dim] + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + stride_obs, + stride_oh, + stride_od, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + + block_n_size = tl.where(cur_batch_seq_len <= 0, 0, cur_batch_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + for block_seq_n in range(0, block_n_size, 1): + tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) + tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) + new_max_logic = tl.maximum(tlogic, max_logic) + + old_scale = tl.exp(max_logic - new_max_logic) + acc *= old_scale + exp_logic = tl.exp(tlogic - new_max_logic) + acc += exp_logic * tv + sum_exp = sum_exp * old_scale + exp_logic + max_logic = new_max_logic + + tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp) + return + + +@torch.no_grad() +def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, out, block_seq): + Lk = mid_out.shape[-1] + assert Lk in {16, 32, 64, 128} + batch, head_num = mid_out.shape[0], mid_out.shape[1] + grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + mid_out, + mid_out_logexpsum, + out, + mid_out.stride(0), + mid_out.stride(1), + mid_out.stride(2), + mid_out.stride(3), + mid_out_logexpsum.stride(0), + mid_out_logexpsum.stride(1), + mid_out_logexpsum.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_SEQ=block_seq, + BLOCK_DMODEL=Lk, + num_warps=4, + num_stages=2, + ) + return diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index b08b2aa1fd..71dd2a69e2 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -7,10 +7,6 @@ from functools import partial from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight -from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd, - context_attention_fwd_ppl_int8kv, -) from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd, token_att_fwd_int8k from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2, token_att_fwd2_int8v @@ -69,30 +65,30 @@ def _bind_norm(self): def _bind_attention(self): if get_env_start_args().enable_fa3: if "offline_calibration_fp8kv" in self.mode: - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_flashattention_fp8, self - ) - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_flashattention_fp8, self - ) + # self._context_attention_kernel = partial( + # LlamaTransformerLayerInfer._context_attention_flashattention_fp8, self + # ) + # self._token_attention_kernel = partial( + # LlamaTransformerLayerInfer._token_decode_attention_flashattention_fp8, self + # ) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_fp8kv, self) elif "export_fp8kv_calibration" in self.mode: - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_flashattention, self - ) - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_flashattention, self - ) + # self._context_attention_kernel = partial( + # LlamaTransformerLayerInfer._context_attention_flashattention, self + # ) + # self._token_attention_kernel = partial( + # LlamaTransformerLayerInfer._token_decode_attention_flashattention, self + # ) self._copy_kv_to_mem_cache = partial( LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self ) elif not self.mode: - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_flashattention, self - ) - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_flashattention, self - ) + # self._context_attention_kernel = partial( + # LlamaTransformerLayerInfer._context_attention_flashattention, self + # ) + # self._token_attention_kernel = partial( + # LlamaTransformerLayerInfer._token_decode_attention_flashattention, self + # ) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) else: raise Exception(f"Unsupported mode for fa3 backend: {self.mode}") @@ -102,84 +98,86 @@ def _bind_attention(self): LlamaTransformerLayerInfer._context_attention_flashinfer_kernel, self ) else: - self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self) + # self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self) + pass if "ppl_int8kv" in self.mode: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv, self) + # self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv, + # self) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, self) - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_kernel_ppl_int8kv, self - ) + # self._context_attention_kernel = partial( + # LlamaTransformerLayerInfer._context_attention_kernel_ppl_int8kv, self + # ) elif "ppl_int8kv_flashdecoding_diverse" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv_flashdecoding_diverse, self - ) + # self._token_attention_kernel = partial( + # LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv_flashdecoding_diverse, self + # ) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, self) - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_kernel_ppl_int8kv, self - ) + # self._context_attention_kernel = partial( + # LlamaTransformerLayerInfer._context_attention_kernel_ppl_int8kv, self + # ) elif "ppl_int8kv_flashdecoding" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv_flashdecoding, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, self) - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_kernel_ppl_int8kv, self - ) + # self._token_attention_kernel = partial( + # LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv_flashdecoding, self + # ) + # self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, + # self) + # self._context_attention_kernel = partial( + # LlamaTransformerLayerInfer._context_attention_kernel_ppl_int8kv, self + # ) + pass elif "ppl_int4kv_flashdecoding" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_ppl_int4kv_flashdecoding, self - ) + # self._token_attention_kernel = partial( + # LlamaTransformerLayerInfer._token_decode_attention_ppl_int4kv_flashdecoding, self + # ) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int4kv, self) elif "ppl_fp16" in self.mode: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_ppl_fp16, self) + # self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_ppl_fp16, self) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) elif "ppl_fp16_flashdecoding" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_ppl_fp16_flashdecoding, self - ) + # self._token_attention_kernel = partial( + # LlamaTransformerLayerInfer._token_decode_attention_ppl_fp16_flashdecoding, self + # ) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) elif "triton_int8kv" in self.mode: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_int8kv, self) + # self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_int8kv, self) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_int8kv, self) elif "offline_calibration_fp8kv" in self.mode: assert get_env_start_args().enable_flashinfer_prefill and get_env_start_args().enable_flashinfer_decode self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_fp8kv, self) - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_flashinfer_kernel_fp8, self - ) - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_flashinfer_fp8, self - ) + # self._context_attention_kernel = partial( + # LlamaTransformerLayerInfer._context_attention_flashinfer_kernel_fp8, self + # ) + # self._token_attention_kernel = partial( + # LlamaTransformerLayerInfer._token_decode_attention_flashinfer_fp8, self + # ) elif "triton_flashdecoding" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_flashdecoding, self - ) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) elif "triton_gqa_attention" in self.mode: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_gqa_attention_normal, self) + # self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_gqa_attention_normal, + # self) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) elif "triton_gqa_flashdecoding" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_gqa_flashdecoding, self - ) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) + elif "triton_gqa_flashdecoding_vsm" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_gqa_flashdecoding_vsm, self - ) + # self._token_attention_kernel = partial( + # LlamaTransformerLayerInfer._token_decode_attention_gqa_flashdecoding_vsm, self + # ) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) elif "export_fp8kv_calibration" in self.mode: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_flashinfer, self) + # self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_flashinfer, + # self) self._copy_kv_to_mem_cache = partial( LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self ) elif not self.mode: - if get_env_start_args().enable_flashinfer_decode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_flashinfer, self - ) - else: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_normal, self) + # if get_env_start_args().enable_flashinfer_decode: + # self._token_attention_kernel = partial( + # LlamaTransformerLayerInfer._token_decode_attention_flashinfer, self + # ) + # else: + # self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_normal, + # self) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) else: raise Exception(f"Unsupported mode: {self.mode}") @@ -241,6 +239,40 @@ def _tpsp_get_qkv( return q, cache_kv + def _context_attention_kernel( + self, + q: torch.Tensor, + kv: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: LlamaTransformerLayerWeight, + out=None, + ) -> torch.Tensor: + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + _k = kv[:, 0 : self.tp_k_head_num_, :] + _v = kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] + o_tensor = infer_state.prefill_att_state.prefill_att( + q=_q, + k=_k, + v=_v, + layer_weight=layer_weight, + alloc_func=self.alloc_tensor, + ) + o_tensor = o_tensor.view(q.shape) + return o_tensor + + def _token_attention_kernel( + self, q: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight, out=None + ) -> torch.Tensor: + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + _k = kv[:, 0 : self.tp_k_head_num_, :] + _v = kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] + o_tensor = infer_state.decode_att_state.decode_att( + q=_q, k=_k, v=_v, layer_weight=layer_weight, alloc_func=self.alloc_tensor + ) + return o_tensor.view(q.shape) + def _context_attention_flashinfer_kernel_fp8( self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None ) -> torch.Tensor: @@ -274,25 +306,6 @@ def _context_attention_flashinfer_kernel( ) return o_tensor - def _context_attention_kernel( - self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - context_attention_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv[:, 0 : self.tp_k_head_num_, :], - kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, - ) - return o_tensor - def _context_attention_kernel_ppl_int8kv( self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None ) -> torch.Tensor: @@ -313,16 +326,16 @@ def _context_attention_kernel_ppl_int8kv( max_seq_len, kv_dequant, ) - context_attention_fwd_ppl_int8kv( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv_dequant[:, 0 : self.tp_k_head_num_, :, :], - kv_dequant[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - infer_state.b_ready_cache_len, - ) + # context_attention_fwd_ppl_int8kv( + # q.view(-1, self.tp_q_head_num_, self.head_dim_), + # kv_dequant[:, 0 : self.tp_k_head_num_, :, :], + # kv_dequant[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :, :], + # o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), + # infer_state.b_start_loc, + # infer_state.b_seq_len, + # infer_state.max_len_in_batch, + # infer_state.b_ready_cache_len, + # ) return o_tensor def _context_attention_flashattention(self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None): @@ -658,43 +671,6 @@ def _token_decode_attention_int8kv(self, q, infer_state: LlamaInferStateInfo, la prob = None return o_tensor - def _token_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - from lightllm.models.llama.triton_kernel.flash_decoding import token_decode_attention_flash_decoding - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_v, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - - def _token_decode_attention_gqa_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - # 对 gqa 模型进行推理优化的代码 - from ..triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return gqa_token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_v, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - def _token_decode_attention_ppl_int8kv(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): batch_size = infer_state.batch_size calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) From 5120d01d279e901a32ef09274ea3481feb098586 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 5 Jan 2026 10:48:30 +0000 Subject: [PATCH 004/114] fix att --- .../basemodel/attention/triton_backend.py | 48 ++++- .../att}/gqa_decode_flashattention_nopad.py | 0 .../att}/gqa_flash_decoding_vsm.py | 0 .../layer_infer/transformer_layer_infer.py | 24 --- .../llama/triton_kernel/gqa_flash_decoding.py | 37 ---- .../gqa_flash_decoding_stage1.py | 176 ------------------ .../gqa_flash_decoding_stage2.py | 64 ------- 7 files changed, 47 insertions(+), 302 deletions(-) rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att}/gqa_decode_flashattention_nopad.py (100%) rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att}/gqa_flash_decoding_vsm.py (100%) delete mode 100644 lightllm/models/llama/triton_kernel/gqa_flash_decoding.py delete mode 100644 lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py delete mode 100644 lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage2.py diff --git a/lightllm/common/basemodel/attention/triton_backend.py b/lightllm/common/basemodel/attention/triton_backend.py index 603261f71e..1c6c8ccd7e 100644 --- a/lightllm/common/basemodel/attention/triton_backend.py +++ b/lightllm/common/basemodel/attention/triton_backend.py @@ -176,7 +176,6 @@ def _normal_decode_gqa_flash_decoding_att( out = alloc_func(q.shape, q.dtype) - print("wzj use gqa decode") gqa_token_decode_attention_flash_decoding( q=q, infer_state=self.infer_state, @@ -187,3 +186,50 @@ def _normal_decode_gqa_flash_decoding_att( ) return out + + def _normal_decode_gqa_flash_decoding_att_vsm( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + alloc_func=torch.empty, + ): + # TODO USE , 在特定场景下比 _normal_decode_gqa_flash_decoding_att 省显存 + from ..triton_kernel.att.gqa_flash_decoding_vsm import gqa_token_decode_attention_flash_decoding_vsm + + out = alloc_func(q.shape, q.dtype) + + gqa_token_decode_attention_flash_decoding_vsm( + q=q, + k=k, + v=v, + infer_state=self.infer_state, + out=out, + alloc_tensor_func=alloc_func, + ) + return out + + def _normal_decode_gqa_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + alloc_func=torch.empty, + ): + # TODO USE , 在特定场景下比 _normal_decode_gqa_flash_decoding_att 省显存 + from ..triton_kernel.att.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd + + out = alloc_func(q.shape, q.dtype) + + gqa_decode_attention_fwd( + q=q, + k=k, + v=v, + out=out, + req_to_tokens=self.infer_state.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + b_seq_len=self.infer_state.b_seq_len, + ) + return out diff --git a/lightllm/models/llama/triton_kernel/gqa_decode_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/att/gqa_decode_flashattention_nopad.py similarity index 100% rename from lightllm/models/llama/triton_kernel/gqa_decode_flashattention_nopad.py rename to lightllm/common/basemodel/triton_kernel/att/gqa_decode_flashattention_nopad.py diff --git a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_vsm.py b/lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding_vsm.py similarity index 100% rename from lightllm/models/llama/triton_kernel/gqa_flash_decoding_vsm.py rename to lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding_vsm.py diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 71dd2a69e2..ee0e6595df 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -152,10 +152,6 @@ def _bind_attention(self): # ) elif "triton_flashdecoding" in self.mode: self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - elif "triton_gqa_attention" in self.mode: - # self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_gqa_attention_normal, - # self) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) elif "triton_gqa_flashdecoding" in self.mode: self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) @@ -609,26 +605,6 @@ def _token_decode_attention_normal(self, q, infer_state: LlamaInferStateInfo, la ) return o_tensor - def _token_decode_gqa_attention_normal(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - # 对 gqa模型进行推理优化的代码 - from ..triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd - - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - gqa_decode_attention_fwd( - q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - o_tensor.view(calcu_shape1), - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - ) - return o_tensor - def _token_decode_attention_int8kv(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): total_token_num = infer_state.total_token_num batch_size = infer_state.batch_size diff --git a/lightllm/models/llama/triton_kernel/gqa_flash_decoding.py b/lightllm/models/llama/triton_kernel/gqa_flash_decoding.py deleted file mode 100644 index 67be7c968b..0000000000 --- a/lightllm/models/llama/triton_kernel/gqa_flash_decoding.py +++ /dev/null @@ -1,37 +0,0 @@ -import torch - - -def gqa_token_decode_attention_flash_decoding( - q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty -): - BLOCK_SEQ = 128 - batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch - calcu_shape1 = (batch_size, q_head_num, head_dim) - - from .gqa_flash_decoding_stage1 import flash_decode_stage1 - from .gqa_flash_decoding_stage2 import flash_decode_stage2 - - o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out - - mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" - ) - mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" - ) - - flash_decode_stage1( - q.view(calcu_shape1), - cache_k, - cache_v, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - mid_o, - mid_o_logexpsum, - BLOCK_SEQ, - ) - flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) - return o_tensor diff --git a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py b/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py deleted file mode 100644 index 320c2cf798..0000000000 --- a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py +++ /dev/null @@ -1,176 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_flash_decode_stage1( - Q, - K, - V, - sm_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, - Mid_O, # [batch, head, seq_block_num, head_dim] - Mid_O_LogExpSum, # [batch, head, seq_block_num] - stride_req_to_tokens_b, - stride_req_to_tokens_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_mid_ob, - stride_mid_oh, - stride_mid_os, - stride_mid_od, - stride_mid_o_eb, - stride_mid_o_eh, - stride_mid_o_es, - gqa_group_size, - Q_HEAD_NUM: tl.constexpr, - BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_kv_head = tl.program_id(1) - seq_start_block = tl.program_id(2) - - cur_q_head_offs = tl.arange(0, Q_HEAD_NUM) - cur_q_head_range = cur_kv_head * gqa_group_size + cur_q_head_offs - - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_batch_start_index = seq_start_block * BLOCK_SEQ - cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) - - off_q = cur_batch * stride_qbs + cur_q_head_range[:, None] * stride_qh + offs_d[None, :] - - block_n_size = ( - tl.where( - cur_batch_end_index - cur_batch_start_index <= 0, - 0, - cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1, - ) - // BLOCK_N - ) - - offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) - - q = tl.load(Q + off_q, mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size, other=0.0) - - sum_exp = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) - max_logic = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) - float("inf") - acc = tl.zeros([Q_HEAD_NUM, BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, block_n_size, 1): - offs_n_new = start_n * BLOCK_N + offs_n - k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, - mask=offs_n_new < cur_batch_end_index, - other=0, - ).to(tl.int64) - off_k = k_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] - k = tl.load(K + off_k, mask=offs_n_new[None, :] < cur_batch_end_index, other=0.0) - att_value = tl.dot(q, k.to(q.dtype)) - att_value *= sm_scale - att_value = tl.where(offs_n_new[None, :] < cur_batch_end_index, att_value, float("-inf")) - v = tl.load( - V + k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :], - mask=offs_n_new[:, None] < cur_batch_end_index, - other=0.0, - ) - - cur_max_logic = tl.max(att_value, axis=1) - new_max_logic = tl.maximum(cur_max_logic, max_logic) - - exp_logic = tl.exp(att_value - new_max_logic[:, None]) - logic_scale = tl.exp(max_logic - new_max_logic) - acc *= logic_scale[:, None] - acc += tl.dot(exp_logic.to(v.dtype), v) - - sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1) - max_logic = new_max_logic - - need_store = tl.where(block_n_size == 0, 0, 1) - for _ in range(0, need_store, 1): - off_mid_o = ( - cur_batch * stride_mid_ob - + cur_q_head_range[:, None] * stride_mid_oh - + seq_start_block * stride_mid_os - + offs_d[None, :] - ) - off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + seq_start_block - tl.store( - Mid_O + off_mid_o, - acc / sum_exp[:, None], - mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size, - ) - tl.store( - Mid_O_LogExpSum + off_mid_o_logexpsum, - max_logic + tl.log(sum_exp), - mask=cur_q_head_range < (cur_kv_head + 1) * gqa_group_size, - ) - return - - -@torch.no_grad() -def flash_decode_stage1( - q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_batch, mid_out, mid_out_logsumexp, block_seq -): - BLOCK_SEQ = block_seq - BLOCK_N = 16 - assert BLOCK_SEQ % BLOCK_N == 0 - # shape constraints - Lq, Lk = q.shape[-1], k.shape[-1] - assert Lq == Lk - assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / (Lk ** 0.5) - batch, kv_head_num = B_req_idx.shape[0], k.shape[1] - grid = (batch, kv_head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) - gqa_group_size = q.shape[1] // k.shape[1] - - _fwd_kernel_flash_decode_stage1[grid]( - q, - k, - v, - sm_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, - mid_out, - mid_out_logsumexp, - Req_to_tokens.stride(0), - Req_to_tokens.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - mid_out.stride(0), - mid_out.stride(1), - mid_out.stride(2), - mid_out.stride(3), - mid_out_logsumexp.stride(0), - mid_out_logsumexp.stride(1), - mid_out_logsumexp.stride(2), - gqa_group_size, - Q_HEAD_NUM=max(16, triton.next_power_of_2(gqa_group_size)), - BLOCK_SEQ=BLOCK_SEQ, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK_N, - num_warps=2, - num_stages=2, - ) - return diff --git a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage2.py b/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage2.py deleted file mode 100644 index 81227f967b..0000000000 --- a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage2.py +++ /dev/null @@ -1,64 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_flash_decode_stage2( - B_Seqlen, - Mid_O, # [batch, head, seq_block_num, head_dim] - Mid_O_LogExpSum, #[batch, head, seq_block_num] - O, #[batch, head, head_dim] - stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, - stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, - stride_obs, stride_oh, stride_od, - BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - - block_n_size = tl.where(cur_batch_seq_len <= 0, 0, cur_batch_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ - - sum_exp = 0.0 - max_logic = -float("inf") - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - - offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d - offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh - for block_seq_n in range(0, block_n_size, 1): - tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) - tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) - new_max_logic = tl.maximum(tlogic, max_logic) - - old_scale = tl.exp(max_logic - new_max_logic) - acc *= old_scale - exp_logic = tl.exp(tlogic - new_max_logic) - acc += exp_logic * tv - sum_exp = sum_exp * old_scale + exp_logic - max_logic = new_max_logic - - tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp) - return - - -@torch.no_grad() -def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq): - Lk = mid_out.shape[-1] - assert Lk in {16, 32, 64, 128} - batch, head_num = mid_out.shape[0], mid_out.shape[1] - grid = (batch, head_num) - - _fwd_kernel_flash_decode_stage2[grid]( - B_Seqlen, mid_out, mid_out_logexpsum, O, - mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3), - mid_out_logexpsum.stride(0), mid_out_logexpsum.stride(1), mid_out_logexpsum.stride(2), - O.stride(0), O.stride(1), O.stride(2), - BLOCK_SEQ=block_seq, - BLOCK_DMODEL=Lk, - num_warps=4, - num_stages=2, - ) - return \ No newline at end of file From 7f05167476344017d4b139ae6d368edb32463239 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 5 Jan 2026 11:04:27 +0000 Subject: [PATCH 005/114] fix att --- .../basemodel/attention/triton_backend.py | 16 +- .../triton_kernel/att/decode_att/__init__.py | 0 .../att/decode_att/gqa/__init__.py | 0 .../decode_att/gqa/flash_decoding/__init__.py | 0 .../gqa/flash_decoding}/gqa_flash_decoding.py | 0 .../gqa_flash_decoding_stage1.py | 0 .../gqa_flash_decoding_stage2.py | 0 .../flash_decoding}/gqa_flash_decoding_vsm.py | 0 .../gqa}/gqa_decode_flashattention_nopad.py | 0 .../att/decode_att/mha/__init__.py | 0 .../decode_att/mha/flash_decoding/__init__.py | 0 .../mha/flash_decoding}/flash_decoding.py | 0 .../flash_decoding}/flash_decoding_stage1.py | 0 .../flash_decoding}/flash_decoding_stage2.py | 0 .../mha/stage3_decode_att/__init__.py | 0 .../token_attention_nopad_att1.py | 233 ++++++++++++++++++ .../token_attention_nopad_reduceV.py | 223 +++++++++++++++++ .../token_attention_nopad_softmax.py | 95 +++++++ .../token_attention_softmax_and_reducev.py | 112 +++++++++ .../triton_kernel/att/prefill_att/__init__.py | 0 .../context_flashattention_nopad.py | 0 .../layer_infer/transformer_layer_infer.py | 27 -- 22 files changed, 674 insertions(+), 32 deletions(-) create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/__init__.py create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/__init__.py create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/__init__.py rename lightllm/common/basemodel/triton_kernel/att/{ => decode_att/gqa/flash_decoding}/gqa_flash_decoding.py (100%) rename lightllm/common/basemodel/triton_kernel/att/{ => decode_att/gqa/flash_decoding}/gqa_flash_decoding_stage1.py (100%) rename lightllm/common/basemodel/triton_kernel/att/{ => decode_att/gqa/flash_decoding}/gqa_flash_decoding_stage2.py (100%) rename lightllm/common/basemodel/triton_kernel/att/{ => decode_att/gqa/flash_decoding}/gqa_flash_decoding_vsm.py (100%) rename lightllm/common/basemodel/triton_kernel/att/{ => decode_att/gqa}/gqa_decode_flashattention_nopad.py (100%) create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/mha/__init__.py create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/__init__.py rename lightllm/common/basemodel/triton_kernel/att/{ => decode_att/mha/flash_decoding}/flash_decoding.py (100%) rename lightllm/common/basemodel/triton_kernel/att/{ => decode_att/mha/flash_decoding}/flash_decoding_stage1.py (100%) rename lightllm/common/basemodel/triton_kernel/att/{ => decode_att/mha/flash_decoding}/flash_decoding_stage2.py (100%) create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/__init__.py create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_att1.py create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_reduceV.py create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_softmax.py create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_softmax_and_reducev.py create mode 100644 lightllm/common/basemodel/triton_kernel/att/prefill_att/__init__.py rename lightllm/common/basemodel/triton_kernel/att/{ => prefill_att}/context_flashattention_nopad.py (100%) diff --git a/lightllm/common/basemodel/attention/triton_backend.py b/lightllm/common/basemodel/attention/triton_backend.py index 1c6c8ccd7e..7a7a88cb50 100644 --- a/lightllm/common/basemodel/attention/triton_backend.py +++ b/lightllm/common/basemodel/attention/triton_backend.py @@ -64,7 +64,7 @@ def _alibi_prefill_att( def _nomarl_prefill_att( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_weight, alloc_func=torch.empty ): - from ..triton_kernel.att.context_flashattention_nopad import context_attention_fwd + from ..triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd out = alloc_func(q.shape, q.dtype) context_attention_fwd( @@ -150,7 +150,9 @@ def _normal_decode_flash_decoding_att( layer_weight, alloc_func=torch.empty, ): - from ..triton_kernel.att.flash_decoding import token_decode_attention_flash_decoding + from ..triton_kernel.att.decode_att.mha.flash_decoding.flash_decoding import ( + token_decode_attention_flash_decoding, + ) out = alloc_func(q.shape, q.dtype) @@ -172,7 +174,9 @@ def _normal_decode_gqa_flash_decoding_att( layer_weight, alloc_func=torch.empty, ): - from ..triton_kernel.att.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding + from ..triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding import ( + gqa_token_decode_attention_flash_decoding, + ) out = alloc_func(q.shape, q.dtype) @@ -196,7 +200,9 @@ def _normal_decode_gqa_flash_decoding_att_vsm( alloc_func=torch.empty, ): # TODO USE , 在特定场景下比 _normal_decode_gqa_flash_decoding_att 省显存 - from ..triton_kernel.att.gqa_flash_decoding_vsm import gqa_token_decode_attention_flash_decoding_vsm + from ..triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding_vsm import ( + gqa_token_decode_attention_flash_decoding_vsm, + ) out = alloc_func(q.shape, q.dtype) @@ -219,7 +225,7 @@ def _normal_decode_gqa_att( alloc_func=torch.empty, ): # TODO USE , 在特定场景下比 _normal_decode_gqa_flash_decoding_att 省显存 - from ..triton_kernel.att.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd + from ..triton_kernel.att.decode_att.gqa.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd out = alloc_func(q.shape, q.dtype) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py diff --git a/lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding_stage1.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py diff --git a/lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding_stage2.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding_stage2.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py diff --git a/lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding_vsm.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/att/gqa_flash_decoding_vsm.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py diff --git a/lightllm/common/basemodel/triton_kernel/att/gqa_decode_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/gqa_decode_flashattention_nopad.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/att/gqa_decode_flashattention_nopad.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/gqa_decode_flashattention_nopad.py diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/triton_kernel/att/flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/att/flash_decoding.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding.py diff --git a/lightllm/common/basemodel/triton_kernel/att/flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage1.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/att/flash_decoding_stage1.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage1.py diff --git a/lightllm/common/basemodel/triton_kernel/att/flash_decoding_stage2.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage2.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/att/flash_decoding_stage2.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage2.py diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_att1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_att1.py new file mode 100644 index 0000000000..eb5af6fecd --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_att1.py @@ -0,0 +1,233 @@ +import torch + +import triton +import triton.language as tl +import math + + +@triton.jit +def _fwd_kernel_token_att1( + Q, + K, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + Att_Out, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + att_stride_h, + att_stride_bs, + kv_group_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_n = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + cur_batch_start_index = 0 + cur_batch_end_index = cur_batch_seq_len + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + q = tl.load(Q + off_q + start_mark) + offs_n_new = cur_batch_start_index + offs_n + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ).to(tl.int64) + off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd + k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) + att_value = tl.sum(q[None, :] * k, 1, dtype=tl.float32) + att_value *= sm_scale + off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs + tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) + return + + +@torch.no_grad() +def token_att_fwd(q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, max_len_in_batch): + BLOCK = 32 + # shape constraints + Lq, Lk = q.shape[-1], k.shape[-1] + assert Lq == Lk + assert Lk in {16, 32, 64, 128, 256} + sm_scale = 1.0 / (Lk ** 0.5) + + batch, head_num = B_req_idx.shape[0], q.shape[1] + + grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK)) + kv_group_num = q.shape[1] // k.shape[1] + + if kv_group_num == 1: + num_warps = 4 + else: + num_warps = 2 + + _fwd_kernel_token_att1[grid]( + q, + k, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + att_out.stride(0), + att_out.stride(1), + kv_group_num=kv_group_num, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + +@triton.jit +def _fwd_kernel_token_att1_int8( + Q, + K, + K_scale, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + Att_Out, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_ksbs, + stride_ksh, + stride_ksd, + att_stride_h, + att_stride_bs, + kv_group_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_n = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + cur_batch_start_index = 0 + cur_batch_end_index = cur_batch_seq_len + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + q = tl.load(Q + off_q + start_mark) + offs_n_new = cur_batch_start_index + offs_n + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ).to(tl.int64) + off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd + k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) + off_ks = k_loc[:, None] * stride_ksbs + cur_kv_head * stride_ksh + k_scale = tl.load(K_scale + off_ks, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) + att_value = tl.sum(q[None, :] * k * k_scale, 1) + att_value *= sm_scale + off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs + tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) + return + + +@torch.no_grad() +def token_att_fwd_int8k(q, k, k_scale, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, max_input_len): + BLOCK = 32 + # shape constraints + Lq, Lk = q.shape[-1], k.shape[-1] + assert Lq == Lk + assert Lk in {16, 32, 64, 128} + sm_scale = 1.0 / (Lk ** 0.5) + + batch, head_num = B_req_idx.shape[0], q.shape[1] + + grid = (batch, head_num, triton.cdiv(max_input_len, BLOCK)) + + kv_group_num = q.shape[1] // k.shape[1] + if kv_group_num == 1: + num_warps = 4 + else: + num_warps = 2 + + _fwd_kernel_token_att1_int8[grid]( + q, + k, + k_scale, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + k_scale.stride(0), + k_scale.stride(1), + k_scale.stride(2), + att_out.stride(0), + att_out.stride(1), + kv_group_num=kv_group_num, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_reduceV.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_reduceV.py new file mode 100644 index 0000000000..243a8d1f66 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_reduceV.py @@ -0,0 +1,223 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_token_att2( + Prob, + V, + Out, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_ph, + stride_pbs, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + kv_group_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_kv_head = cur_head // kv_group_num + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_start_index = 0 + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + v_loc_off = cur_batch_req_idx * stride_req_to_tokens_b + (cur_batch_start_index + offs_n) * stride_req_to_tokens_s + p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs + v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + for start_n in range(0, cur_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + p_value = tl.load(Prob + p_offs + start_n, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0) + v_loc = tl.load( + Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s, + mask=(start_n + offs_n) < cur_batch_seq_len, + other=0.0, + ).to(tl.int64) + v_value = tl.load( + V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0 + ) + acc += tl.sum(p_value[:, None] * v_value, 0) + + acc = acc.to(Out.dtype.element_ty) + off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + return + + +@torch.no_grad() +def token_att_fwd2(prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen): + BLOCK = 128 + # BLOCK = 64 # for triton 2.0.0dev + batch, head = B_req_idx.shape[0], prob.shape[0] + grid = (batch, head) + num_warps = 4 + dim = v.shape[-1] + + kv_group_num = prob.shape[0] // v.shape[1] + + _fwd_kernel_token_att2[grid]( + prob, + v, + out, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + prob.stride(0), + prob.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + +@triton.jit +def _fwd_kernel_token_att2_int8v( + Prob, + V, + V_scale, + Out, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, # B_Start_Loc 保存的是如果连续存储时候的累加输入和 + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_ph, + stride_pbs, + stride_vbs, + stride_vh, + stride_vd, + stride_vsbs, + stride_vsh, + stride_vsd, + stride_obs, + stride_oh, + stride_od, + kv_group_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_kv_head = cur_head // kv_group_num + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_start_index = 0 + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + v_loc_off = cur_batch_req_idx * stride_req_to_tokens_b + (cur_batch_start_index + offs_n) * stride_req_to_tokens_s + p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs + v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + vs_offs = cur_kv_head * stride_vsh + + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + for start_n in range(0, cur_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + p_value = tl.load(Prob + p_offs + start_n, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0) + v_loc = tl.load( + Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s, + mask=(start_n + offs_n) < cur_batch_seq_len, + other=0.0, + ).to(tl.int64) + v_value = tl.load( + V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0 + ) + vs_value = tl.load( + V_scale + vs_offs + v_loc[:, None] * stride_vsbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) + acc += tl.sum(p_value[:, None] * v_value * vs_value, 0) + + acc = acc.to(Out.dtype.element_ty) + off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + return + + +@torch.no_grad() +def token_att_fwd2_int8v(prob, v, v_scale, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, max_len_in_batch): + if max_len_in_batch < 512: + BLOCK = triton.next_power_of_2(max_len_in_batch) + else: + BLOCK = 512 + batch, head = B_req_idx.shape[0], prob.shape[0] + grid = (batch, head) + num_warps = 4 + dim = v.shape[-1] + kv_group_num = prob.shape[0] // v.shape[1] + + _fwd_kernel_token_att2_int8v[grid]( + prob, + v, + v_scale, + out, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + prob.stride(0), + prob.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + v_scale.stride(0), + v_scale.stride(1), + v_scale.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + +def torch_att(V, P, bs, seqlen, num_head, head_dim): + V = V.view(bs, seqlen, num_head, head_dim).transpose(1, 2) + P = P.reshape(num_head, bs, 1, seqlen).transpose(0, 1) + out = torch.matmul(P, V) + + return out diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_softmax.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_softmax.py new file mode 100644 index 0000000000..0bb6410e13 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_softmax.py @@ -0,0 +1,95 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_token_softmax( + Logics, + B_Start_Loc, + B_Seqlen, + Prob_Out, + stride_logic_h, + stride_logic_bs, + stride_prob_h, + stride_prob_bs, + BLOCK_SIZE: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + col_offsets = tl.arange(0, BLOCK_SIZE) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + row = tl.load( + Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, + mask=col_offsets < cur_batch_seq_len, + other=-float("inf"), + ).to(tl.float32) + + row_minus_max = row - tl.max(row, axis=0) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + + tl.store( + Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) * stride_prob_bs, + softmax_output, + mask=col_offsets < cur_batch_seq_len, + ) + return + + +@torch.no_grad() +def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len): + BLOCK_SIZE = triton.next_power_of_2(max_input_len) + batch, head_num = B_Start_Loc.shape[0], Logics.shape[0] + + num_warps = 4 + if BLOCK_SIZE >= 2048: + num_warps = 8 + if BLOCK_SIZE >= 4096: + num_warps = 16 + + _fwd_kernel_token_softmax[(batch, head_num)]( + Logics, + B_Start_Loc, + B_Seqlen, + Prob_Out, + Logics.stride(0), + Logics.stride(1), + Prob_Out.stride(0), + Prob_Out.stride(1), + num_warps=num_warps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return + + +def test1(): + + import torch + + B, N_CTX, H, D = 4, 1025, 12, 128 + del D + + dtype = torch.float16 + + Logics = torch.empty((H, B * N_CTX), dtype=dtype, device="cuda").normal_(mean=0.1, std=10) + ProbOut = torch.empty((H, B * N_CTX), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) + + b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") + b_seq_len = torch.zeros((B,), dtype=torch.int32, device="cuda") + + for i in range(B): + b_start_loc[i] = i * N_CTX + b_seq_len[i] = N_CTX + + token_softmax_fwd(Logics, b_start_loc, b_seq_len, ProbOut, N_CTX) + + torch_out = Logics.reshape(H * B, -1).softmax(-1).reshape(H, B * N_CTX) + o = ProbOut + print("max ", torch.max(torch.abs(torch_out - o))) + print("mean ", torch.mean(torch.abs(torch_out - o))) + assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_softmax_and_reducev.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_softmax_and_reducev.py new file mode 100644 index 0000000000..d963f8582a --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_softmax_and_reducev.py @@ -0,0 +1,112 @@ +import torch +import triton +import triton.language as tl +import torch.nn.functional as F + + +@triton.jit +def _fwd_kernel( + Logics, + V, + Out, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + stride_logic_h, + stride_logic_bs, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_req_to_token_b, + stride_req_to_token_s, + other_kv_index, # 避免读取到nan的数据 + kv_group_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + v_ptrs = V + off_v + + e_max = float("-inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + v_index = tl.load( + Req_to_tokens + cur_batch_req_idx * stride_req_to_token_b + (start_n + offs_n) * stride_req_to_token_s, + mask=(start_n + offs_n) < cur_batch_seq_len, + other=other_kv_index, + ).to(tl.int64) + + qk = tl.load( + Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs, + mask=start_n + offs_n < cur_batch_seq_len, + other=float("-inf"), + ) + + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + old_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + e_sum = e_sum * old_scale + tl.sum(p, 0) + v = tl.load(v_ptrs + v_index[:, None] * stride_vbs) + acc = acc * old_scale + tl.sum(p[:, None] * v, 0) + e_max = n_e_max + + acc = acc / e_sum + off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + return + + +@torch.no_grad() +def token_softmax_reducev_fwd(logics, v, o, req_to_tokens, b_req_idx, b_start_loc, b_seq_len): + BLOCK = 64 + batch, head = b_seq_len.shape[0], logics.shape[0] + grid = (batch, head) + kv_group_num = logics.shape[0] // v.shape[1] + + num_warps = 1 + _fwd_kernel[grid]( + logics, + v, + o, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, + logics.stride(0), + logics.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + req_to_tokens.stride(0), + req_to_tokens.stride(1), + 0, + kv_group_num, + BLOCK_DMODEL=v.shape[-1], + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=3, + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/att/prefill_att/__init__.py b/lightllm/common/basemodel/triton_kernel/att/prefill_att/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/triton_kernel/att/context_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/att/context_flashattention_nopad.py rename to lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index ee0e6595df..56eaa7e089 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -154,12 +154,6 @@ def _bind_attention(self): self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) elif "triton_gqa_flashdecoding" in self.mode: self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - - elif "triton_gqa_flashdecoding_vsm" in self.mode: - # self._token_attention_kernel = partial( - # LlamaTransformerLayerInfer._token_decode_attention_gqa_flashdecoding_vsm, self - # ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) elif "export_fp8kv_calibration" in self.mode: # self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_flashinfer, # self) @@ -797,27 +791,6 @@ def _token_decode_attention_ppl_int4kv_flashdecoding( alloc_tensor_func=self.alloc_tensor, ) - def _token_decode_attention_gqa_flashdecoding_vsm( - self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None - ): - from lightllm.models.llama.triton_kernel.gqa_flash_decoding_vsm import ( - gqa_token_decode_attention_flash_decoding_vsm, - ) - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - q_shape = (infer_state.batch_size, self.tp_q_head_num_, self.head_dim_) - return gqa_token_decode_attention_flash_decoding_vsm( - q.view(q_shape), - cache_k, - cache_v, - infer_state, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - def _token_decode_attention_flashattention(self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None): cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( -1, 1, self.tp_k_head_num_, self.head_dim_ From 701133074f159059d54ffb594f9b126f61fe1aff Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 5 Jan 2026 11:13:56 +0000 Subject: [PATCH 006/114] fix att --- .../basemodel/attention/triton_backend.py | 45 +++++++++++++++++++ .../layer_infer/transformer_layer_infer.py | 39 ---------------- 2 files changed, 45 insertions(+), 39 deletions(-) diff --git a/lightllm/common/basemodel/attention/triton_backend.py b/lightllm/common/basemodel/attention/triton_backend.py index 7a7a88cb50..5362b40d80 100644 --- a/lightllm/common/basemodel/attention/triton_backend.py +++ b/lightllm/common/basemodel/attention/triton_backend.py @@ -239,3 +239,48 @@ def _normal_decode_gqa_att( b_seq_len=self.infer_state.b_seq_len, ) return out + + def _normal_decode_stage3_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + alloc_func=torch.empty, + ): + total_token_num = self.infer_state.total_token_num + batch_size = self.infer_state.batch_size + q_head_num = q.shape[1] + head_dim = q.shape[2] + + calcu_shape1 = (batch_size, q_head_num, head_dim) + att_m_tensor = alloc_func((q_head_num, total_token_num), torch.float32) + + from ..triton_kernel.att.decode_att.mha.stage3_decode_att.token_attention_nopad_att1 import token_att_fwd + + token_att_fwd( + q.view(calcu_shape1), + k, + att_m_tensor, + Req_to_tokens=self.infer_state.req_manager.req_to_token_indexs, + B_req_idx=self.infer_state.b_req_idx, + B_Start_Loc=self.infer_state.b_start_loc, + B_Seqlen=self.infer_state.b_seq_len, + max_len_in_batch=self.infer_state.max_len_in_batch, + ) + + o_tensor = alloc_func(q.shape, q.dtype) + from ..triton_kernel.att.decode_att.mha.stage3_decode_att.token_attention_softmax_and_reducev import ( + token_softmax_reducev_fwd, + ) + + token_softmax_reducev_fwd( + att_m_tensor, + v, + o_tensor.view(calcu_shape1), + req_to_tokens=self.infer_state.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + b_start_loc=self.infer_state.b_start_loc, + b_seq_len=self.infer_state.b_seq_len, + ) + return o_tensor diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 56eaa7e089..ddcc365d63 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -165,9 +165,6 @@ def _bind_attention(self): # self._token_attention_kernel = partial( # LlamaTransformerLayerInfer._token_decode_attention_flashinfer, self # ) - # else: - # self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_normal, - # self) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) else: raise Exception(f"Unsupported mode: {self.mode}") @@ -563,42 +560,6 @@ def _token_decode_attention_flashinfer(self, q, infer_state: LlamaFlashInferStat ) return o_tensor - def _token_decode_attention_normal(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - total_token_num = infer_state.total_token_num - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - - att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) - - token_att_fwd( - q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - att_m_tensor, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( - token_softmax_reducev_fwd, - ) - - token_softmax_reducev_fwd( - att_m_tensor, - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - o_tensor.view(calcu_shape1), - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - ) - return o_tensor - def _token_decode_attention_int8kv(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): total_token_num = infer_state.total_token_num batch_size = infer_state.batch_size From 08f43369d917f15aadf3040637ae8097d019a0f2 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 5 Jan 2026 22:39:06 +0800 Subject: [PATCH 007/114] add att_control params --- .../common/basemodel/attention/__init__.py | 2 +- .../common/basemodel/attention/base_att.py | 22 +++++++++++++++++-- .../basemodel/attention/triton_backend.py | 10 ++++----- .../layer_infer/transformer_layer_infer.py | 10 +++++++-- 4 files changed, 34 insertions(+), 10 deletions(-) diff --git a/lightllm/common/basemodel/attention/__init__.py b/lightllm/common/basemodel/attention/__init__.py index 4e7938ec8b..379f3079ac 100644 --- a/lightllm/common/basemodel/attention/__init__.py +++ b/lightllm/common/basemodel/attention/__init__.py @@ -1,2 +1,2 @@ -from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState +from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from .triton_backend import TritonAttBackend, TritonPrefillAttState, TritonDecodeAttState diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index d3be114f9a..76d5251109 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -34,6 +34,15 @@ def create_att_decode_state(self) -> "BaseDecodeAttState": raise NotImplementedError("not impl") +@dataclass +class AttControl: + """ + prefill_att 和 decode_att 的入参,用于控制att backend 内部的行为, 选择正确的att 实现。 + """ + + use_alibi: bool = (False,) + + @dataclass class BasePrefillAttState(ABC): @@ -55,8 +64,8 @@ def prefill_att( k: torch.tensor, v: torch.tensor, layer_weight, + att_control: AttControl = AttControl(), alloc_func=torch.empty, - use_alibi=False, ) -> torch.Tensor: raise NotImplementedError("not impl") @@ -81,7 +90,16 @@ def decode_att( k: torch.Tensor, v: torch.Tensor, layer_weight, + att_control: AttControl = AttControl(), alloc_func=torch.empty, - use_alibi=False, ) -> torch.Tensor: pass + + +@dataclass +class AttControl: + """ + prefill_att 和 decode_att 的入参,用于控制att backend 内部的行为, 选择正确的att 实现。 + """ + + use_alibi: bool = False diff --git a/lightllm/common/basemodel/attention/triton_backend.py b/lightllm/common/basemodel/attention/triton_backend.py index 5362b40d80..7b6d3fd918 100644 --- a/lightllm/common/basemodel/attention/triton_backend.py +++ b/lightllm/common/basemodel/attention/triton_backend.py @@ -1,6 +1,6 @@ import dataclasses import torch -from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState +from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional @@ -26,10 +26,10 @@ def prefill_att( k: torch.Tensor, v: torch.Tensor, layer_weight, + att_control: AttControl = AttControl(), alloc_func=torch.empty, - use_alibi=False, ) -> torch.Tensor: - if use_alibi: + if att_control.use_alibi: return self._alibi_prefill_att(q=q, k=k, v=v, layer_weight=layer_weight, alloc_func=alloc_func) else: return self._nomarl_prefill_att(q=q, k=k, v=v, layer_weight=layer_weight, alloc_func=alloc_func) @@ -96,10 +96,10 @@ def decode_att( k: torch.Tensor, v: torch.Tensor, layer_weight, + att_control: AttControl = AttControl(), alloc_func=torch.empty, - use_alibi=False, ): - if use_alibi: + if att_control.use_alibi: return self._alibi_decode_att(q=q, k=k, v=v, layer_weight=layer_weight, alloc_func=alloc_func) else: q_head_num = q.shape[1] diff --git a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py index 71b710ed58..d156c09279 100755 --- a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py @@ -3,6 +3,7 @@ from lightllm.common.basemodel import TransformerLayerInferTpl from lightllm.models.bloom.layer_weights.transformer_layer_weight import BloomTransformerLayerWeight from lightllm.common.basemodel import InferStateInfo +from lightllm.common.basemodel.attention.base_att import AttControl class BloomTransformerLayerInfer(TransformerLayerInferTpl): @@ -57,8 +58,8 @@ def _context_attention_kernel( k=_k, v=_v, layer_weight=layer_weight, + att_control=AttControl(use_alibi=True), alloc_func=self.alloc_tensor, - use_alibi=True, ) o_tensor = o_tensor.view(q.shape) return o_tensor @@ -71,7 +72,12 @@ def _token_attention_kernel( _k = kv[:, 0 : self.tp_k_head_num_, :] _v = kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] o_tensor = infer_state.decode_att_state.decode_att( - q=_q, k=_k, v=_v, layer_weight=layer_weight, alloc_func=self.alloc_tensor, use_alibi=True + q=_q, + k=_k, + v=_v, + layer_weight=layer_weight, + att_control=AttControl(use_alibi=True), + alloc_func=self.alloc_tensor, ) return o_tensor.view(q.shape) From d6a44053d1572c4ceb50a6dac4013556c02a6909 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 5 Jan 2026 23:23:46 +0800 Subject: [PATCH 008/114] fix int8kv --- .../attention/int8kv_triton_backend.py | 89 +++++++++++++ .../token_attention_nopad_att1.py | 119 ------------------ .../token_attention_nopad_reduceV.py | 113 ----------------- .../layer_infer/transformer_layer_infer.py | 45 ------- 4 files changed, 89 insertions(+), 277 deletions(-) create mode 100644 lightllm/common/basemodel/attention/int8kv_triton_backend.py diff --git a/lightllm/common/basemodel/attention/int8kv_triton_backend.py b/lightllm/common/basemodel/attention/int8kv_triton_backend.py new file mode 100644 index 0000000000..e9783e5a27 --- /dev/null +++ b/lightllm/common/basemodel/attention/int8kv_triton_backend.py @@ -0,0 +1,89 @@ +import dataclasses +import torch +from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from typing import Optional, Tuple + + +class Int8kvTritonAttBackend(BaseAttBackend): + def __init__(self, quant_group_size: int): + self.quant_group_size: int = quant_group_size + + def create_att_prefill_state(self, infer_state) -> "Int8kvTritonPrefillAttState": + return Int8kvTritonPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "Int8kvTritonDecodeAttState": + return Int8kvTritonDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class Int8kvTritonPrefillAttState(BasePrefillAttState): + def init_state(self): + pass + + def copy_for_prefill_cuda_graph(self, new_state: "Int8kvTritonPrefillAttState"): + pass + + def prefill_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: Tuple[torch.Tensor, torch.Tensor], + layer_weight, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.use_alibi is False + + return self._nomarl_prefill_att(q=q, k=k, v=v, layer_weight=layer_weight, alloc_func=alloc_func) + + def _nomarl_prefill_att( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_weight, alloc_func=torch.empty + ): + raise NotImplementedError("not impl") + + +@dataclasses.dataclass +class Int8kvTritonDecodeAttState(BaseDecodeAttState): + def init_state(self): + pass + + def copy_for_decode_cuda_graph(self, new_state: "Int8kvTritonDecodeAttState"): + pass + + def decode_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: Tuple[torch.Tensor, torch.Tensor], + layer_weight, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert att_control.use_alibi is False + q = q + k, k_scale = k + v, v_scale = v + if k_scale.ndim == 3 and k_scale.shape[2] == 1: + return self._per_head_quant_decode_stage3_att( + q=q, + k=k, + k_scale=k_scale, + v=v, + v_scale=v, + layer_weight=layer_weight, + alloc_func=alloc_func, + ) + else: + raise NotImplementedError("not support decode att") + + def _per_head_quant_decode_stage3_att( + self, + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + layer_weight, + alloc_func=torch.empty, + ): + raise NotImplementedError("error") diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_att1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_att1.py index eb5af6fecd..9de2b82057 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_att1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_att1.py @@ -112,122 +112,3 @@ def token_att_fwd(q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen num_stages=1, ) return - - -@triton.jit -def _fwd_kernel_token_att1_int8( - Q, - K, - K_scale, - sm_scale, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - Att_Out, - stride_req_to_tokens_b, - stride_req_to_tokens_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_ksbs, - stride_ksh, - stride_ksd, - att_stride_h, - att_stride_bs, - kv_group_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_n = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - - cur_batch_start_index = 0 - cur_batch_end_index = cur_batch_seq_len - - off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd - - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) - - for start_mark in range(0, block_mask, 1): - q = tl.load(Q + off_q + start_mark) - offs_n_new = cur_batch_start_index + offs_n - k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new, - mask=offs_n_new < cur_batch_end_index, - other=0, - ).to(tl.int64) - off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd - k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - off_ks = k_loc[:, None] * stride_ksbs + cur_kv_head * stride_ksh - k_scale = tl.load(K_scale + off_ks, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k * k_scale, 1) - att_value *= sm_scale - off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs - tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) - return - - -@torch.no_grad() -def token_att_fwd_int8k(q, k, k_scale, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, max_input_len): - BLOCK = 32 - # shape constraints - Lq, Lk = q.shape[-1], k.shape[-1] - assert Lq == Lk - assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / (Lk ** 0.5) - - batch, head_num = B_req_idx.shape[0], q.shape[1] - - grid = (batch, head_num, triton.cdiv(max_input_len, BLOCK)) - - kv_group_num = q.shape[1] // k.shape[1] - if kv_group_num == 1: - num_warps = 4 - else: - num_warps = 2 - - _fwd_kernel_token_att1_int8[grid]( - q, - k, - k_scale, - sm_scale, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - att_out, - Req_to_tokens.stride(0), - Req_to_tokens.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - k_scale.stride(0), - k_scale.stride(1), - k_scale.stride(2), - att_out.stride(0), - att_out.stride(1), - kv_group_num=kv_group_num, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_reduceV.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_reduceV.py index 243a8d1f66..96a5b26dd6 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_reduceV.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_reduceV.py @@ -102,119 +102,6 @@ def token_att_fwd2(prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen return -@triton.jit -def _fwd_kernel_token_att2_int8v( - Prob, - V, - V_scale, - Out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, # B_Start_Loc 保存的是如果连续存储时候的累加输入和 - stride_req_to_tokens_b, - stride_req_to_tokens_s, - stride_ph, - stride_pbs, - stride_vbs, - stride_vh, - stride_vd, - stride_vsbs, - stride_vsh, - stride_vsd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - cur_kv_head = cur_head // kv_group_num - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_index = 0 - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - - v_loc_off = cur_batch_req_idx * stride_req_to_tokens_b + (cur_batch_start_index + offs_n) * stride_req_to_tokens_s - p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs - v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - vs_offs = cur_kv_head * stride_vsh - - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - p_value = tl.load(Prob + p_offs + start_n, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0) - v_loc = tl.load( - Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s, - mask=(start_n + offs_n) < cur_batch_seq_len, - other=0.0, - ).to(tl.int64) - v_value = tl.load( - V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0 - ) - vs_value = tl.load( - V_scale + vs_offs + v_loc[:, None] * stride_vsbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - acc += tl.sum(p_value[:, None] * v_value * vs_value, 0) - - acc = acc.to(Out.dtype.element_ty) - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - return - - -@torch.no_grad() -def token_att_fwd2_int8v(prob, v, v_scale, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, max_len_in_batch): - if max_len_in_batch < 512: - BLOCK = triton.next_power_of_2(max_len_in_batch) - else: - BLOCK = 512 - batch, head = B_req_idx.shape[0], prob.shape[0] - grid = (batch, head) - num_warps = 4 - dim = v.shape[-1] - kv_group_num = prob.shape[0] // v.shape[1] - - _fwd_kernel_token_att2_int8v[grid]( - prob, - v, - v_scale, - out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - Req_to_tokens.stride(0), - Req_to_tokens.stride(1), - prob.stride(0), - prob.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - v_scale.stride(0), - v_scale.stride(1), - v_scale.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - kv_group_num=kv_group_num, - BLOCK_DMODEL=dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - def torch_att(V, P, bs, seqlen, num_head, head_dim): V = V.view(bs, seqlen, num_head, head_dim).transpose(1, 2) P = P.reshape(num_head, bs, 1, seqlen).transpose(0, 1) diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index ddcc365d63..36fecc028c 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -138,9 +138,6 @@ def _bind_attention(self): # LlamaTransformerLayerInfer._token_decode_attention_ppl_fp16_flashdecoding, self # ) self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - elif "triton_int8kv" in self.mode: - # self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_int8kv, self) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_int8kv, self) elif "offline_calibration_fp8kv" in self.mode: assert get_env_start_args().enable_flashinfer_prefill and get_env_start_args().enable_flashinfer_decode self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_fp8kv, self) @@ -560,48 +557,6 @@ def _token_decode_attention_flashinfer(self, q, infer_state: LlamaFlashInferStat ) return o_tensor - def _token_decode_attention_int8kv(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - total_token_num = infer_state.total_token_num - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), q.dtype) - token_att_fwd_int8k( - q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - att_m_tensor, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - prob = self.alloc_tensor(att_m_tensor.shape, att_m_tensor.dtype) - token_softmax_fwd( - att_m_tensor, infer_state.b_start_loc, infer_state.b_seq_len, prob, infer_state.max_len_in_batch - ) - att_m_tensor = None - - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - token_att_fwd2_int8v( - prob, - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - infer_state.mem_manager.scale_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - o_tensor.view(calcu_shape1), - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - prob = None - return o_tensor - def _token_decode_attention_ppl_int8kv(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): batch_size = infer_state.batch_size calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) From 86e1f59f1ed5f6dc894d22d0ce4761989494b5cd Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 6 Jan 2026 02:34:13 +0000 Subject: [PATCH 009/114] fix --- lightllm/common/basemodel/attention/base_att.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index 76d5251109..cccc81992a 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -40,7 +40,7 @@ class AttControl: prefill_att 和 decode_att 的入参,用于控制att backend 内部的行为, 选择正确的att 实现。 """ - use_alibi: bool = (False,) + use_alibi: bool = False @dataclass From 7e45e69a30e5695beb8b67075474ba12589cc3d4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 6 Jan 2026 05:50:14 +0000 Subject: [PATCH 010/114] add new int8kv dequant triton kernel --- .../triton_kernel/kv_copy/__init__.py | 0 .../kv_copy/ppl_quant_copy_kv.py | 330 ++++++++++++++++++ .../layer_infer/transformer_layer_infer.py | 3 - .../kv_copy/test_ppl_quant_copy_kv.py | 86 +++++ 4 files changed, 416 insertions(+), 3 deletions(-) create mode 100644 lightllm/common/basemodel/triton_kernel/kv_copy/__init__.py create mode 100644 lightllm/common/basemodel/triton_kernel/kv_copy/ppl_quant_copy_kv.py create mode 100644 unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_quant_copy_kv.py diff --git a/lightllm/common/basemodel/triton_kernel/kv_copy/__init__.py b/lightllm/common/basemodel/triton_kernel/kv_copy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_quant_copy_kv.py b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_quant_copy_kv.py new file mode 100644 index 0000000000..f1d85bed84 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_quant_copy_kv.py @@ -0,0 +1,330 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_destindex_copy_quantize_kv( + K, + Dest_loc, + Out, + Out_scale, + stride_k_bs, + stride_k_h, + stride_k_g, + stride_k_d, + stride_o_bs, + stride_o_h, + stride_o_g, + stride_o_d, + stride_os_bs, + stride_os_h, + stride_os_g, + group_size, + BLOCK_GROUP_NUM: tl.constexpr, + BLOCK_GROUP_DIM: tl.constexpr, +): + cur_index = tl.program_id(0) + cur_head = tl.program_id(1) + + offs_g = tl.arange(0, BLOCK_GROUP_NUM) + offs_d = tl.arange(0, BLOCK_GROUP_DIM) + + dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) + + src_data = tl.load( + K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :], + mask=offs_g[:, None] < group_size, + other=0.0, + ) + abs_data = tl.abs(src_data) + data_scale = (tl.max(abs_data, axis=1) / 127.0).to(Out_scale.dtype.element_ty) + q_src_data = (src_data / data_scale[:, None]).to(tl.int8) + + o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :] + os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g + tl.store(o_ptrs, q_src_data, mask=offs_g[:, None] < group_size) + tl.store(os_ptrs, data_scale, mask=offs_g < group_size) + return + + +@torch.no_grad() +def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale): + seq_len = DestLoc.shape[0] + head_num = K.shape[1] + head_dim = K.shape[2] + quant_group_dim = 8 + + assert head_dim % quant_group_dim == 0, "error head dim, can not been supported to copy quant kv" + grid = (seq_len, head_num) + num_warps = 1 + + group_size = head_dim // quant_group_dim + group_dim = quant_group_dim + + K = K.view((K.shape[0], K.shape[1], group_size, group_dim)) + Out = Out.view(Out.shape[0], Out.shape[1], group_size, group_dim) + + _fwd_kernel_destindex_copy_quantize_kv[grid]( + K, + DestLoc, + Out, + Out_scale, + K.stride(0), + K.stride(1), + K.stride(2), + K.stride(3), + Out.stride(0), + Out.stride(1), + Out.stride(2), + Out.stride(3), + Out_scale.stride(0), + Out_scale.stride(1), + Out_scale.stride(2), + group_size, + BLOCK_GROUP_NUM=triton.next_power_of_2(group_size), + BLOCK_GROUP_DIM=group_dim, + num_warps=num_warps, + num_stages=1, + ) + return + + +@triton.jit +def _fwd_dequantize_int8kv( + k, + k_ss, + k_sh, + k_sg, + k_sd, + k_scale, + k_scale_ss, + k_scale_sh, + k_scale_sg, + k_scale_sd, + v, + v_ss, + v_sh, + v_sg, + v_sd, + v_scale, + v_scale_ss, + v_scale_sh, + v_scale_sg, + v_scale_sd, + req_to_token_indexs, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + b_seq_len, + b_req_idx, + b_kv_start_loc, + k_out, + k_out_ss, + k_out_sh, + k_out_sg, + k_out_sd, + v_out, + v_out_ss, + v_out_sh, + v_out_sg, + v_out_sd, + k_head_num, + v_head_num, + group_count, + group_dim, + SEQ_BLOCK_SIZE: tl.constexpr, + GROUP_COUNT_BLOCK_SIZE: tl.constexpr, + BLOCK_GROUP_DIM: tl.constexpr, +): + start_block_index = tl.program_id(0) + cur_batch = tl.program_id(1) + cur_batch_req_idx = tl.load(b_req_idx + cur_batch) + cur_seq_len = tl.load(b_seq_len + cur_batch) + if start_block_index * SEQ_BLOCK_SIZE >= cur_seq_len: + return + + out_start_loc = tl.load(b_kv_start_loc + cur_batch) + + offs_kv_loc = (start_block_index * SEQ_BLOCK_SIZE + tl.arange(0, SEQ_BLOCK_SIZE)) % cur_seq_len + kv_loc = tl.load(req_to_token_indexs + cur_batch_req_idx * stride_req_to_tokens_b + offs_kv_loc).to(tl.int64) + + offs_d = tl.arange(0, BLOCK_GROUP_DIM) % group_dim + offs_scale_d = tl.arange(0, 1) + group_offs = tl.arange(0, GROUP_COUNT_BLOCK_SIZE) % group_count + + for k_head_index in tl.range(0, k_head_num, step=1, num_stages=3): + k_int8 = tl.load( + k + + kv_loc[:, None, None] * k_ss + + k_head_index * k_sh + + group_offs[None, :, None] * k_sg + + offs_d[None, None, :] + ) + k_scale_data = tl.load( + k_scale + + kv_loc[:, None, None] * k_scale_ss + + k_head_index * k_scale_sh + + group_offs[None, :, None] * k_scale_sg + + offs_scale_d[None, None, :] + ) + k_out_data = k_int8.to(k_out.dtype.element_ty) * k_scale_data + tl.store( + k_out + + (out_start_loc + offs_kv_loc[:, None, None]) * k_out_ss + + k_head_index * k_out_sh + + group_offs[None, :, None] * k_out_sg + + offs_d[None, None, :], + k_out_data, + ) + + for v_head_index in tl.range(0, v_head_num, step=1, num_stages=3): + v_int8 = tl.load( + v + + kv_loc[:, None, None] * v_ss + + v_head_index * v_sh + + group_offs[None, :, None] * v_sg + + offs_d[None, None, :] + ) + v_scale_data = tl.load( + v_scale + + kv_loc[:, None, None] * v_scale_ss + + v_head_index * v_scale_sh + + group_offs[None, :, None] * v_scale_sg + + offs_scale_d[None, None, :] + ) + v_out_data = v_int8.to(v_out.dtype.element_ty) * v_scale_data + tl.store( + v_out + + (out_start_loc + offs_kv_loc[:, None, None]) * v_out_ss + + v_head_index * v_out_sh + + group_offs[None, :, None] * v_out_sg + + offs_d[None, None, :], + v_out_data, + ) + return + + +@torch.no_grad() +def dequantize_int8kv( + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + req_to_token_indexs: torch.Tensor, + b_seq_len: torch.Tensor, + b_req_idx: torch.Tensor, + b_kv_start_loc: torch.Tensor, + k_out: torch.Tensor, + v_out: torch.Tensor, + max_len_in_batch: int, + quant_group_size: int, +): + batch_size = b_seq_len.shape[0] + k_head_num = k.shape[1] + k_head_dim = k.shape[2] + v_head_num = v.shape[1] + v_head_dim = v.shape[2] + assert k_head_dim % quant_group_size == 0, "error head dim, can not been supported to copy quant kv" + assert v_head_dim % quant_group_size == 0, "error head dim, can not been supported to copy quant kv" + assert k_head_dim == v_head_dim, "error head dim, can not been supported to copy quant kv" + assert k_head_dim // v_scale.shape[2] == quant_group_size, "error head dim, can not been supported to copy quant kv" + assert k_head_dim in [64, 128, 256] + + group_count = k_head_dim // quant_group_size + group_dim = quant_group_size + + k = k.view((k.shape[0], k.shape[1], group_count, group_dim)) + v = v.view((v.shape[0], v.shape[1], group_count, group_dim)) + k_scale = k_scale.view((k_scale.shape[0], k_scale.shape[1], group_count, 1)) + v_scale = v_scale.view((v_scale.shape[0], v_scale.shape[1], group_count, 1)) + + # 使拆分的grid 具有足够的并行度 + SEQ_BLOCK_SIZE = 128 + while triton.cdiv(max_len_in_batch, SEQ_BLOCK_SIZE) * batch_size < 512: + SEQ_BLOCK_SIZE = SEQ_BLOCK_SIZE // 2 + if SEQ_BLOCK_SIZE <= 1: + break + + if SEQ_BLOCK_SIZE <= 1: + SEQ_BLOCK_SIZE = 8 + + grid = (triton.cdiv(max_len_in_batch, SEQ_BLOCK_SIZE), batch_size) + num_warps = 4 + k_out = k_out.view((k_out.shape[0], k_out.shape[1], group_count, group_dim)) + v_out = v_out.view((v_out.shape[0], v_out.shape[1], group_count, group_dim)) + + _fwd_dequantize_int8kv[grid]( + k=k, + k_ss=k.stride(0), + k_sh=k.stride(1), + k_sg=k.stride(2), + k_sd=k.stride(3), + k_scale=k_scale, + k_scale_ss=k_scale.stride(0), + k_scale_sh=k_scale.stride(1), + k_scale_sg=k_scale.stride(2), + k_scale_sd=k_scale.stride(2), + v=v, + v_ss=v.stride(0), + v_sh=v.stride(1), + v_sg=v.stride(2), + v_sd=v.stride(3), + v_scale=v_scale, + v_scale_ss=v_scale.stride(0), + v_scale_sh=v_scale.stride(1), + v_scale_sg=v_scale.stride(2), + v_scale_sd=v_scale.stride(3), + req_to_token_indexs=req_to_token_indexs, + stride_req_to_tokens_b=req_to_token_indexs.stride(0), + stride_req_to_tokens_s=req_to_token_indexs.stride(1), + b_seq_len=b_seq_len, + b_req_idx=b_req_idx, + b_kv_start_loc=b_kv_start_loc, + k_out=k_out, + k_out_ss=k_out.stride(0), + k_out_sh=k_out.stride(1), + k_out_sg=k_out.stride(2), + k_out_sd=k_out.stride(3), + v_out=v_out, + v_out_ss=v_out.stride(0), + v_out_sh=v_out.stride(1), + v_out_sg=v_out.stride(2), + v_out_sd=v_out.stride(3), + k_head_num=k_head_num, + v_head_num=v_head_num, + group_count=group_count, + group_dim=group_dim, + SEQ_BLOCK_SIZE=SEQ_BLOCK_SIZE, + GROUP_COUNT_BLOCK_SIZE=triton.next_power_of_2(group_count), + BLOCK_GROUP_DIM=triton.next_power_of_2(group_dim), + num_warps=num_warps, + num_stages=1, + ) + return + + +def test2(): + import time + + B, N_CTX, H, D = 1, 3, 12, 128 + src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() + dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32).cuda() + value_dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda().to(torch.int8) + scale_dest = torch.randn((B * N_CTX, H, D // 8), dtype=torch.float16).cuda() + + for _ in range(10): + destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) + torch.cuda.synchronize() + t1 = time.time() + for _ in range(1000): + destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) + torch.cuda.synchronize() + t2 = time.time() + + print("Time cost ", t2 - t1) + value_dest = value_dest.view((B * N_CTX, H, D // 8, 8)) + scale_dest = scale_dest.view((B * N_CTX, H, D // 8, 1)) + print("max ", torch.max(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src))) + print("mean ", torch.mean(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src))) + cos = torch.nn.CosineSimilarity(0) + print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32))) diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 36fecc028c..bcfc710861 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -7,9 +7,6 @@ from functools import partial from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight -from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd, token_att_fwd_int8k -from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd -from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2, token_att_fwd2_int8v from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd diff --git a/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_quant_copy_kv.py b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_quant_copy_kv.py new file mode 100644 index 0000000000..a3449de808 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_quant_copy_kv.py @@ -0,0 +1,86 @@ +import torch +import time +import pytest +from lightllm.common.basemodel.triton_kernel.kv_copy.ppl_quant_copy_kv import ( + dequantize_int8kv, + destindex_copy_quantize_kv, +) +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def torch_dequant(kv, kv_scale, b_req_idx, b_seq_len, req_to_token_indexs, odtype, group_quant_size): + batch = b_req_idx.shape[0] + tmp_out = [] + for i in range(batch): + req_idx = b_req_idx[i] + seq_len = b_seq_len[i] + kv_loc = req_to_token_indexs[req_idx, :seq_len] + head_num = kv.shape[1] + cur_kv = kv[kv_loc, :, :].reshape(seq_len, head_num, -1, group_quant_size).to(odtype) + cur_scale = kv_scale[kv_loc, :, :].reshape(seq_len, head_num, -1, 1) + out = cur_kv * cur_scale + tmp_out.append(out.reshape(seq_len, head_num, -1)) + return torch.cat(tmp_out, dim=0) + + +@pytest.mark.parametrize( + "B, H, N_CTX, D_HEAD, group_quant_size", + [ + (b, H, N_CTX, D_HEAD, group_quant_size) + for b in [1, 2, 4] + for H in [1, 8] + for N_CTX in [3, 10, 1024] + for D_HEAD in [64, 128] + for group_quant_size in [8, 16] + ], +) +def test_dequantize_int8kv(B, H, N_CTX, D_HEAD, group_quant_size): + dtype = torch.bfloat16 + kv = torch.empty((B * N_CTX, 2 * H, D_HEAD), dtype=torch.int8, device="cuda").random_(-10, 10) + kv_scale = torch.randn((B * N_CTX, 2 * H, D_HEAD // group_quant_size), dtype=dtype, device="cuda") + out = torch.empty((B * N_CTX, 2 * H, D_HEAD), dtype=dtype, device="cuda") + req_to_token_indexs = torch.empty((B, N_CTX), dtype=torch.int32, device="cuda") + max_input_len = N_CTX + b_seq_len = torch.ones((B,), dtype=torch.int32, device="cuda") + b_seq_len.fill_(N_CTX) + b_req_idx = torch.arange(0, B, dtype=torch.int32, device="cuda") + req_to_token_indexs.view(-1)[:] = torch.arange(0, B * N_CTX, dtype=torch.int32, device="cuda") + b_kv_start_loc = torch.cumsum(b_seq_len, dim=0, dtype=torch.int32) - b_seq_len + + k = kv[:, :H, :] + v = kv[:, H:, :] + k_scale = kv_scale[:, :H, :] + v_scale = kv_scale[:, H:, :] + + ground_out = torch_dequant( + kv=kv, + kv_scale=kv_scale, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + req_to_token_indexs=req_to_token_indexs, + odtype=out.dtype, + group_quant_size=group_quant_size, + ) + dequantize_int8kv( + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + req_to_token_indexs=req_to_token_indexs, + b_seq_len=b_seq_len, + b_req_idx=b_req_idx, + b_kv_start_loc=b_kv_start_loc, + k_out=out[:, :H, :], + v_out=out[:, H:, :], + max_len_in_batch=max_input_len, + quant_group_size=group_quant_size, + ) + assert torch.allclose(out, ground_out, atol=1e-2, rtol=0) + cos = torch.nn.CosineSimilarity(0) + assert cos(out.flatten().float(), ground_out.flatten().float()) > 0.99 + + +if __name__ == "__main__": + pytest.main() From f44623eb30b2e299ac8678b3794377f775300f86 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 6 Jan 2026 07:05:13 +0000 Subject: [PATCH 011/114] fix int8kv prefill attention kernel --- .../attention/int8kv_triton_backend.py | 74 +++++++++++- .../context_flashattention_nopad.py | 66 +++++------ .../test_context_flashattention_nopad.py | 106 ++++++++++++++++++ 3 files changed, 210 insertions(+), 36 deletions(-) create mode 100644 unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad.py diff --git a/lightllm/common/basemodel/attention/int8kv_triton_backend.py b/lightllm/common/basemodel/attention/int8kv_triton_backend.py index e9783e5a27..bf0b16d384 100644 --- a/lightllm/common/basemodel/attention/int8kv_triton_backend.py +++ b/lightllm/common/basemodel/attention/int8kv_triton_backend.py @@ -17,8 +17,15 @@ def create_att_decode_state(self, infer_state) -> "Int8kvTritonDecodeAttState": @dataclasses.dataclass class Int8kvTritonPrefillAttState(BasePrefillAttState): + + # 用于反量化的时候使用,可以减少反量化占用的显存数量。按需使用。 + b_kv_start_loc: torch.Tensor = None + def init_state(self): - pass + self.b_kv_start_loc = ( + torch.cumsum(self.infer_state.b_seq_len, dim=0, dtype=self.infer_state.b_seq_len.dtype) + - self.infer_state.b_seq_len + ) def copy_for_prefill_cuda_graph(self, new_state: "Int8kvTritonPrefillAttState"): pass @@ -34,12 +41,71 @@ def prefill_att( ) -> torch.Tensor: assert att_control.use_alibi is False + self.backend: Int8kvTritonAttBackend = self.backend # for typing + if self.backend.quant_group_size == 8: + pass + # context_attention_fwd_ppl_int8kv( + # q.view(-1, self.tp_q_head_num_, self.head_dim_), + # kv_dequant[:, 0 : self.tp_k_head_num_, :, :], + # kv_dequant[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :, :], + # o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), + # infer_state.b_start_loc, + # infer_state.b_seq_len, + # infer_state.max_len_in_batch, + # infer_state.b_ready_cache_len, + # ) + return self._nomarl_prefill_att(q=q, k=k, v=v, layer_weight=layer_weight, alloc_func=alloc_func) - def _nomarl_prefill_att( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_weight, alloc_func=torch.empty + def _groupsize8_quant_prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + layer_weight, + alloc_func=torch.empty, ): - raise NotImplementedError("not impl") + # o_tensor = alloc_func(q.shape, q.dtype, device=q.device) + # batch_size = self.infer_state.b_seq_len.shape[0] + + assert k.untyped_storage().data_ptr() == v.untyped_storage().data_ptr() + assert k_scale.untyped_storage().data_ptr() == v_scale.untyped_storage().data_ptr() + + total_token_num = self.infer_state.total_token_num + k_dequant = alloc_func((total_token_num, k.shape[1], k.shape[2]), dtype=q.dtype, device=q.device) + v_dequant = alloc_func((total_token_num, v.shape[1], v.shape[2]), dtype=q.dtype, device=q.device) + + max_kv_seq_len = self.infer_state.max_kv_seq_len + + from ..triton_kernel.kv_copy.ppl_quant_copy_kv import dequantize_int8kv + + dequantize_int8kv( + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + req_to_token_indexs=self.infer_state.req_manager.req_to_token_indexs, + b_seq_len=self.infer_state.b_seq_len, + b_req_idx=self.infer_state.b_req_idx, + b_kv_start_loc=self.b_kv_start_loc, + k_out=k_dequant, + v_out=v_dequant, + max_len_in_batch=max_kv_seq_len, + quant_group_size=self.backend.quant_group_size, + ) + + # context_attention_fwd_ppl_int8kv( + # q.view(-1, self.tp_q_head_num_, self.head_dim_), + # kv_dequant[:, 0 : self.tp_k_head_num_, :, :], + # kv_dequant[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :, :], + # o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), + # infer_state.b_start_loc, + # infer_state.b_seq_len, + # infer_state.max_len_in_batch, + # infer_state.b_ready_cache_len, + # ) @dataclasses.dataclass diff --git a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py index e36c51b394..8ba1f8d9ac 100644 --- a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py +++ b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py @@ -343,18 +343,17 @@ def _fwd_kernel_int8kv( sm_scale, Out, B_Start_Loc, + B_kv_start_loc, B_Seqlen, b_prompt_cache_len, stride_qbs, stride_qh, stride_qd, - stride_kb, - stride_kh, stride_ks, + stride_kh, stride_kd, - stride_vb, - stride_vh, stride_vs, + stride_vh, stride_vd, stride_obs, stride_oh, @@ -374,6 +373,7 @@ def _fwd_kernel_int8kv( prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len + kv_start_loc = tl.load(B_kv_start_loc + cur_batch) block_start_loc = BLOCK_M * start_m @@ -393,6 +393,9 @@ def _fwd_kernel_int8kv( l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + stride_ks = tl.cast(stride_ks, tl.int64) + stride_vs = tl.cast(stride_vs, tl.int64) + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) block_end_loc = tl.minimum(block_start_loc + BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len) # causal mask @@ -405,8 +408,7 @@ def _fwd_kernel_int8kv( # other=0, # ) off_k = ( - cur_batch * stride_kb - + (start_n + offs_n[None, :]) * stride_ks + +(kv_start_loc + start_n + offs_n[None, :]) * stride_ks + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd ) @@ -432,8 +434,7 @@ def _fwd_kernel_int8kv( # other=0.0, # ) off_v = ( - cur_batch * stride_vb - + (start_n + offs_n[:, None]) * stride_vs + +(kv_start_loc + start_n + offs_n[:, None]) * stride_vs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd ) @@ -455,7 +456,9 @@ def _fwd_kernel_int8kv( @torch.no_grad() -def context_attention_fwd_ppl_int8kv(q, k, v, o, b_start_loc, b_seq_len, max_input_len, b_prompt_cache_len): +def context_attention_fwd_ppl_int8kv( + q, k, v, o, b_start_loc, b_kv_start_loc, b_seq_len, max_q_input_len, b_prompt_cache_len +): BLOCK_M = 128 if not is_tesla() else 64 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] @@ -468,34 +471,33 @@ def context_attention_fwd_ppl_int8kv(q, k, v, o, b_start_loc, b_seq_len, max_inp batch, head = b_seq_len.shape[0], q.shape[1] kv_group_num = q.shape[1] // k.shape[1] - grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1) + grid = lambda meta: (triton.cdiv(max_q_input_len, meta["BLOCK_M"]), batch * head, 1) BLOCK_N = BLOCK_M num_warps = 4 if Lk <= 64 else 8 num_stages = 1 _fwd_kernel_int8kv[grid]( - q, - k, - v, - sm_scale, - o, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - k.stride(3), - v.stride(0), - v.stride(1), - v.stride(2), - v.stride(3), - o.stride(0), - o.stride(1), - o.stride(2), + Q=q, + K=k, + V=v, + sm_scale=sm_scale, + Out=o, + B_Start_Loc=b_start_loc, + B_kv_start_loc=b_kv_start_loc, + B_Seqlen=b_seq_len, + b_prompt_cache_len=b_prompt_cache_len, + stride_qbs=q.stride(0), + stride_qh=q.stride(1), + stride_qd=q.stride(2), + stride_ks=k.stride(0), + stride_kh=k.stride(1), + stride_kd=k.stride(2), + stride_vs=v.stride(0), + stride_vh=v.stride(1), + stride_vd=v.stride(2), + stride_obs=o.stride(0), + stride_oh=o.stride(1), + stride_od=o.stride(2), kv_group_num=kv_group_num, H=head, BLOCK_DMODEL=Lk, diff --git a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad.py new file mode 100644 index 0000000000..5e4d815824 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad.py @@ -0,0 +1,106 @@ +import torch +import time +import pytest +from lightllm.common.basemodel.triton_kernel.att.prefill_att.context_flashattention_nopad import ( + context_attention_fwd_ppl_int8kv, +) +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def torch_context_attention_fwd2(q, k, v, o, b_start_loc, b_seq_len, b_prompt_cache_len): + batch = b_start_loc.shape[0] + k = k.transpose(1, 2) + v = v.transpose(1, 2) + for i in range(batch): + start_loc = b_start_loc[i] + seq_len = b_seq_len[i] + prompt_cache_len = b_prompt_cache_len[i] + cur_q = q[start_loc : start_loc + seq_len - prompt_cache_len, :, :] + cur_q = cur_q.clone().to(torch.float32) + cur_k = k[i, :seq_len, :] + cur_k = cur_k.clone().to(torch.float32) + + cur_v = v[i, :seq_len, :] + cur_v = cur_v.clone().to(torch.float32) + + cur_q = cur_q.transpose(0, 1) + cur_k = cur_k.transpose(0, 1) + cur_v = cur_v.transpose(0, 1) + dk = cur_q.shape[-1] + + p = torch.matmul(cur_q, cur_k.transpose(-2, -1)) / torch.sqrt(torch.tensor(dk, dtype=torch.float32)) + + q_index = torch.arange(cur_q.shape[1]).unsqueeze(-1).to(p.device) + k_index = torch.arange(cur_k.shape[1]).unsqueeze(0).to(p.device) + mask = (q_index + prompt_cache_len >= k_index).int() + mask = mask.unsqueeze(0).expand(cur_q.shape[0], -1, -1) + + p = p.masked_fill(mask == 0, float("-inf")) + + s = torch.nn.functional.softmax(p, dim=-1) + + o[start_loc : start_loc + seq_len - prompt_cache_len, :, :] = torch.matmul(s, cur_v).transpose(0, 1) + + +@pytest.mark.parametrize( + "B, H, N_CTX, D_HEAD, prompt_cache_len", + [ + (b, H, N_CTX, D_HEAD, prompt_cache_len) + for b in [1, 2, 4] + for H in [1, 8] + for N_CTX in [3, 10, 1024] + for D_HEAD in [64, 128] + for prompt_cache_len in [0, 56, 200] + ], +) +def test_context_attention_fwd_ppl_int8kv(B, H, N_CTX, D_HEAD, prompt_cache_len): + dtype = torch.float16 + prompt_cache_len = 0 + if prompt_cache_len > N_CTX: + return + + q = torch.empty((B * (N_CTX - prompt_cache_len), H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + kv = torch.empty((B, 2 * H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) + k = kv[:, :H, :] + v = kv[:, H:, :] + + o = torch.empty((B * (N_CTX - prompt_cache_len), H, D_HEAD), dtype=dtype, device="cuda") + torch_o = torch.empty((B * (N_CTX - prompt_cache_len), H, D_HEAD), dtype=dtype, device="cuda") + + max_q_input_len = N_CTX - prompt_cache_len + + b_seq_len = torch.ones((B,), dtype=torch.int32, device="cuda") + b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") + b_seq_len = torch.ones((B,), dtype=torch.int32, device="cuda") + b_prompt_cache_len = torch.zeros(B, dtype=torch.int32, device="cuda") + + for i in range(B): + b_seq_len[i] = N_CTX + if i != 0: + b_start_loc[i] = b_start_loc[i - 1] + N_CTX - prompt_cache_len + b_prompt_cache_len[i] = prompt_cache_len + + torch_context_attention_fwd2(q, k, v, torch_o, b_start_loc, b_seq_len, b_prompt_cache_len) + + b_kv_start_loc = torch.cumsum(b_seq_len, dim=0, dtype=torch.int32) - b_seq_len + context_attention_fwd_ppl_int8kv( + q=q, + k=k, + v=v, + o=o, + b_start_loc=b_start_loc, + b_kv_start_loc=b_kv_start_loc, + b_seq_len=b_seq_len, + max_q_input_len=max_q_input_len, + b_prompt_cache_len=b_prompt_cache_len, + ) + + assert torch.allclose(torch_o, o, atol=1e-2, rtol=0) + cos = torch.nn.CosineSimilarity(0) + assert cos(o.flatten().float(), torch_o.flatten().float()) > 0.99 + + +if __name__ == "__main__": + pytest.main() From fa0f51247c15ff4eba6f5c939cd27b26ee16d0bf Mon Sep 17 00:00:00 2001 From: root Date: Tue, 6 Jan 2026 12:20:00 +0000 Subject: [PATCH 012/114] fix unittest --- .../context_flashattention_nopad.py | 4 +-- .../test_context_flashattention_nopad.py | 34 +++++++++---------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py index 8ba1f8d9ac..29a56603b9 100644 --- a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py +++ b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py @@ -408,7 +408,7 @@ def _fwd_kernel_int8kv( # other=0, # ) off_k = ( - +(kv_start_loc + start_n + offs_n[None, :]) * stride_ks + (kv_start_loc + start_n + offs_n[None, :]) * stride_ks + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd ) @@ -434,7 +434,7 @@ def _fwd_kernel_int8kv( # other=0.0, # ) off_v = ( - +(kv_start_loc + start_n + offs_n[:, None]) * stride_vs + (kv_start_loc + start_n + offs_n[:, None]) * stride_vs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd ) diff --git a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad.py index 5e4d815824..5ab4be6b33 100644 --- a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad.py +++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad.py @@ -9,35 +9,34 @@ logger = init_logger(__name__) -def torch_context_attention_fwd2(q, k, v, o, b_start_loc, b_seq_len, b_prompt_cache_len): +def torch_context_attention_fwd2(q, k, v, o, b_start_loc, b_kv_start_loc, b_seq_len, b_prompt_cache_len): batch = b_start_loc.shape[0] - k = k.transpose(1, 2) - v = v.transpose(1, 2) + for i in range(batch): start_loc = b_start_loc[i] + kv_start_loc = b_kv_start_loc[i] seq_len = b_seq_len[i] prompt_cache_len = b_prompt_cache_len[i] cur_q = q[start_loc : start_loc + seq_len - prompt_cache_len, :, :] cur_q = cur_q.clone().to(torch.float32) - cur_k = k[i, :seq_len, :] + cur_k = k[kv_start_loc : (kv_start_loc + seq_len), :, :] cur_k = cur_k.clone().to(torch.float32) - cur_v = v[i, :seq_len, :] + cur_v = v[kv_start_loc : (kv_start_loc + seq_len), :, :] cur_v = cur_v.clone().to(torch.float32) - cur_q = cur_q.transpose(0, 1) - cur_k = cur_k.transpose(0, 1) - cur_v = cur_v.transpose(0, 1) + dk = cur_q.shape[-1] + cur_q = cur_q.permute(1, 0, 2) + cur_k = cur_k.permute(1, 2, 0) + cur_v = cur_v.permute(1, 0, 2) dk = cur_q.shape[-1] - p = torch.matmul(cur_q, cur_k.transpose(-2, -1)) / torch.sqrt(torch.tensor(dk, dtype=torch.float32)) + p = torch.matmul(cur_q, cur_k) / torch.sqrt(torch.tensor(dk, dtype=torch.float32)) - q_index = torch.arange(cur_q.shape[1]).unsqueeze(-1).to(p.device) - k_index = torch.arange(cur_k.shape[1]).unsqueeze(0).to(p.device) - mask = (q_index + prompt_cache_len >= k_index).int() - mask = mask.unsqueeze(0).expand(cur_q.shape[0], -1, -1) + q_index = (torch.arange(cur_q.shape[1]).to(p.device) + prompt_cache_len).view(-1, 1) + k_index = torch.arange(seq_len).to(p.device).view(1, -1) - p = p.masked_fill(mask == 0, float("-inf")) + p[:, (q_index < k_index)] = float("-inf") s = torch.nn.functional.softmax(p, dim=-1) @@ -58,11 +57,11 @@ def torch_context_attention_fwd2(q, k, v, o, b_start_loc, b_seq_len, b_prompt_ca def test_context_attention_fwd_ppl_int8kv(B, H, N_CTX, D_HEAD, prompt_cache_len): dtype = torch.float16 prompt_cache_len = 0 - if prompt_cache_len > N_CTX: + if prompt_cache_len >= N_CTX - 1: return q = torch.empty((B * (N_CTX - prompt_cache_len), H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - kv = torch.empty((B, 2 * H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) + kv = torch.empty((B * N_CTX, 2 * H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) k = kv[:, :H, :] v = kv[:, H:, :] @@ -82,9 +81,8 @@ def test_context_attention_fwd_ppl_int8kv(B, H, N_CTX, D_HEAD, prompt_cache_len) b_start_loc[i] = b_start_loc[i - 1] + N_CTX - prompt_cache_len b_prompt_cache_len[i] = prompt_cache_len - torch_context_attention_fwd2(q, k, v, torch_o, b_start_loc, b_seq_len, b_prompt_cache_len) - b_kv_start_loc = torch.cumsum(b_seq_len, dim=0, dtype=torch.int32) - b_seq_len + torch_context_attention_fwd2(q, k, v, torch_o, b_start_loc, b_kv_start_loc, b_seq_len, b_prompt_cache_len) context_attention_fwd_ppl_int8kv( q=q, k=k, From 650db11c5457f8ac24f5960702fca0c2fcb67618 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 6 Jan 2026 12:30:16 +0000 Subject: [PATCH 013/114] fix int8kv att backend. --- .../attention/int8kv_triton_backend.py | 52 ++++++++++--------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/lightllm/common/basemodel/attention/int8kv_triton_backend.py b/lightllm/common/basemodel/attention/int8kv_triton_backend.py index bf0b16d384..c3de9c5640 100644 --- a/lightllm/common/basemodel/attention/int8kv_triton_backend.py +++ b/lightllm/common/basemodel/attention/int8kv_triton_backend.py @@ -44,20 +44,20 @@ def prefill_att( self.backend: Int8kvTritonAttBackend = self.backend # for typing if self.backend.quant_group_size == 8: pass - # context_attention_fwd_ppl_int8kv( - # q.view(-1, self.tp_q_head_num_, self.head_dim_), - # kv_dequant[:, 0 : self.tp_k_head_num_, :, :], - # kv_dequant[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :, :], - # o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - # infer_state.b_start_loc, - # infer_state.b_seq_len, - # infer_state.max_len_in_batch, - # infer_state.b_ready_cache_len, - # ) - - return self._nomarl_prefill_att(q=q, k=k, v=v, layer_weight=layer_weight, alloc_func=alloc_func) - - def _groupsize8_quant_prefill_att( + k, k_scale = k + v, v_scale = v + o = self._groupsize_quant_prefill_att( + q=q, + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + layer_weight=layer_weight, + alloc_func=alloc_func, + ) + return o + + def _groupsize_quant_prefill_att( self, q: torch.Tensor, k: torch.Tensor, @@ -76,6 +76,7 @@ def _groupsize8_quant_prefill_att( total_token_num = self.infer_state.total_token_num k_dequant = alloc_func((total_token_num, k.shape[1], k.shape[2]), dtype=q.dtype, device=q.device) v_dequant = alloc_func((total_token_num, v.shape[1], v.shape[2]), dtype=q.dtype, device=q.device) + o_tensor = alloc_func(q.shape, dtype=q.dtype, device=q.device) max_kv_seq_len = self.infer_state.max_kv_seq_len @@ -96,16 +97,19 @@ def _groupsize8_quant_prefill_att( quant_group_size=self.backend.quant_group_size, ) - # context_attention_fwd_ppl_int8kv( - # q.view(-1, self.tp_q_head_num_, self.head_dim_), - # kv_dequant[:, 0 : self.tp_k_head_num_, :, :], - # kv_dequant[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :, :], - # o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - # infer_state.b_start_loc, - # infer_state.b_seq_len, - # infer_state.max_len_in_batch, - # infer_state.b_ready_cache_len, - # ) + from ..triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd_ppl_int8kv + + context_attention_fwd_ppl_int8kv( + q=q, + k=k_dequant, + v=v_dequant, + o=o_tensor, + b_start_loc=self.infer_state.b_start_loc, + b_kv_start_loc=self.b_kv_start_loc, + b_seq_len=self.infer_state.b_seq_len, + max_q_input_len=self.infer_state.max_q_seq_len, + b_prompt_cache_len=self.infer_state.b_ready_cache_len, + ) @dataclasses.dataclass From cc55263901c0d0555567ecfcf524c887b60010d7 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 6 Jan 2026 12:31:03 +0000 Subject: [PATCH 014/114] fix --- lightllm/common/basemodel/attention/int8kv_triton_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/attention/int8kv_triton_backend.py b/lightllm/common/basemodel/attention/int8kv_triton_backend.py index c3de9c5640..da34583858 100644 --- a/lightllm/common/basemodel/attention/int8kv_triton_backend.py +++ b/lightllm/common/basemodel/attention/int8kv_triton_backend.py @@ -66,7 +66,7 @@ def _groupsize_quant_prefill_att( v_scale: torch.Tensor, layer_weight, alloc_func=torch.empty, - ): + ) -> torch.Tensor: # o_tensor = alloc_func(q.shape, q.dtype, device=q.device) # batch_size = self.infer_state.b_seq_len.shape[0] @@ -110,6 +110,7 @@ def _groupsize_quant_prefill_att( max_q_input_len=self.infer_state.max_q_seq_len, b_prompt_cache_len=self.infer_state.b_ready_cache_len, ) + return o_tensor @dataclasses.dataclass From 2ab9a1ebdf2c4f66afdbd3d15d199ed9ee96c523 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 6 Jan 2026 13:12:18 +0000 Subject: [PATCH 015/114] fix diverse att. --- .../attention/int8kv_triton_backend.py | 35 ++++++++++-------- .../att/decode_att/int8kv_gqa/__init__.py | 0 .../ppl_int8kv_flash_decoding_diverse.py | 5 +-- ...pl_int8kv_flash_decoding_diverse_stage1.py | 0 ...pl_int8kv_flash_decoding_diverse_stage3.py | 0 .../layer_infer/transformer_layer_infer.py | 36 ------------------- 6 files changed, 23 insertions(+), 53 deletions(-) create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv_gqa/__init__.py rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att/decode_att/int8kv_gqa}/ppl_int8kv_flash_decoding_diverse.py (98%) rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att/decode_att/int8kv_gqa}/ppl_int8kv_flash_decoding_diverse_stage1.py (100%) rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att/decode_att/int8kv_gqa}/ppl_int8kv_flash_decoding_diverse_stage3.py (100%) mode change 100755 => 100644 lightllm/models/llama/layer_infer/transformer_layer_infer.py diff --git a/lightllm/common/basemodel/attention/int8kv_triton_backend.py b/lightllm/common/basemodel/attention/int8kv_triton_backend.py index da34583858..5f01cf8be6 100644 --- a/lightllm/common/basemodel/attention/int8kv_triton_backend.py +++ b/lightllm/common/basemodel/attention/int8kv_triton_backend.py @@ -2,6 +2,7 @@ import torch from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional, Tuple +from lightllm.utils.envs_utils import enable_diverse_mode_gqa_decode_fast_kernel class Int8kvTritonAttBackend(BaseAttBackend): @@ -131,23 +132,14 @@ def decode_att( alloc_func=torch.empty, ): assert att_control.use_alibi is False - q = q k, k_scale = k v, v_scale = v - if k_scale.ndim == 3 and k_scale.shape[2] == 1: - return self._per_head_quant_decode_stage3_att( - q=q, - k=k, - k_scale=k_scale, - v=v, - v_scale=v, - layer_weight=layer_weight, - alloc_func=alloc_func, + if enable_diverse_mode_gqa_decode_fast_kernel(): + return self.diverse_decode_att( + q=q, k=k, k_scale=k_scale, v=v, v_scale=v_scale, layer_weight=layer_weight, alloc_func=alloc_func ) - else: - raise NotImplementedError("not support decode att") - def _per_head_quant_decode_stage3_att( + def diverse_decode_att( self, q: torch.Tensor, k: torch.Tensor, @@ -156,5 +148,18 @@ def _per_head_quant_decode_stage3_att( v_scale: torch.Tensor, layer_weight, alloc_func=torch.empty, - ): - raise NotImplementedError("error") + ) -> torch.Tensor: + + from ..triton_kernel.att.decode_att.int8kv_gqa.ppl_int8kv_flash_decoding_diverse import ( + token_decode_attention_flash_decoding, + ) + + return token_decode_attention_flash_decoding( + q=q, + infer_state=self.infer_state, + cache_k=k, + cache_k_scale=k_scale, + cache_v=v, + cache_v_scale=v_scale, + alloc_tensor_func=alloc_func, + ) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv_gqa/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv_gqa/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv_gqa/ppl_int8kv_flash_decoding_diverse.py similarity index 98% rename from lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv_gqa/ppl_int8kv_flash_decoding_diverse.py index 84054bf867..d42a1a12a3 100644 --- a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv_gqa/ppl_int8kv_flash_decoding_diverse.py @@ -10,8 +10,6 @@ def token_decode_attention_flash_decoding( q, infer_state: InferStateInfo, - q_head_num, - head_dim, cache_k, cache_k_scale, cache_v, @@ -28,6 +26,9 @@ def token_decode_attention_flash_decoding( stream1 = shared_streams_dict["stream1"] stream2 = shared_streams_dict["stream2"] + q_head_num = q.shape[1] + head_dim = q.shape[2] + BLOCK_SEQ = 256 batch_size = infer_state.batch_size max_len_in_batch = infer_state.max_len_in_batch diff --git a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv_gqa/ppl_int8kv_flash_decoding_diverse_stage1.py similarity index 100% rename from lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse_stage1.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv_gqa/ppl_int8kv_flash_decoding_diverse_stage1.py diff --git a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse_stage3.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv_gqa/ppl_int8kv_flash_decoding_diverse_stage3.py similarity index 100% rename from lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse_stage3.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv_gqa/ppl_int8kv_flash_decoding_diverse_stage3.py diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py old mode 100755 new mode 100644 index bcfc710861..f55cd618b2 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -104,14 +104,6 @@ def _bind_attention(self): # self._context_attention_kernel = partial( # LlamaTransformerLayerInfer._context_attention_kernel_ppl_int8kv, self # ) - elif "ppl_int8kv_flashdecoding_diverse" in self.mode: - # self._token_attention_kernel = partial( - # LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv_flashdecoding_diverse, self - # ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, self) - # self._context_attention_kernel = partial( - # LlamaTransformerLayerInfer._context_attention_kernel_ppl_int8kv, self - # ) elif "ppl_int8kv_flashdecoding" in self.mode: # self._token_attention_kernel = partial( # LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv_flashdecoding, self @@ -650,34 +642,6 @@ def _token_decode_attention_ppl_int8kv_flashdecoding( alloc_tensor_func=self.alloc_tensor, ) - def _token_decode_attention_ppl_int8kv_flashdecoding_diverse( - self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None - ): - from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse import ( - token_decode_attention_flash_decoding, - ) - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_k_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - cache_v_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - def _token_decode_attention_ppl_int4kv_flashdecoding( self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None ): From 70cbed2c70588c5664af71d824a606effa092291 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 6 Jan 2026 13:32:05 +0000 Subject: [PATCH 016/114] fix --- .../attention/int8kv_triton_backend.py | 36 ++++++++++++++++++- .../{int8kv_gqa => int8kv}/__init__.py | 0 .../int8kv}/ppl_int8kv_flash_decoding.py | 5 ++- .../ppl_int8kv_flash_decoding_diverse.py | 0 ...pl_int8kv_flash_decoding_diverse_stage1.py | 0 ...pl_int8kv_flash_decoding_diverse_stage3.py | 0 .../layer_infer/transformer_layer_infer.py | 36 ------------------- 7 files changed, 37 insertions(+), 40 deletions(-) rename lightllm/common/basemodel/triton_kernel/att/decode_att/{int8kv_gqa => int8kv}/__init__.py (100%) rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att/decode_att/int8kv}/ppl_int8kv_flash_decoding.py (91%) rename lightllm/common/basemodel/triton_kernel/att/decode_att/{int8kv_gqa => int8kv}/ppl_int8kv_flash_decoding_diverse.py (100%) rename lightllm/common/basemodel/triton_kernel/att/decode_att/{int8kv_gqa => int8kv}/ppl_int8kv_flash_decoding_diverse_stage1.py (100%) rename lightllm/common/basemodel/triton_kernel/att/decode_att/{int8kv_gqa => int8kv}/ppl_int8kv_flash_decoding_diverse_stage3.py (100%) diff --git a/lightllm/common/basemodel/attention/int8kv_triton_backend.py b/lightllm/common/basemodel/attention/int8kv_triton_backend.py index 5f01cf8be6..9ec57d0c59 100644 --- a/lightllm/common/basemodel/attention/int8kv_triton_backend.py +++ b/lightllm/common/basemodel/attention/int8kv_triton_backend.py @@ -138,6 +138,16 @@ def decode_att( return self.diverse_decode_att( q=q, k=k, k_scale=k_scale, v=v, v_scale=v_scale, layer_weight=layer_weight, alloc_func=alloc_func ) + else: + return self.ppl_mha_int8kv_decode_att( + q=q, + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + layer_weight=layer_weight, + alloc_func=alloc_func, + ) def diverse_decode_att( self, @@ -150,7 +160,31 @@ def diverse_decode_att( alloc_func=torch.empty, ) -> torch.Tensor: - from ..triton_kernel.att.decode_att.int8kv_gqa.ppl_int8kv_flash_decoding_diverse import ( + from ..triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse import ( + token_decode_attention_flash_decoding, + ) + + return token_decode_attention_flash_decoding( + q=q, + infer_state=self.infer_state, + cache_k=k, + cache_k_scale=k_scale, + cache_v=v, + cache_v_scale=v_scale, + alloc_tensor_func=alloc_func, + ) + + def ppl_mha_int8kv_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + layer_weight, + alloc_func=torch.empty, + ) -> torch.Tensor: + from ..triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding import ( token_decode_attention_flash_decoding, ) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv_gqa/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/__init__.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv_gqa/__init__.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/__init__.py diff --git a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding.py similarity index 91% rename from lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding.py index 88e39b82fc..a02ce88a95 100644 --- a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding.py @@ -5,8 +5,6 @@ def token_decode_attention_flash_decoding( q, infer_state, - q_head_num, - head_dim, cache_k, cache_k_scale, cache_v, @@ -15,11 +13,12 @@ def token_decode_attention_flash_decoding( alloc_tensor_func=torch.empty, ): BLOCK_SEQ = 256 + q_head_num, head_dim = q.shape[1], q.shape[2] batch_size = infer_state.batch_size max_len_in_batch = infer_state.max_len_in_batch calcu_shape1 = (batch_size, q_head_num, head_dim) - from .flash_decoding_stage2 import flash_decode_stage2 + from ..mha.flash_decoding.flash_decoding_stage2 import flash_decode_stage2 o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv_gqa/ppl_int8kv_flash_decoding_diverse.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv_gqa/ppl_int8kv_flash_decoding_diverse.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse.py diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv_gqa/ppl_int8kv_flash_decoding_diverse_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage1.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv_gqa/ppl_int8kv_flash_decoding_diverse_stage1.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage1.py diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv_gqa/ppl_int8kv_flash_decoding_diverse_stage3.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage3.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv_gqa/ppl_int8kv_flash_decoding_diverse_stage3.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage3.py diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index f55cd618b2..439cadd6bc 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -104,16 +104,6 @@ def _bind_attention(self): # self._context_attention_kernel = partial( # LlamaTransformerLayerInfer._context_attention_kernel_ppl_int8kv, self # ) - elif "ppl_int8kv_flashdecoding" in self.mode: - # self._token_attention_kernel = partial( - # LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv_flashdecoding, self - # ) - # self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, - # self) - # self._context_attention_kernel = partial( - # LlamaTransformerLayerInfer._context_attention_kernel_ppl_int8kv, self - # ) - pass elif "ppl_int4kv_flashdecoding" in self.mode: # self._token_attention_kernel = partial( # LlamaTransformerLayerInfer._token_decode_attention_ppl_int4kv_flashdecoding, self @@ -616,32 +606,6 @@ def _token_decode_attention_ppl_fp16_flashdecoding( alloc_tensor_func=self.alloc_tensor, ) - def _token_decode_attention_ppl_int8kv_flashdecoding( - self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None - ): - from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding import token_decode_attention_flash_decoding - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_k_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - cache_v_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - def _token_decode_attention_ppl_int4kv_flashdecoding( self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None ): From b0abacf9cee91f651265ed26a1442482114c55c3 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 6 Jan 2026 14:17:46 +0000 Subject: [PATCH 017/114] add int4kv. --- .../attention/int4kv_triton_backend.py | 199 ++++++++++++++++++ .../attention/int8kv_triton_backend.py | 2 +- .../kv_copy/ppl_int4kv_copy_kv.py | 136 ++++++++++++ ...quant_copy_kv.py => ppl_int8kv_copy_kv.py} | 0 ..._copy_kv.py => test_ppl_int8kv_copy_kv.py} | 2 +- 5 files changed, 337 insertions(+), 2 deletions(-) create mode 100644 lightllm/common/basemodel/attention/int4kv_triton_backend.py create mode 100644 lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py rename lightllm/common/basemodel/triton_kernel/kv_copy/{ppl_quant_copy_kv.py => ppl_int8kv_copy_kv.py} (100%) rename unit_tests/common/basemodel/triton_kernel/kv_copy/{test_ppl_quant_copy_kv.py => test_ppl_int8kv_copy_kv.py} (97%) diff --git a/lightllm/common/basemodel/attention/int4kv_triton_backend.py b/lightllm/common/basemodel/attention/int4kv_triton_backend.py new file mode 100644 index 0000000000..1e8d88518b --- /dev/null +++ b/lightllm/common/basemodel/attention/int4kv_triton_backend.py @@ -0,0 +1,199 @@ +import dataclasses +import torch +from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from typing import Optional, Tuple +from lightllm.utils.envs_utils import enable_diverse_mode_gqa_decode_fast_kernel + + +class Int4kvTritonAttBackend(BaseAttBackend): + def __init__(self, quant_group_size: int): + self.quant_group_size: int = quant_group_size + + def create_att_prefill_state(self, infer_state) -> "Int4kvTritonPrefillAttState": + return Int4kvTritonPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "Int4kvTritonDecodeAttState": + return Int4kvTritonDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class Int4kvTritonPrefillAttState(BasePrefillAttState): + + # 用于反量化的时候使用,可以减少反量化占用的显存数量。按需使用。 + b_kv_start_loc: torch.Tensor = None + + def init_state(self): + self.b_kv_start_loc = ( + torch.cumsum(self.infer_state.b_seq_len, dim=0, dtype=self.infer_state.b_seq_len.dtype) + - self.infer_state.b_seq_len + ) + + def copy_for_prefill_cuda_graph(self, new_state: "Int4kvTritonPrefillAttState"): + pass + + def prefill_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: Tuple[torch.Tensor, torch.Tensor], + layer_weight, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.use_alibi is False + + self.backend: Int4kvTritonAttBackend = self.backend # for typing + if self.backend.quant_group_size == 8: + pass + k, k_scale = k + v, v_scale = v + o = self._groupsize_quant_prefill_att( + q=q, + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + layer_weight=layer_weight, + alloc_func=alloc_func, + ) + return o + + def _groupsize_quant_prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + layer_weight, + alloc_func=torch.empty, + ) -> torch.Tensor: + # o_tensor = alloc_func(q.shape, q.dtype, device=q.device) + # batch_size = self.infer_state.b_seq_len.shape[0] + + assert k.untyped_storage().data_ptr() == v.untyped_storage().data_ptr() + assert k_scale.untyped_storage().data_ptr() == v_scale.untyped_storage().data_ptr() + + total_token_num = self.infer_state.total_token_num + k_dequant = alloc_func((total_token_num, k.shape[1], k.shape[2]), dtype=q.dtype, device=q.device) + v_dequant = alloc_func((total_token_num, v.shape[1], v.shape[2]), dtype=q.dtype, device=q.device) + o_tensor = alloc_func(q.shape, dtype=q.dtype, device=q.device) + + max_kv_seq_len = self.infer_state.max_kv_seq_len + + from ..triton_kernel.kv_copy.ppl_int8kv_copy_kv import dequantize_int8kv + + dequantize_int8kv( + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + req_to_token_indexs=self.infer_state.req_manager.req_to_token_indexs, + b_seq_len=self.infer_state.b_seq_len, + b_req_idx=self.infer_state.b_req_idx, + b_kv_start_loc=self.b_kv_start_loc, + k_out=k_dequant, + v_out=v_dequant, + max_len_in_batch=max_kv_seq_len, + quant_group_size=self.backend.quant_group_size, + ) + + from ..triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd_ppl_int8kv + + context_attention_fwd_ppl_int8kv( + q=q, + k=k_dequant, + v=v_dequant, + o=o_tensor, + b_start_loc=self.infer_state.b_start_loc, + b_kv_start_loc=self.b_kv_start_loc, + b_seq_len=self.infer_state.b_seq_len, + max_q_input_len=self.infer_state.max_q_seq_len, + b_prompt_cache_len=self.infer_state.b_ready_cache_len, + ) + return o_tensor + + +@dataclasses.dataclass +class Int4kvTritonDecodeAttState(BaseDecodeAttState): + def init_state(self): + pass + + def copy_for_decode_cuda_graph(self, new_state: "Int4kvTritonDecodeAttState"): + pass + + def decode_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: Tuple[torch.Tensor, torch.Tensor], + layer_weight, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert att_control.use_alibi is False + k, k_scale = k + v, v_scale = v + if enable_diverse_mode_gqa_decode_fast_kernel(): + return self.diverse_decode_att( + q=q, k=k, k_scale=k_scale, v=v, v_scale=v_scale, layer_weight=layer_weight, alloc_func=alloc_func + ) + else: + return self.ppl_mha_int8kv_decode_att( + q=q, + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + layer_weight=layer_weight, + alloc_func=alloc_func, + ) + + def diverse_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + layer_weight, + alloc_func=torch.empty, + ) -> torch.Tensor: + + from ..triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse import ( + token_decode_attention_flash_decoding, + ) + + return token_decode_attention_flash_decoding( + q=q, + infer_state=self.infer_state, + cache_k=k, + cache_k_scale=k_scale, + cache_v=v, + cache_v_scale=v_scale, + alloc_tensor_func=alloc_func, + ) + + def ppl_mha_int8kv_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + layer_weight, + alloc_func=torch.empty, + ) -> torch.Tensor: + from ..triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding import ( + token_decode_attention_flash_decoding, + ) + + return token_decode_attention_flash_decoding( + q=q, + infer_state=self.infer_state, + cache_k=k, + cache_k_scale=k_scale, + cache_v=v, + cache_v_scale=v_scale, + alloc_tensor_func=alloc_func, + ) diff --git a/lightllm/common/basemodel/attention/int8kv_triton_backend.py b/lightllm/common/basemodel/attention/int8kv_triton_backend.py index 9ec57d0c59..bcf927ad5e 100644 --- a/lightllm/common/basemodel/attention/int8kv_triton_backend.py +++ b/lightllm/common/basemodel/attention/int8kv_triton_backend.py @@ -81,7 +81,7 @@ def _groupsize_quant_prefill_att( max_kv_seq_len = self.infer_state.max_kv_seq_len - from ..triton_kernel.kv_copy.ppl_quant_copy_kv import dequantize_int8kv + from ..triton_kernel.kv_copy.ppl_int8kv_copy_kv import dequantize_int8kv dequantize_int8kv( k=k, diff --git a/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py new file mode 100644 index 0000000000..e58f785235 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py @@ -0,0 +1,136 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_destindex_copy_quantize_int4_kv( + K, + Dest_loc, + Out, + Out_scale, + stride_k_bs, + stride_k_h, + stride_k_g, + stride_k_d, + stride_o_bs, + stride_o_h, + stride_o_g, + stride_o_d, + stride_os_bs, + stride_os_h, + stride_os_g, + group_count, + token_num, + HEAD_NUM: tl.constexpr, + BLOCK_GROUP_COUNT: tl.constexpr, + BLOCK_GROUP_DIM: tl.constexpr, +): + start_index = tl.program_id(0) + + for cur_index in range(start_index, token_num, step=tl.num_programs(axis=0)): + offs_g = tl.arange(0, BLOCK_GROUP_COUNT) % group_count + offs_d = tl.arange(0, BLOCK_GROUP_DIM // 2) + + dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) + + for cur_head in tl.static_range(HEAD_NUM, step=1): + src_data_0 = tl.load( + K + + cur_index * stride_k_bs + + cur_head * stride_k_h + + offs_g[:, None] * stride_k_g + + offs_d[None, :] * 2, + other=0.0, + ) + src_data_1 = tl.load( + K + + cur_index * stride_k_bs + + cur_head * stride_k_h + + offs_g[:, None] * stride_k_g + + offs_d[None, :] * 2 + + 1, + other=0.0, + ) + + abs_data_0 = tl.abs(src_data_0) + abs_data_1 = tl.abs(src_data_1) + + data_scale = (tl.maximum(tl.max(abs_data_0, axis=1), tl.max(abs_data_1, axis=1)) / 7.0).to( + Out_scale.dtype.element_ty + ) + q_src_data_0 = (src_data_0 / data_scale[:, None]).to(tl.int8) + q_src_data_0 = tl.where(q_src_data_0 > 7, 7, q_src_data_0) + q_src_data_0 = tl.where(q_src_data_0 < -7, -7, q_src_data_0) + + q_src_data_1 = (src_data_1 / data_scale[:, None]).to(tl.int8) + q_src_data_1 = tl.where(q_src_data_1 > 7, 7, q_src_data_1) + q_src_data_1 = tl.where(q_src_data_1 < -7, -7, q_src_data_1) + + low_4 = ((q_src_data_0 & 0x80) >> 4) | (q_src_data_0 & 0xF) + high_4 = (((q_src_data_1 & 0x80) >> 4) | (q_src_data_1 & 0xF)) << 4 + + out_data = low_4 | high_4 + + o_ptrs = ( + Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :] + ) + os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g + tl.store(o_ptrs, out_data) + tl.store(os_ptrs, data_scale) + return + + +@torch.no_grad() +def destindex_copy_int4kv( + KV: torch.Tensor, + DestLoc: torch.Tensor, + KV_buffer: torch.Tensor, + KV_scale_buffer: torch.Tensor, + quant_group_size: int, +): + head_num = KV.shape[1] + head_dim = KV.shape[2] + + assert head_dim % quant_group_size == 0, "error head dim, can not been supported to copy quant kv" + + group_count = head_dim // quant_group_size + group_dim = quant_group_size + + assert triton.next_power_of_2(group_dim) == group_dim + + KV = KV.view((KV.shape[0], head_num, group_count, group_dim)) + KV_buffer = KV_buffer.view( + KV_buffer.shape[0], KV_buffer.shape[1], group_count, group_dim // 2 + ) # OUt 是 int8 类型, 两个int4组一个int8,所以 group_dim // 2 + KV_scale_buffer = KV_scale_buffer.view(KV_scale_buffer.shape[0], KV_scale_buffer.shape[1], group_count) + if len(DestLoc) < 1024: + grid = (len(DestLoc),) + else: + grid = (1024,) + + _fwd_kernel_destindex_copy_quantize_int4_kv[grid]( + K=KV, + Dest_loc=DestLoc, + Out=KV_buffer, + Out_scale=KV_scale_buffer, + stride_k_bs=KV.stride(0), + stride_k_h=KV.stride(1), + stride_k_g=KV.stride(2), + stride_k_d=KV.stride(3), + stride_o_bs=KV_buffer.stride(0), + stride_o_h=KV_buffer.stride(1), + stride_o_g=KV_buffer.stride(2), + stride_o_d=KV_buffer.stride(3), + stride_os_bs=KV_scale_buffer.stride(0), + stride_os_h=KV_scale_buffer.stride(1), + stride_os_g=KV_scale_buffer.stride(2), + group_count=group_count, + token_num=len(DestLoc), + HEAD_NUM=head_num, + BLOCK_GROUP_COUNT=triton.next_power_of_2(group_count), + BLOCK_GROUP_DIM=triton.next_power_of_2(group_dim), + num_warps=4, + num_stages=1, + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_quant_copy_kv.py b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int8kv_copy_kv.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/kv_copy/ppl_quant_copy_kv.py rename to lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int8kv_copy_kv.py diff --git a/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_quant_copy_kv.py b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int8kv_copy_kv.py similarity index 97% rename from unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_quant_copy_kv.py rename to unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int8kv_copy_kv.py index a3449de808..149c9894ae 100644 --- a/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_quant_copy_kv.py +++ b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int8kv_copy_kv.py @@ -1,7 +1,7 @@ import torch import time import pytest -from lightllm.common.basemodel.triton_kernel.kv_copy.ppl_quant_copy_kv import ( +from lightllm.common.basemodel.triton_kernel.kv_copy.ppl_int8kv_copy_kv import ( dequantize_int8kv, destindex_copy_quantize_kv, ) From 41ecffde560f1b70a5a020c57477f8c2361b6c3c Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 6 Jan 2026 15:02:29 +0000 Subject: [PATCH 018/114] add int4 kernel --- .../kv_copy/ppl_int4kv_copy_kv.py | 234 ++++++++++++++++++ .../kv_copy/test_ppl_int4kv_copy_kv.py | 8 + 2 files changed, 242 insertions(+) create mode 100644 unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py diff --git a/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py index e58f785235..672c8009e0 100644 --- a/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py +++ b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py @@ -134,3 +134,237 @@ def destindex_copy_int4kv( num_stages=1, ) return + + +@triton.jit +def _fwd_dequantize_int4kv( + k, + k_ss, + k_sh, + k_sg, + k_sd, + k_scale, + k_scale_ss, + k_scale_sh, + k_scale_sg, + k_scale_sd, + v, + v_ss, + v_sh, + v_sg, + v_sd, + v_scale, + v_scale_ss, + v_scale_sh, + v_scale_sg, + v_scale_sd, + req_to_token_indexs, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + b_seq_len, + b_req_idx, + b_kv_start_loc, + k_out, + k_out_ss, + k_out_sh, + k_out_sg, + k_out_sd, + v_out, + v_out_ss, + v_out_sh, + v_out_sg, + v_out_sd, + k_head_num, + v_head_num, + group_count, + group_dim, + SEQ_BLOCK_SIZE: tl.constexpr, + GROUP_COUNT_BLOCK_SIZE: tl.constexpr, + BLOCK_GROUP_DIM: tl.constexpr, +): + start_block_index = tl.program_id(0) + cur_batch = tl.program_id(1) + cur_batch_req_idx = tl.load(b_req_idx + cur_batch) + cur_seq_len = tl.load(b_seq_len + cur_batch) + if start_block_index * SEQ_BLOCK_SIZE >= cur_seq_len: + return + + out_start_loc = tl.load(b_kv_start_loc + cur_batch) + + offs_kv_loc = (start_block_index * SEQ_BLOCK_SIZE + tl.arange(0, SEQ_BLOCK_SIZE)) % cur_seq_len + kv_loc = tl.load(req_to_token_indexs + cur_batch_req_idx * stride_req_to_tokens_b + offs_kv_loc).to(tl.int64) + + offs_d = tl.arange(0, BLOCK_GROUP_DIM) % group_dim + offs_scale_d = tl.arange(0, 1) + group_offs = tl.arange(0, GROUP_COUNT_BLOCK_SIZE) % group_count + + for k_head_index in tl.range(0, k_head_num, step=1, num_stages=3): + k_int8 = tl.load( + k + + kv_loc[:, None, None] * k_ss + + k_head_index * k_sh + + group_offs[None, :, None] * k_sg + + offs_d[None, None, :] // 2 + ) + k_high = (k_int8 & 0xF0) >> 4 + k_low = k_int8 & 0x0F + k_high = tl.where(k_high >= 8, k_high - 16, k_high) + k_low = tl.where(k_low >= 8, k_low - 16, k_low) + + k_int4 = tl.where( + offs_d[None, None, :] % 2 == 0, + k_low, + k_high, + ) + + k_scale_data = tl.load( + k_scale + + kv_loc[:, None, None] * k_scale_ss + + k_head_index * k_scale_sh + + group_offs[None, :, None] * k_scale_sg + + offs_scale_d[None, None, :] + ) + k_out_data = k_int4.to(k_out.dtype.element_ty) * k_scale_data + tl.store( + k_out + + (out_start_loc + offs_kv_loc[:, None, None]) * k_out_ss + + k_head_index * k_out_sh + + group_offs[None, :, None] * k_out_sg + + offs_d[None, None, :], + k_out_data, + ) + + for v_head_index in tl.range(0, v_head_num, step=1, num_stages=3): + v_int8 = tl.load( + v + + kv_loc[:, None, None] * v_ss + + v_head_index * v_sh + + group_offs[None, :, None] * v_sg + + offs_d[None, None, :] + ) + v_high = (v_int8 & 0xF0) >> 4 + v_low = v_int8 & 0x0F + v_high = tl.where(v_high >= 8, v_high - 16, v_high) + v_low = tl.where(v_low >= 8, v_low - 16, v_low) + + v_int4 = tl.where( + offs_d[None, None, :] % 2 == 0, + v_low, + v_high, + ) + v_scale_data = tl.load( + v_scale + + kv_loc[:, None, None] * v_scale_ss + + v_head_index * v_scale_sh + + group_offs[None, :, None] * v_scale_sg + + offs_scale_d[None, None, :] + ) + v_out_data = v_int4.to(v_out.dtype.element_ty) * v_scale_data + tl.store( + v_out + + (out_start_loc + offs_kv_loc[:, None, None]) * v_out_ss + + v_head_index * v_out_sh + + group_offs[None, :, None] * v_out_sg + + offs_d[None, None, :], + v_out_data, + ) + return + + +@torch.no_grad() +def dequantize_int4kv( + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + req_to_token_indexs: torch.Tensor, + b_seq_len: torch.Tensor, + b_req_idx: torch.Tensor, + b_kv_start_loc: torch.Tensor, + k_out: torch.Tensor, + v_out: torch.Tensor, + max_len_in_batch: int, + quant_group_size: int, +): + batch_size = b_seq_len.shape[0] + k_head_num = k.shape[1] + k_head_dim = k.shape[2] * 2 + v_head_num = v.shape[1] + v_head_dim = v.shape[2] * 2 + assert k_head_dim % quant_group_size == 0, "error head dim, can not been supported to copy quant kv" + assert v_head_dim % quant_group_size == 0, "error head dim, can not been supported to copy quant kv" + assert k_head_dim == v_head_dim, "error head dim, can not been supported to copy quant kv" + assert k_head_dim // v_scale.shape[2] == quant_group_size, "error head dim, can not been supported to copy quant kv" + assert k_head_dim in [64, 128, 256] + + group_count = k_head_dim // quant_group_size + group_dim = quant_group_size + + k = k.view((k.shape[0], k.shape[1], group_count, group_dim // 2)) # int4kv 以 int8 存储的 + v = v.view((v.shape[0], v.shape[1], group_count, group_dim // 2)) + k_scale = k_scale.view((k_scale.shape[0], k_scale.shape[1], group_count, 1)) + v_scale = v_scale.view((v_scale.shape[0], v_scale.shape[1], group_count, 1)) + + # 使拆分的grid 具有足够的并行度 + SEQ_BLOCK_SIZE = 128 + while triton.cdiv(max_len_in_batch, SEQ_BLOCK_SIZE) * batch_size < 512: + SEQ_BLOCK_SIZE = SEQ_BLOCK_SIZE // 2 + if SEQ_BLOCK_SIZE <= 1: + break + + if SEQ_BLOCK_SIZE <= 1: + SEQ_BLOCK_SIZE = 8 + + grid = (triton.cdiv(max_len_in_batch, SEQ_BLOCK_SIZE), batch_size) + num_warps = 4 + k_out = k_out.view((k_out.shape[0], k_out.shape[1], group_count, group_dim)) + v_out = v_out.view((v_out.shape[0], v_out.shape[1], group_count, group_dim)) + + _fwd_dequantize_int4kv[grid]( + k=k, + k_ss=k.stride(0), + k_sh=k.stride(1), + k_sg=k.stride(2), + k_sd=k.stride(3), + k_scale=k_scale, + k_scale_ss=k_scale.stride(0), + k_scale_sh=k_scale.stride(1), + k_scale_sg=k_scale.stride(2), + k_scale_sd=k_scale.stride(2), + v=v, + v_ss=v.stride(0), + v_sh=v.stride(1), + v_sg=v.stride(2), + v_sd=v.stride(3), + v_scale=v_scale, + v_scale_ss=v_scale.stride(0), + v_scale_sh=v_scale.stride(1), + v_scale_sg=v_scale.stride(2), + v_scale_sd=v_scale.stride(3), + req_to_token_indexs=req_to_token_indexs, + stride_req_to_tokens_b=req_to_token_indexs.stride(0), + stride_req_to_tokens_s=req_to_token_indexs.stride(1), + b_seq_len=b_seq_len, + b_req_idx=b_req_idx, + b_kv_start_loc=b_kv_start_loc, + k_out=k_out, + k_out_ss=k_out.stride(0), + k_out_sh=k_out.stride(1), + k_out_sg=k_out.stride(2), + k_out_sd=k_out.stride(3), + v_out=v_out, + v_out_ss=v_out.stride(0), + v_out_sh=v_out.stride(1), + v_out_sg=v_out.stride(2), + v_out_sd=v_out.stride(3), + k_head_num=k_head_num, + v_head_num=v_head_num, + group_count=group_count, + group_dim=group_dim, + SEQ_BLOCK_SIZE=SEQ_BLOCK_SIZE, + GROUP_COUNT_BLOCK_SIZE=triton.next_power_of_2(group_count), + BLOCK_GROUP_DIM=triton.next_power_of_2(group_dim), + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py new file mode 100644 index 0000000000..708b1bd571 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py @@ -0,0 +1,8 @@ +import torch +import pytest +import numpy as np +from typing import Tuple +from lightllm.common.basemodel.triton_kernel.kv_copy.ppl_int4kv_copy_kv import destindex_copy_int4kv +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) From ccaaa6ac6e052a050a45f38a87e9309f1f21a379 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 6 Jan 2026 15:38:13 +0000 Subject: [PATCH 019/114] fix all --- .../kv_copy/ppl_int4kv_copy_kv.py | 8 +-- .../kv_copy/test_ppl_int4kv_copy_kv.py | 57 ++++++++++++++++++- 2 files changed, 60 insertions(+), 5 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py index 672c8009e0..e917d54b38 100644 --- a/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py +++ b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py @@ -41,7 +41,6 @@ def _fwd_kernel_destindex_copy_quantize_int4_kv( + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :] * 2, - other=0.0, ) src_data_1 = tl.load( K @@ -50,7 +49,6 @@ def _fwd_kernel_destindex_copy_quantize_int4_kv( + offs_g[:, None] * stride_k_g + offs_d[None, :] * 2 + 1, - other=0.0, ) abs_data_0 = tl.abs(src_data_0) @@ -62,10 +60,12 @@ def _fwd_kernel_destindex_copy_quantize_int4_kv( q_src_data_0 = (src_data_0 / data_scale[:, None]).to(tl.int8) q_src_data_0 = tl.where(q_src_data_0 > 7, 7, q_src_data_0) q_src_data_0 = tl.where(q_src_data_0 < -7, -7, q_src_data_0) + q_src_data_0 = tl.cast(q_src_data_0, tl.uint8) q_src_data_1 = (src_data_1 / data_scale[:, None]).to(tl.int8) q_src_data_1 = tl.where(q_src_data_1 > 7, 7, q_src_data_1) q_src_data_1 = tl.where(q_src_data_1 < -7, -7, q_src_data_1) + q_src_data_1 = tl.cast(q_src_data_1, tl.uint8) low_4 = ((q_src_data_0 & 0x80) >> 4) | (q_src_data_0 & 0xF) high_4 = (((q_src_data_1 & 0x80) >> 4) | (q_src_data_1 & 0xF)) << 4 @@ -206,7 +206,7 @@ def _fwd_dequantize_int4kv( + group_offs[None, :, None] * k_sg + offs_d[None, None, :] // 2 ) - k_high = (k_int8 & 0xF0) >> 4 + k_high = tl.cast((tl.cast(k_int8, tl.uint8) & 0xF0) >> 4, tl.int8) k_low = k_int8 & 0x0F k_high = tl.where(k_high >= 8, k_high - 16, k_high) k_low = tl.where(k_low >= 8, k_low - 16, k_low) @@ -242,7 +242,7 @@ def _fwd_dequantize_int4kv( + group_offs[None, :, None] * v_sg + offs_d[None, None, :] ) - v_high = (v_int8 & 0xF0) >> 4 + v_high = tl.cast((tl.cast(v_int8, tl.uint8) & 0xF0) >> 4, tl.int8) v_low = v_int8 & 0x0F v_high = tl.where(v_high >= 8, v_high - 16, v_high) v_low = tl.where(v_low >= 8, v_low - 16, v_low) diff --git a/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py index 708b1bd571..231f11aab4 100644 --- a/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py +++ b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py @@ -2,7 +2,62 @@ import pytest import numpy as np from typing import Tuple -from lightllm.common.basemodel.triton_kernel.kv_copy.ppl_int4kv_copy_kv import destindex_copy_int4kv +from lightllm.common.basemodel.triton_kernel.kv_copy.ppl_int4kv_copy_kv import destindex_copy_int4kv, dequantize_int4kv from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) + + +def test_quanted_and_dequant(): + """Test quantization followed by dequantization.""" + batch_size = 1 + seq_len = 8 + head_num = 4 + k_head_num = 2 + v_head_num = 2 + assert k_head_num + v_head_num == head_num + head_dim = 64 + quant_group_size = 8 + + # Create original data + original_kv = torch.randn(batch_size * seq_len, head_num, head_dim, dtype=torch.float32).clamp_(-0.5, 0.5).cuda() + dest_loc = torch.arange(batch_size * seq_len, dtype=torch.int64).cuda() + + # Quantize + group_count = head_dim // quant_group_size + kv_buffer = torch.zeros(batch_size * seq_len, head_num, head_dim // 2, dtype=torch.int8).cuda() + kv_scale_buffer = torch.zeros(batch_size * seq_len, head_num, group_count, dtype=torch.float32).cuda() + destindex_copy_int4kv(original_kv, dest_loc, kv_buffer, kv_scale_buffer, quant_group_size) + + # Dequantize + req_to_token_indexs = torch.arange(seq_len, dtype=torch.int64).unsqueeze(0).cuda() + b_seq_len = torch.tensor([seq_len], dtype=torch.int32).cuda() + b_req_idx = torch.tensor([0], dtype=torch.int32).cuda() + b_kv_start_loc = torch.tensor([0], dtype=torch.int32).cuda() + + recovered_kv = torch.zeros(batch_size * seq_len, head_num, head_dim, dtype=torch.float32).cuda() + + dequantize_int4kv( + k=kv_buffer[:, 0:k_head_num, :], + k_scale=kv_scale_buffer[:, k_head_num:, :], + v=kv_buffer[:, k_head_num:, :], + v_scale=kv_scale_buffer[:, k_head_num:, :], + req_to_token_indexs=req_to_token_indexs, + b_seq_len=b_seq_len, + b_req_idx=b_req_idx, + b_kv_start_loc=b_kv_start_loc, + k_out=recovered_kv[:, :k_head_num, :], + v_out=recovered_kv[:, k_head_num:, :], + max_len_in_batch=seq_len, + quant_group_size=quant_group_size, + ) + + logger.info("Round-trip test completed!") + + # assert torch.allclose(recovered_kv, original_kv, atol=1e-2, rtol=0) + cos = torch.nn.CosineSimilarity(0) + assert cos(recovered_kv.flatten().float(), original_kv.flatten().float()) > 0.99 + + +if __name__ == "__main__": + pytest.main() From 9961261ecdaf9efb85e05d7acd17994f451e2b81 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 6 Jan 2026 15:42:11 +0000 Subject: [PATCH 020/114] fix unit test --- .../triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py index 231f11aab4..a213dd8fc6 100644 --- a/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py +++ b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py @@ -20,7 +20,7 @@ def test_quanted_and_dequant(): quant_group_size = 8 # Create original data - original_kv = torch.randn(batch_size * seq_len, head_num, head_dim, dtype=torch.float32).clamp_(-0.5, 0.5).cuda() + original_kv = torch.randn(batch_size * seq_len, head_num, head_dim, dtype=torch.float32).clamp_(-1, 1).cuda() dest_loc = torch.arange(batch_size * seq_len, dtype=torch.int64).cuda() # Quantize @@ -39,7 +39,7 @@ def test_quanted_and_dequant(): dequantize_int4kv( k=kv_buffer[:, 0:k_head_num, :], - k_scale=kv_scale_buffer[:, k_head_num:, :], + k_scale=kv_scale_buffer[:, 0:k_head_num, :], v=kv_buffer[:, k_head_num:, :], v_scale=kv_scale_buffer[:, k_head_num:, :], req_to_token_indexs=req_to_token_indexs, @@ -54,7 +54,7 @@ def test_quanted_and_dequant(): logger.info("Round-trip test completed!") - # assert torch.allclose(recovered_kv, original_kv, atol=1e-2, rtol=0) + assert torch.allclose(recovered_kv, original_kv, atol=2 / 14 * 2, rtol=0) cos = torch.nn.CosineSimilarity(0) assert cos(recovered_kv.flatten().float(), original_kv.flatten().float()) > 0.99 From 06e2369ef39bae78e5ea7106d2b5a3933798b8bf Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 6 Jan 2026 15:50:25 +0000 Subject: [PATCH 021/114] fix unit test --- .../basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py index e917d54b38..3ca5b18be3 100644 --- a/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py +++ b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py @@ -61,11 +61,12 @@ def _fwd_kernel_destindex_copy_quantize_int4_kv( q_src_data_0 = tl.where(q_src_data_0 > 7, 7, q_src_data_0) q_src_data_0 = tl.where(q_src_data_0 < -7, -7, q_src_data_0) q_src_data_0 = tl.cast(q_src_data_0, tl.uint8) + q_src_data_0 = q_src_data_0.to(tl.uint8, bitcast=True) q_src_data_1 = (src_data_1 / data_scale[:, None]).to(tl.int8) q_src_data_1 = tl.where(q_src_data_1 > 7, 7, q_src_data_1) q_src_data_1 = tl.where(q_src_data_1 < -7, -7, q_src_data_1) - q_src_data_1 = tl.cast(q_src_data_1, tl.uint8) + q_src_data_1 = q_src_data_1.to(tl.uint8, bitcast=True) low_4 = ((q_src_data_0 & 0x80) >> 4) | (q_src_data_0 & 0xF) high_4 = (((q_src_data_1 & 0x80) >> 4) | (q_src_data_1 & 0xF)) << 4 @@ -206,7 +207,7 @@ def _fwd_dequantize_int4kv( + group_offs[None, :, None] * k_sg + offs_d[None, None, :] // 2 ) - k_high = tl.cast((tl.cast(k_int8, tl.uint8) & 0xF0) >> 4, tl.int8) + k_high = ((k_int8.to(tl.uint8, bitcast=True) & 0xF0) >> 4).to(tl.int8, bitcast=True) k_low = k_int8 & 0x0F k_high = tl.where(k_high >= 8, k_high - 16, k_high) k_low = tl.where(k_low >= 8, k_low - 16, k_low) @@ -242,7 +243,7 @@ def _fwd_dequantize_int4kv( + group_offs[None, :, None] * v_sg + offs_d[None, None, :] ) - v_high = tl.cast((tl.cast(v_int8, tl.uint8) & 0xF0) >> 4, tl.int8) + v_high = ((v_int8.to(tl.uint8, bitcast=True) & 0xF0) >> 4).to(tl.int8, bitcast=True) v_low = v_int8 & 0x0F v_high = tl.where(v_high >= 8, v_high - 16, v_high) v_low = tl.where(v_low >= 8, v_low - 16, v_low) From ef803342e0f6b770ea40285f1c87350e2782402f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 6 Jan 2026 15:54:22 +0000 Subject: [PATCH 022/114] fix --- .../triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py index a213dd8fc6..0e01529083 100644 --- a/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py +++ b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py @@ -54,9 +54,10 @@ def test_quanted_and_dequant(): logger.info("Round-trip test completed!") - assert torch.allclose(recovered_kv, original_kv, atol=2 / 14 * 2, rtol=0) + # assert torch.allclose(recovered_kv, original_kv, atol=2 / 14 * 2, rtol=0) cos = torch.nn.CosineSimilarity(0) - assert cos(recovered_kv.flatten().float(), original_kv.flatten().float()) > 0.99 + print(recovered_kv.flatten().float()[0:10], original_kv.flatten().float()[0:10], flush=True) + assert cos(recovered_kv.flatten().float()[0:10], original_kv.flatten().float()[0:10]) > 0.99 if __name__ == "__main__": From bddb2cf8712f5241d336929914df9254e8629c38 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 02:10:04 +0000 Subject: [PATCH 023/114] fix all --- .../triton_kernel/kv_copy/ppl_int4kv_copy_kv.py | 13 +++++++------ .../kv_copy/test_ppl_int4kv_copy_kv.py | 8 +++----- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py index 3ca5b18be3..f7268a853d 100644 --- a/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py +++ b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py @@ -60,7 +60,6 @@ def _fwd_kernel_destindex_copy_quantize_int4_kv( q_src_data_0 = (src_data_0 / data_scale[:, None]).to(tl.int8) q_src_data_0 = tl.where(q_src_data_0 > 7, 7, q_src_data_0) q_src_data_0 = tl.where(q_src_data_0 < -7, -7, q_src_data_0) - q_src_data_0 = tl.cast(q_src_data_0, tl.uint8) q_src_data_0 = q_src_data_0.to(tl.uint8, bitcast=True) q_src_data_1 = (src_data_1 / data_scale[:, None]).to(tl.int8) @@ -71,7 +70,7 @@ def _fwd_kernel_destindex_copy_quantize_int4_kv( low_4 = ((q_src_data_0 & 0x80) >> 4) | (q_src_data_0 & 0xF) high_4 = (((q_src_data_1 & 0x80) >> 4) | (q_src_data_1 & 0xF)) << 4 - out_data = low_4 | high_4 + out_data = (low_4 | high_4).to(tl.int8, bitcast=True) o_ptrs = ( Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :] @@ -130,7 +129,7 @@ def destindex_copy_int4kv( token_num=len(DestLoc), HEAD_NUM=head_num, BLOCK_GROUP_COUNT=triton.next_power_of_2(group_count), - BLOCK_GROUP_DIM=triton.next_power_of_2(group_dim), + BLOCK_GROUP_DIM=group_dim, num_warps=4, num_stages=1, ) @@ -195,7 +194,7 @@ def _fwd_dequantize_int4kv( offs_kv_loc = (start_block_index * SEQ_BLOCK_SIZE + tl.arange(0, SEQ_BLOCK_SIZE)) % cur_seq_len kv_loc = tl.load(req_to_token_indexs + cur_batch_req_idx * stride_req_to_tokens_b + offs_kv_loc).to(tl.int64) - offs_d = tl.arange(0, BLOCK_GROUP_DIM) % group_dim + offs_d = tl.arange(0, BLOCK_GROUP_DIM) offs_scale_d = tl.arange(0, 1) group_offs = tl.arange(0, GROUP_COUNT_BLOCK_SIZE) % group_count @@ -241,7 +240,7 @@ def _fwd_dequantize_int4kv( + kv_loc[:, None, None] * v_ss + v_head_index * v_sh + group_offs[None, :, None] * v_sg - + offs_d[None, None, :] + + offs_d[None, None, :] // 2 ) v_high = ((v_int8.to(tl.uint8, bitcast=True) & 0xF0) >> 4).to(tl.int8, bitcast=True) v_low = v_int8 & 0x0F @@ -301,6 +300,8 @@ def dequantize_int4kv( group_count = k_head_dim // quant_group_size group_dim = quant_group_size + assert triton.next_power_of_2(group_dim) == group_dim + k = k.view((k.shape[0], k.shape[1], group_count, group_dim // 2)) # int4kv 以 int8 存储的 v = v.view((v.shape[0], v.shape[1], group_count, group_dim // 2)) k_scale = k_scale.view((k_scale.shape[0], k_scale.shape[1], group_count, 1)) @@ -364,7 +365,7 @@ def dequantize_int4kv( group_dim=group_dim, SEQ_BLOCK_SIZE=SEQ_BLOCK_SIZE, GROUP_COUNT_BLOCK_SIZE=triton.next_power_of_2(group_count), - BLOCK_GROUP_DIM=triton.next_power_of_2(group_dim), + BLOCK_GROUP_DIM=group_dim, num_warps=num_warps, num_stages=1, ) diff --git a/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py index 0e01529083..83537ec708 100644 --- a/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py +++ b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py @@ -30,7 +30,7 @@ def test_quanted_and_dequant(): destindex_copy_int4kv(original_kv, dest_loc, kv_buffer, kv_scale_buffer, quant_group_size) # Dequantize - req_to_token_indexs = torch.arange(seq_len, dtype=torch.int64).unsqueeze(0).cuda() + req_to_token_indexs = torch.arange(seq_len, dtype=torch.int64).view(1, -1).cuda() b_seq_len = torch.tensor([seq_len], dtype=torch.int32).cuda() b_req_idx = torch.tensor([0], dtype=torch.int32).cuda() b_kv_start_loc = torch.tensor([0], dtype=torch.int32).cuda() @@ -53,11 +53,9 @@ def test_quanted_and_dequant(): ) logger.info("Round-trip test completed!") - - # assert torch.allclose(recovered_kv, original_kv, atol=2 / 14 * 2, rtol=0) + assert torch.allclose(recovered_kv, original_kv, atol=2 / 14, rtol=0) cos = torch.nn.CosineSimilarity(0) - print(recovered_kv.flatten().float()[0:10], original_kv.flatten().float()[0:10], flush=True) - assert cos(recovered_kv.flatten().float()[0:10], original_kv.flatten().float()[0:10]) > 0.99 + assert cos(recovered_kv.flatten().float(), original_kv.flatten().float()) > 0.99 if __name__ == "__main__": From e9f446223599964ca68459724c3dddcf4ca23a93 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 02:17:37 +0000 Subject: [PATCH 024/114] fix all --- .../attention/int4kv_triton_backend.py | 8 +- .../attention/int8kv_triton_backend.py | 4 +- .../context_flashattention_nopad.py | 87 +------------------ .../test_context_flashattention_nopad.py | 6 +- 4 files changed, 12 insertions(+), 93 deletions(-) diff --git a/lightllm/common/basemodel/attention/int4kv_triton_backend.py b/lightllm/common/basemodel/attention/int4kv_triton_backend.py index 1e8d88518b..f2421dd815 100644 --- a/lightllm/common/basemodel/attention/int4kv_triton_backend.py +++ b/lightllm/common/basemodel/attention/int4kv_triton_backend.py @@ -81,9 +81,9 @@ def _groupsize_quant_prefill_att( max_kv_seq_len = self.infer_state.max_kv_seq_len - from ..triton_kernel.kv_copy.ppl_int8kv_copy_kv import dequantize_int8kv + from ..triton_kernel.kv_copy.ppl_int4kv_copy_kv import dequantize_int4kv - dequantize_int8kv( + dequantize_int4kv( k=k, k_scale=k_scale, v=v, @@ -98,9 +98,9 @@ def _groupsize_quant_prefill_att( quant_group_size=self.backend.quant_group_size, ) - from ..triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd_ppl_int8kv + from ..triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd_contiguous_kv - context_attention_fwd_ppl_int8kv( + context_attention_fwd_contiguous_kv( q=q, k=k_dequant, v=v_dequant, diff --git a/lightllm/common/basemodel/attention/int8kv_triton_backend.py b/lightllm/common/basemodel/attention/int8kv_triton_backend.py index bcf927ad5e..361b9b2b5c 100644 --- a/lightllm/common/basemodel/attention/int8kv_triton_backend.py +++ b/lightllm/common/basemodel/attention/int8kv_triton_backend.py @@ -98,9 +98,9 @@ def _groupsize_quant_prefill_att( quant_group_size=self.backend.quant_group_size, ) - from ..triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd_ppl_int8kv + from ..triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd_contiguous_kv - context_attention_fwd_ppl_int8kv( + context_attention_fwd_contiguous_kv( q=q, k=k_dequant, v=v_dequant, diff --git a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py index 29a56603b9..5ba6d0beb6 100644 --- a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py +++ b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py @@ -336,7 +336,7 @@ def context_attention_fwd_no_prompt_cache(q, k, v, o, b_start_loc, b_seq_len, ma @triton.jit -def _fwd_kernel_int8kv( +def _fwd_kernel_contiguous_kv( Q, K, V, @@ -456,7 +456,7 @@ def _fwd_kernel_int8kv( @torch.no_grad() -def context_attention_fwd_ppl_int8kv( +def context_attention_fwd_contiguous_kv( q, k, v, o, b_start_loc, b_kv_start_loc, b_seq_len, max_q_input_len, b_prompt_cache_len ): BLOCK_M = 128 if not is_tesla() else 64 @@ -476,7 +476,7 @@ def context_attention_fwd_ppl_int8kv( num_warps = 4 if Lk <= 64 else 8 num_stages = 1 - _fwd_kernel_int8kv[grid]( + _fwd_kernel_contiguous_kv[grid]( Q=q, K=k, V=v, @@ -598,86 +598,5 @@ def test(): assert torch.allclose(torch_o, o, atol=1e-2, rtol=0) -def torch_context_attention_fwd2(q, k, v, o, b_start_loc, b_seq_len, b_prompt_cache_len): - - batch = b_start_loc.shape[0] - k = k.transpose(1, 2) - v = v.transpose(1, 2) - for i in range(batch): - start_loc = b_start_loc[i] - seq_len = b_seq_len[i] - prompt_cache_len = b_prompt_cache_len[i] - cur_q = q[start_loc : start_loc + seq_len - prompt_cache_len, :, :] - cur_q = cur_q.clone().to(torch.float32) - cur_k = k[i, :seq_len, :] - cur_k = cur_k.clone().to(torch.float32) - - cur_v = v[i, :seq_len, :] - cur_v = cur_v.clone().to(torch.float32) - - cur_q = cur_q.transpose(0, 1) - cur_k = cur_k.transpose(0, 1) - cur_v = cur_v.transpose(0, 1) - dk = cur_q.shape[-1] - - p = torch.matmul(cur_q, cur_k.transpose(-2, -1)) / torch.sqrt(torch.tensor(dk, dtype=torch.float32)) - - q_index = torch.arange(cur_q.shape[1]).unsqueeze(-1).to(p.device) - k_index = torch.arange(cur_k.shape[1]).unsqueeze(0).to(p.device) - mask = (q_index + prompt_cache_len >= k_index).int() - mask = mask.unsqueeze(0).expand(cur_q.shape[0], -1, -1) - - p = p.masked_fill(mask == 0, float("-inf")) - - s = F.softmax(p, dim=-1) - - o[start_loc : start_loc + seq_len - prompt_cache_len, :, :] = torch.matmul(s, cur_v).transpose(0, 1) - - -def test2(): - import torch - import numpy as np - - Z, H, N_CTX, D_HEAD = 16, 16, 2048, 128 - dtype = torch.float16 - prompt_cache_len = 0 - q = torch.empty((Z * (N_CTX - prompt_cache_len), H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - kv = torch.empty((Z, 2 * H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - k = kv[:, :H] - v = kv[:, H:] - # v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - o = torch.empty((Z * (N_CTX - prompt_cache_len), H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - torch_o = torch.empty((Z * (N_CTX - prompt_cache_len), H, D_HEAD), dtype=dtype, device="cuda").normal_( - mean=0.3, std=0.2 - ) - max_input_len = N_CTX - b_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - b_prompt_cache_len = torch.zeros(Z, dtype=torch.int32, device="cuda") - - for i in range(Z): - b_seq_len[i] = N_CTX - if i != 0: - b_start_loc[i] = b_start_loc[i - 1] + N_CTX - prompt_cache_len - b_prompt_cache_len[i] = prompt_cache_len - torch_context_attention_fwd2(q, k, v, torch_o, b_start_loc, b_seq_len, b_prompt_cache_len) - - import time - - torch.cuda.synchronize() - a = time.time() - for i in range(1000): - context_attention_fwd_ppl_int8kv(q, k, v, o, b_start_loc, b_seq_len, max_input_len, b_prompt_cache_len) - torch.cuda.synchronize() - b = time.time() - # print(o.shape, torch_out.shape) - print((b - a)) - - print("max ", torch.max(torch.abs(torch_o - o))) - print("mean ", torch.mean(torch.abs(torch_o - o))) - assert torch.allclose(torch_o, o, atol=1e-2, rtol=0) - - if __name__ == "__main__": test() - test2() diff --git a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad.py index 5ab4be6b33..d1a53f873f 100644 --- a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad.py +++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad.py @@ -2,7 +2,7 @@ import time import pytest from lightllm.common.basemodel.triton_kernel.att.prefill_att.context_flashattention_nopad import ( - context_attention_fwd_ppl_int8kv, + context_attention_fwd_contiguous_kv, ) from lightllm.utils.log_utils import init_logger @@ -54,7 +54,7 @@ def torch_context_attention_fwd2(q, k, v, o, b_start_loc, b_kv_start_loc, b_seq_ for prompt_cache_len in [0, 56, 200] ], ) -def test_context_attention_fwd_ppl_int8kv(B, H, N_CTX, D_HEAD, prompt_cache_len): +def test_context_attention_fwd_contiguous_kv(B, H, N_CTX, D_HEAD, prompt_cache_len): dtype = torch.float16 prompt_cache_len = 0 if prompt_cache_len >= N_CTX - 1: @@ -83,7 +83,7 @@ def test_context_attention_fwd_ppl_int8kv(B, H, N_CTX, D_HEAD, prompt_cache_len) b_kv_start_loc = torch.cumsum(b_seq_len, dim=0, dtype=torch.int32) - b_seq_len torch_context_attention_fwd2(q, k, v, torch_o, b_start_loc, b_kv_start_loc, b_seq_len, b_prompt_cache_len) - context_attention_fwd_ppl_int8kv( + context_attention_fwd_contiguous_kv( q=q, k=k, v=v, From 61e4f2447da25adf0d6ed79cd51477f41cb0b19b Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 02:26:08 +0000 Subject: [PATCH 025/114] add int4kv backend --- .../attention/int4kv_triton_backend.py | 47 ++---- .../att/decode_att/int4kv/__init__.py | 0 .../int4kv/ppl_int4kv_flash_decoding.py | 50 +++++++ .../llama/triton_kernel/ppl_int4kv_copy_kv.py | 138 ------------------ 4 files changed, 59 insertions(+), 176 deletions(-) create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/__init__.py create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py delete mode 100644 lightllm/models/llama/triton_kernel/ppl_int4kv_copy_kv.py diff --git a/lightllm/common/basemodel/attention/int4kv_triton_backend.py b/lightllm/common/basemodel/attention/int4kv_triton_backend.py index f2421dd815..f1c1491c52 100644 --- a/lightllm/common/basemodel/attention/int4kv_triton_backend.py +++ b/lightllm/common/basemodel/attention/int4kv_triton_backend.py @@ -134,47 +134,18 @@ def decode_att( assert att_control.use_alibi is False k, k_scale = k v, v_scale = v - if enable_diverse_mode_gqa_decode_fast_kernel(): - return self.diverse_decode_att( - q=q, k=k, k_scale=k_scale, v=v, v_scale=v_scale, layer_weight=layer_weight, alloc_func=alloc_func - ) - else: - return self.ppl_mha_int8kv_decode_att( - q=q, - k=k, - k_scale=k_scale, - v=v, - v_scale=v_scale, - layer_weight=layer_weight, - alloc_func=alloc_func, - ) - - def diverse_decode_att( - self, - q: torch.Tensor, - k: torch.Tensor, - k_scale: torch.Tensor, - v: torch.Tensor, - v_scale: torch.Tensor, - layer_weight, - alloc_func=torch.empty, - ) -> torch.Tensor: - - from ..triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse import ( - token_decode_attention_flash_decoding, - ) - return token_decode_attention_flash_decoding( + return self.ppl_int4kv_decode_att( q=q, - infer_state=self.infer_state, - cache_k=k, - cache_k_scale=k_scale, - cache_v=v, - cache_v_scale=v_scale, - alloc_tensor_func=alloc_func, + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + layer_weight=layer_weight, + alloc_func=alloc_func, ) - def ppl_mha_int8kv_decode_att( + def ppl_int4kv_decode_att( self, q: torch.Tensor, k: torch.Tensor, @@ -184,7 +155,7 @@ def ppl_mha_int8kv_decode_att( layer_weight, alloc_func=torch.empty, ) -> torch.Tensor: - from ..triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding import ( + from ..triton_kernel.att.decode_att.int4kv.ppl_int4kv_flash_decoding import ( token_decode_attention_flash_decoding, ) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py new file mode 100644 index 0000000000..8c61ed3c4e --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py @@ -0,0 +1,50 @@ +import torch +from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops + + +def token_decode_attention_flash_decoding( + q, + infer_state, + cache_k, + cache_k_scale, + cache_v, + cache_v_scale, + out=None, + alloc_tensor_func=torch.empty, +): + BLOCK_SEQ = 256 + batch_size = infer_state.batch_size + max_len_in_batch = infer_state.max_len_in_batch + q_head_num = q.shape[1] + head_dim = q.shape[2] + calcu_shape1 = (batch_size, q_head_num, head_dim) + + from ..mha.flash_decoding.flash_decoding_stage2 import flash_decode_stage2 + + o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out + + mid_o = alloc_tensor_func( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" + ) + mid_o_logexpsum = alloc_tensor_func( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" + ) + + light_ops.group8_int4kv_flashdecoding_stage1( + BLOCK_SEQ, + mid_o, + mid_o_logexpsum, + 1.0 / (head_dim ** 0.5), + q.view(calcu_shape1), + cache_k, + cache_k_scale, + cache_v, + cache_v_scale, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + + flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) + return o_tensor diff --git a/lightllm/models/llama/triton_kernel/ppl_int4kv_copy_kv.py b/lightllm/models/llama/triton_kernel/ppl_int4kv_copy_kv.py deleted file mode 100644 index 7ba0f3b31b..0000000000 --- a/lightllm/models/llama/triton_kernel/ppl_int4kv_copy_kv.py +++ /dev/null @@ -1,138 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_destindex_copy_quantize_int4_kv( - K, - Dest_loc, - Out, - Out_scale, - stride_k_bs, - stride_k_h, - stride_k_g, - stride_k_d, - stride_o_bs, - stride_o_h, - stride_o_g, - stride_o_d, - stride_os_bs, - stride_os_h, - stride_os_g, - group_size, - BLOCK_GROUP_NUM: tl.constexpr, - BLOCK_GROUP_DIM: tl.constexpr, -): - cur_index = tl.program_id(0) - cur_head = tl.program_id(1) - - offs_g = tl.arange(0, BLOCK_GROUP_NUM) - offs_d = tl.arange(0, BLOCK_GROUP_DIM // 2) - - dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) - - src_data_0 = tl.load( - K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :] * 2, - mask=offs_g[:, None] < group_size, - other=0.0, - ) - src_data_1 = tl.load( - K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :] * 2 + 1, - mask=offs_g[:, None] < group_size, - other=0.0, - ) - - abs_data_0 = tl.abs(src_data_0) - abs_data_1 = tl.abs(src_data_1) - - data_scale = (tl.maximum(tl.max(abs_data_0, axis=1), tl.max(abs_data_1, axis=1)) / 7.0).to(Out_scale.dtype.element_ty) - q_src_data_0 = (src_data_0 / data_scale[:, None]).to(tl.int8) - q_src_data_0 = tl.where(q_src_data_0 > 7, 7, q_src_data_0) - q_src_data_0 = tl.where(q_src_data_0 < -7, -7, q_src_data_0) - - q_src_data_1 = (src_data_1 / data_scale[:, None]).to(tl.int8) - q_src_data_1 = tl.where(q_src_data_1 > 7, 7, q_src_data_1) - q_src_data_1 = tl.where(q_src_data_1 < -7, -7, q_src_data_1) - - low_4 = ((q_src_data_0 & 0x80) >> 4) | (q_src_data_0 & 0xF) - high_4 = (((q_src_data_1 & 0x80) >> 4) | (q_src_data_1 & 0xF)) << 4 - - # tl.device_print(low_4) - # tl.device_print(high_4) - - out_data = low_4 | high_4 - - o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :] - os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g - tl.store(o_ptrs, out_data, mask=offs_g[:, None] < group_size) - tl.store(os_ptrs, data_scale, mask=offs_g < group_size) - return - - -@torch.no_grad() -def destindex_copy_int4kv(K, DestLoc, Out, Out_scale): - # seq_len = DestLoc.shape[0] - # head_num = K.shape[1] - head_dim = K.shape[2] - quant_group_dim = 8 - - assert head_dim % quant_group_dim == 0, "error head dim, can not been supported to copy quant kv" - # grid = (seq_len, head_num) - # num_warps = 1 - - group_size = head_dim // quant_group_dim - group_dim = quant_group_dim - - K = K.view((K.shape[0], K.shape[1], group_size, group_dim)) - Out = Out.view( - Out.shape[0], Out.shape[1], group_size, group_dim // 2 - ) # OUt 是 int8 类型, 两个int4组一个int8,所以 group_dim // 2 - - from lightllm_ppl_int4kv_flashdecoding_kernel import group8_copy_int4_kv - - group8_copy_int4_kv(Out, Out_scale, K, DestLoc, 4) - - # _fwd_kernel_destindex_copy_quantize_int4_kv[grid]( - # K, - # DestLoc, - # Out, - # Out_scale, - # K.stride(0), - # K.stride(1), - # K.stride(2), - # K.stride(3), - # Out.stride(0), - # Out.stride(1), - # Out.stride(2), - # Out.stride(3), - # Out_scale.stride(0), - # Out_scale.stride(1), - # Out_scale.stride(2), - # group_size, - # BLOCK_GROUP_NUM=triton.next_power_of_2(group_size), - # BLOCK_GROUP_DIM=group_dim, - # num_warps=num_warps, - # num_stages=1, - # ) - return - - -def test2(): - import time - - src = torch.randn((1, 1, 8), dtype=torch.float16).cuda() - src[0, 0, :] = torch.tensor([1, -2, 2, 0, 4, 5, 6, 7]).cuda() - dest_loc = torch.arange(0, 1, dtype=torch.int32).cuda() - value_dest = torch.randn((1, 1, 4), dtype=torch.float16).cuda().to(torch.int8) - scale_dest = torch.randn((1, 1, 1), dtype=torch.float16).cuda() - - destindex_copy_int4kv(src, dest_loc, value_dest, scale_dest) - - print(value_dest) - print(scale_dest) - - -if __name__ == "__main__": - test2() From deba72765b787c7247d7565b694ce273a7753cce Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 02:29:16 +0000 Subject: [PATCH 026/114] fix --- .../att/decode_att/ppl_fp16/__init__.py | 0 .../ppl_fp16/ppl_fp16_flash_decoding.py} | 23 +++++----------- .../layer_infer/transformer_layer_infer.py | 26 ------------------- 3 files changed, 6 insertions(+), 43 deletions(-) create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/__init__.py rename lightllm/{models/llama/triton_kernel/ppl_int4kv_flash_decoding.py => common/basemodel/triton_kernel/att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py} (69%) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/llama/triton_kernel/ppl_int4kv_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py similarity index 69% rename from lightllm/models/llama/triton_kernel/ppl_int4kv_flash_decoding.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py index 1e324bcc0b..fc21848e16 100644 --- a/lightllm/models/llama/triton_kernel/ppl_int4kv_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py @@ -1,25 +1,16 @@ import torch +from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops -def token_decode_attention_flash_decoding( - q, - infer_state, - q_head_num, - head_dim, - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - out=None, - alloc_tensor_func=torch.empty, -): +def token_decode_attention_flash_decoding(q, infer_state, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty): BLOCK_SEQ = 256 batch_size = infer_state.batch_size + q_head_num = q.shape[1] + head_dim = q.shape[2] max_len_in_batch = infer_state.max_len_in_batch calcu_shape1 = (batch_size, q_head_num, head_dim) - from lightllm_ppl_int4kv_flashdecoding_kernel import group8_int4kv_flashdecoding_stage1 - from .flash_decoding_stage2 import flash_decode_stage2 + from ..mha.flash_decoding.flash_decoding_stage2 import flash_decode_stage2 o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out @@ -30,16 +21,14 @@ def token_decode_attention_flash_decoding( [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" ) - group8_int4kv_flashdecoding_stage1( + light_ops.fp16_flashdecoding_stage1( BLOCK_SEQ, mid_o, mid_o_logexpsum, 1.0 / (head_dim ** 0.5), q.view(calcu_shape1), cache_k, - cache_k_scale, cache_v, - cache_v_scale, infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, infer_state.b_seq_len, diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 439cadd6bc..258736265b 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -606,32 +606,6 @@ def _token_decode_attention_ppl_fp16_flashdecoding( alloc_tensor_func=self.alloc_tensor, ) - def _token_decode_attention_ppl_int4kv_flashdecoding( - self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None - ): - from lightllm.models.llama.triton_kernel.ppl_int4kv_flash_decoding import token_decode_attention_flash_decoding - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_k_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - cache_v_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - def _token_decode_attention_flashattention(self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None): cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( -1, 1, self.tp_k_head_num_, self.head_dim_ From 8ef04b42894794e91211047593d8034b5e580f9c Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 04:52:13 +0000 Subject: [PATCH 027/114] add fa3 --- .../common/basemodel/attention/fa3_backend.py | 230 ++++++++++++++++++ .../llama/flashattention_infer_struct.py | 106 -------- lightllm/models/llama/infer_struct.py | 1 - .../layer_infer/transformer_layer_infer.py | 55 +---- .../triton_kernel/ppl_fp16_flash_decoding.py | 39 --- .../generic_padded_pre_process.py | 8 +- .../mode_backend/generic_pre_process.py | 3 +- 7 files changed, 239 insertions(+), 203 deletions(-) create mode 100644 lightllm/common/basemodel/attention/fa3_backend.py delete mode 100644 lightllm/models/llama/flashattention_infer_struct.py delete mode 100644 lightllm/models/llama/triton_kernel/ppl_fp16_flash_decoding.py diff --git a/lightllm/common/basemodel/attention/fa3_backend.py b/lightllm/common/basemodel/attention/fa3_backend.py new file mode 100644 index 0000000000..f0ee9b9b67 --- /dev/null +++ b/lightllm/common/basemodel/attention/fa3_backend.py @@ -0,0 +1,230 @@ +import dataclasses +import torch +from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from typing import Optional, TYPE_CHECKING +from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id +from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy + +if TYPE_CHECKING: + from lightllm.common.basemodel.basemodel import TpPartBaseModel + + +class Fa3AttBackend(BaseAttBackend): + def __init__(self, model: "TpPartBaseModel"): + super().__init__() + self.model = model + tp_world_size = get_dp_world_size() + self.tp_q_head_num = model.config["num_attention_heads"] // tp_world_size + self.tp_kv_head_num = max(model.config["num_key_value_heads"] // tp_world_size, 1) + head_dim = model.config["hidden_size"] // model.config["num_attention_heads"] + self.head_dim = model.config.get("head_dim", head_dim) + self.workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) + self.max_seq_length = model.max_seq_length + self.kv_indices_buffer = [ + torch.empty( + model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() + ), + torch.empty( + model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() + ), + ] + self.q_data_type = model.data_type + self.kv_data_type = model.data_type + + def get_page_table_buffer(self): + """ + 用于减少 decode graph 捕获的时候, 造成显存二次方增长的情况. + """ + model = self.model + if self._shared_page_table_buffer is None: + self._shared_page_table_buffer = [ + torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( + get_current_device_id() + ), + torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( + get_current_device_id() + ), + ] + return self._shared_page_table_buffer + + def create_att_prefill_state(self, infer_state) -> "Fa3PrefillAttState": + return Fa3PrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "Fa3DecodeAttState": + return Fa3DecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class Fa3PrefillAttState(BasePrefillAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + + def init_state(self): + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + self.page_table = torch.empty( + (self.infer_state.batch_size, self.infer_state.max_kv_seq_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + self.page_table.copy_( + self.infer_state.req_manager.req_to_token_indexs[ + self.infer_state.b_req_idx, : self.infer_state.max_kv_seq_len + ] + ) + + def copy_for_prefill_cuda_graph(self, new_state: "Fa3PrefillAttState"): + pass + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.use_alibi is False + return self._nomarl_prefill_att( + q=q, + k=k, + v=v, + layer_weight=layer_weight, + alloc_func=alloc_func, + ) + + def _nomarl_prefill_att( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_weight, alloc_func=torch.empty + ) -> torch.Tensor: + self.backend: Fa3AttBackend = self.backend # for typing + + k_descale, v_descale = None, None # disable quantization + Lq = q.shape[-1] + sm_scale = 1.0 / (Lq ** 0.5) + o = flash_attn_with_kvcache( + q=q, + k_cache=k, + v_cache=v, + page_table=self.page_table, + cache_seqlens=self.infer_state.b_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + softmax_scale=sm_scale, + causal=True, + window_size=(-1, -1), + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=False, + ) + return o + + +@dataclasses.dataclass +class Fa3DecodeAttState(BaseDecodeAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + b_att_seq_len: torch.Tensor = None + # 在是否开启mtp 的不同模式下,其设置不同的值,可以加速算子的运行。 + decode_max_q_seq_len: int = None + + def init_state(self): + self.backend: Fa3AttBackend = self.backend + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + args_mtp_step = get_env_start_args().mtp_step + att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) + assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 + + model = self.backend.model + # 可以使用 cuda graph的时候从 buffer中申请 + if ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ): + page_buffer = self.backend.get_page_table_buffer(model.graph_max_batch_size, model.graph_max_len_in_batch) + self.page_table = page_buffer[self.infer_state.microbatch_index][ + : att_batch_size * model.graph_max_len_in_batch + ].reshape(att_batch_size, model.graph_max_len_in_batch) + else: + self.page_table = torch.empty( + (att_batch_size, self.infer_state.max_kv_seq_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + + if args_mtp_step > 0: + page_table_copy( + page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], + ) + self.b_att_seq_len = self.infer_state.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() + self.decode_max_q_seq_len = args_mtp_step + 1 + else: + page_table_copy( + page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + ) + self.b_att_seq_len = self.infer_state.b_seq_len + self.decode_max_q_seq_len = 1 + return + + def copy_for_decode_cuda_graph(self, new_state: "Fa3DecodeAttState"): + pass + + def decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert att_control.use_alibi is False + return self._normal_decode_gqa_att( + q=q, + k=k, + v=v, + layer_weight=layer_weight, + alloc_func=alloc_func, + ) + + def _normal_decode_gqa_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + alloc_func=torch.empty, + ): + k_descale, v_descale = None, None # disable quantization + Lq = q.shape[-1] + sm_scale = 1.0 / (Lq ** 0.5) + o = flash_attn_with_kvcache( + q=q, + k_cache=k, + v_cache=v, + page_table=self.page_table, + cache_seqlens=self.b_att_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.decode_max_q_seq_len, + softmax_scale=sm_scale, + causal=True, + window_size=(-1, -1), + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=False, + ) + return o diff --git a/lightllm/models/llama/flashattention_infer_struct.py b/lightllm/models/llama/flashattention_infer_struct.py deleted file mode 100644 index 9f71cbbc56..0000000000 --- a/lightllm/models/llama/flashattention_infer_struct.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -import torch -import numpy as np -import torch.distributed as dist -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.envs_utils import get_env_start_args -from lightllm.utils.dist_utils import get_current_device_id -from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index -from lightllm.common.basemodel.batch_objs import ModelInput -from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy - - -class FlashAttentionStateInfo(LlamaInferStateInfo): - _shared_page_table_buffer = None - - def __init__(self): - super().__init__() - - @classmethod - def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): - if cls._shared_page_table_buffer is None: - cls._shared_page_table_buffer = [ - torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()), - torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()), - ] - return cls._shared_page_table_buffer - - def _init_flash_attention_state(self, model): - if self.is_prefill: - self.cu_seqlens_q = self.b1_cu_q_seq_len.int() - self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() - self.page_table = torch.empty( - (self.batch_size, self.max_seq_len), dtype=torch.int32, device=self.input_ids.device - ) - self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len]) - else: - # Meta information of flashattention for decoding - self.cu_seqlens_q = self.b1_cu_q_seq_len.int() - self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() - max_seq_len_k = self.max_kv_seq_len - args_mtp_step = get_env_start_args().mtp_step - att_batch_size = self.batch_size // (args_mtp_step + 1) - if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: - page_buffer = FlashAttentionStateInfo.get_page_table_buffer( - model.graph_max_batch_size, model.graph_max_len_in_batch - ) - self.page_table = page_buffer[self.microbatch_index][ - : att_batch_size * model.graph_max_len_in_batch - ].reshape(att_batch_size, model.graph_max_len_in_batch) - else: - self.page_table = torch.empty( - (att_batch_size, self.max_len_in_batch), dtype=torch.int32, device=self.input_ids.device - ) - - page_table_copy( - page_table=self.page_table[:, :max_seq_len_k], - req_to_token_indexs=model.req_manager.req_to_token_indexs, - b_req_idx=self.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], - ) - if args_mtp_step > 0: - self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() - else: - self.b_att_seq_len = self.b_seq_len - - if "offline_calibration_fp8kv" in model.mode: - if self.is_prefill: - device = self.input_ids.device - # q_scale和token_batch_ids在对q做per head量化使用,为了节省资源在推理外部初始化 - self.q_scale = torch.empty( - (self.batch_size, self.mem_manager.head_num), dtype=torch.float32, device=device - ) - self.token_batch_ids = torch.repeat_interleave( - torch.arange(self.batch_size, device=device), self.b_q_seq_len - ) - - offline_scales = self.mem_manager.scales - head_num = self.mem_manager.head_num - # 为了减少推理计算量,在推理外部初始化k_descale和v_descale - self.k_descale = ( - offline_scales[:, :head_num] - .view(-1, 1, head_num) - .expand(offline_scales.shape[0], self.batch_size, head_num) - if offline_scales is not None - else torch.ones( - (self.mem_manager.layer_num, self.batch_size, head_num), - dtype=torch.float32, - device=self.input_ids.device, - ) - ) - self.v_descale = ( - offline_scales[:, head_num:] - .view(-1, 1, head_num) - .expand(offline_scales.shape[0], self.batch_size, head_num) - if offline_scales is not None - else torch.ones( - (self.mem_manager.layer_num, self.batch_size, head_num), - dtype=torch.float32, - device=self.input_ids.device, - ) - ) - return - - def init_some_extra_state(self, model): - super().init_some_extra_state(model) - self._init_flash_attention_state(model) - return diff --git a/lightllm/models/llama/infer_struct.py b/lightllm/models/llama/infer_struct.py index 3bba439767..fe6ca392a2 100644 --- a/lightllm/models/llama/infer_struct.py +++ b/lightllm/models/llama/infer_struct.py @@ -14,7 +14,6 @@ def init_some_extra_state(self, model): super().init_some_extra_state(model) if self.is_prefill: self.max_seq_len = self.max_kv_seq_len - self.q_max_seq_len = self.max_q_seq_len position_ids = self.position_ids self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 258736265b..ff96127a67 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -97,26 +97,9 @@ def _bind_attention(self): else: # self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self) pass - if "ppl_int8kv" in self.mode: - # self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv, - # self) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, self) - # self._context_attention_kernel = partial( - # LlamaTransformerLayerInfer._context_attention_kernel_ppl_int8kv, self - # ) - elif "ppl_int4kv_flashdecoding" in self.mode: - # self._token_attention_kernel = partial( - # LlamaTransformerLayerInfer._token_decode_attention_ppl_int4kv_flashdecoding, self - # ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int4kv, self) - elif "ppl_fp16" in self.mode: - # self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_ppl_fp16, self) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - elif "ppl_fp16_flashdecoding" in self.mode: - # self._token_attention_kernel = partial( - # LlamaTransformerLayerInfer._token_decode_attention_ppl_fp16_flashdecoding, self - # ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) + + if "int8kv" in self.mode: + pass elif "offline_calibration_fp8kv" in self.mode: assert get_env_start_args().enable_flashinfer_prefill and get_env_start_args().enable_flashinfer_decode self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_fp8kv, self) @@ -269,38 +252,6 @@ def _context_attention_flashinfer_kernel( ) return o_tensor - def _context_attention_kernel_ppl_int8kv( - self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - batch_size = infer_state.b_seq_len.shape[0] - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - kv_scale = infer_state.mem_manager.scale_buffer[self.layer_num_] - max_seq_len = infer_state.max_seq_len - kv_dequant = self.alloc_tensor( - (batch_size, kv.shape[1], max_seq_len, kv.shape[2]), device=q.device, dtype=q.dtype - ) - destindex_copy_dequantize_kv( - kv, - kv_scale, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_seq_len, - infer_state.b_req_idx, - max_seq_len, - kv_dequant, - ) - # context_attention_fwd_ppl_int8kv( - # q.view(-1, self.tp_q_head_num_, self.head_dim_), - # kv_dequant[:, 0 : self.tp_k_head_num_, :, :], - # kv_dequant[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :, :], - # o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - # infer_state.b_start_loc, - # infer_state.b_seq_len, - # infer_state.max_len_in_batch, - # infer_state.b_ready_cache_len, - # ) - return o_tensor - def _context_attention_flashattention(self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None): cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( -1, 1, self.tp_k_head_num_, self.head_dim_ diff --git a/lightllm/models/llama/triton_kernel/ppl_fp16_flash_decoding.py b/lightllm/models/llama/triton_kernel/ppl_fp16_flash_decoding.py deleted file mode 100644 index 8fda084605..0000000000 --- a/lightllm/models/llama/triton_kernel/ppl_fp16_flash_decoding.py +++ /dev/null @@ -1,39 +0,0 @@ -import torch - - -def token_decode_attention_flash_decoding( - q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty -): - BLOCK_SEQ = 256 - batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch - calcu_shape1 = (batch_size, q_head_num, head_dim) - - from lightllm_ppl_fp16_flashdecoding_kernel import fp16_flashdecoding_stage1 - from .flash_decoding_stage2 import flash_decode_stage2 - - o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out - - mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" - ) - mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" - ) - - fp16_flashdecoding_stage1( - BLOCK_SEQ, - mid_o, - mid_o_logexpsum, - 1.0 / (head_dim ** 0.5), - q.view(calcu_shape1), - cache_k, - cache_v, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) - return o_tensor diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 6465995c45..7845b9b2ca 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -150,6 +150,7 @@ def padded_prepare_decode_inputs( seq_len = req.get_cur_total_len() assert req.cur_kv_len == seq_len - 1 b_seq_len.append(seq_len) + b_q_seq_len.append(1) total_token_num += seq_len b_mtp_index.append(0) batch_multimodal_params.append(req.multimodal_params) @@ -160,29 +161,28 @@ def padded_prepare_decode_inputs( total_token_num += seq_len b_req_idx.append(req.req_idx) b_seq_len.append(seq_len) + b_q_seq_len.append(1) b_mtp_index.append(step + 1) batch_multimodal_params.append(req.multimodal_params) - b_q_seq_len.append(req.mtp_step + 1) - # padding fake req for decode for _ in range(padded_req_num): seq_len = 2 total_token_num += seq_len b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID) b_seq_len.append(seq_len) + b_q_seq_len.append(1) b_mtp_index.append(0) batch_multimodal_params.append({"images": [], "audios": []}) for step in range(args_mtp_step): seq_len += 1 total_token_num += seq_len b_seq_len.append(seq_len) + b_q_seq_len.append(1) b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID) b_mtp_index.append(step + 1) batch_multimodal_params.append({"images": [], "audios": []}) - b_q_seq_len.append(1 + args_mtp_step) - max_kv_seq_len = max(b_seq_len) max_q_seq_len = max(b_q_seq_len) max_len_in_batch = max(b_seq_len) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index bdb36054b4..963394116a 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -107,6 +107,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In seq_len = req.get_cur_total_len() assert req.cur_kv_len == seq_len - 1, f"{req.cur_kv_len} {seq_len}" b_seq_len.append(seq_len) + b_q_seq_len.append(1) total_token_num += seq_len max_len_in_batch = max(max_len_in_batch, seq_len) b_mtp_index.append(0) @@ -121,7 +122,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In max_len_in_batch = max(max_len_in_batch, seq_len) b_mtp_index.append(step + 1) multimodal_params.append(req.multimodal_params) - b_q_seq_len.append(req.mtp_step + 1) + b_q_seq_len.append(1) max_kv_seq_len = max(b_seq_len) max_q_seq_len = max(b_q_seq_len) From eb5b94a2209d7cf6e2078846e9671b5ac8f9a054 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 05:05:36 +0000 Subject: [PATCH 028/114] add fa3 --- .../common/basemodel/attention/fa3_backend.py | 18 +-- .../layer_infer/transformer_layer_infer.py | 110 +----------------- 2 files changed, 5 insertions(+), 123 deletions(-) diff --git a/lightllm/common/basemodel/attention/fa3_backend.py b/lightllm/common/basemodel/attention/fa3_backend.py index f0ee9b9b67..774793cca1 100644 --- a/lightllm/common/basemodel/attention/fa3_backend.py +++ b/lightllm/common/basemodel/attention/fa3_backend.py @@ -15,23 +15,7 @@ class Fa3AttBackend(BaseAttBackend): def __init__(self, model: "TpPartBaseModel"): super().__init__() self.model = model - tp_world_size = get_dp_world_size() - self.tp_q_head_num = model.config["num_attention_heads"] // tp_world_size - self.tp_kv_head_num = max(model.config["num_key_value_heads"] // tp_world_size, 1) - head_dim = model.config["hidden_size"] // model.config["num_attention_heads"] - self.head_dim = model.config.get("head_dim", head_dim) - self.workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) - self.max_seq_length = model.max_seq_length - self.kv_indices_buffer = [ - torch.empty( - model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() - ), - torch.empty( - model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() - ), - ] - self.q_data_type = model.data_type - self.kv_data_type = model.data_type + self.get_page_table_buffer() # init def get_page_table_buffer(self): """ diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index ff96127a67..4215b75ff4 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -11,12 +11,10 @@ from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo from lightllm.models.llama.flashinfer_struct import LlamaFlashInferStateInfo from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv, destindex_copy_quantize_kv from lightllm.common.basemodel.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 from lightllm.common.basemodel import TransformerLayerInferTpl -from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_dequantize_kv from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args @@ -61,24 +59,8 @@ def _bind_norm(self): def _bind_attention(self): if get_env_start_args().enable_fa3: - if "offline_calibration_fp8kv" in self.mode: - # self._context_attention_kernel = partial( - # LlamaTransformerLayerInfer._context_attention_flashattention_fp8, self - # ) - # self._token_attention_kernel = partial( - # LlamaTransformerLayerInfer._token_decode_attention_flashattention_fp8, self - # ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_fp8kv, self) - elif "export_fp8kv_calibration" in self.mode: - # self._context_attention_kernel = partial( - # LlamaTransformerLayerInfer._context_attention_flashattention, self - # ) - # self._token_attention_kernel = partial( - # LlamaTransformerLayerInfer._token_decode_attention_flashattention, self - # ) - self._copy_kv_to_mem_cache = partial( - LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self - ) + if True: + pass elif not self.mode: # self._context_attention_kernel = partial( # LlamaTransformerLayerInfer._context_attention_flashattention, self @@ -252,39 +234,7 @@ def _context_attention_flashinfer_kernel( ) return o_tensor - def _context_attention_flashattention(self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None): - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( - -1, 1, self.tp_k_head_num_, self.head_dim_ - ) - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) - q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) - k_descale, v_descale = None, None # disable quantization - Lq = q.shape[-1] - sm_scale = 1.0 / (Lq ** 0.5) - o = flash_attn_with_kvcache( - q=q, - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=infer_state.q_max_seq_len, - softmax_scale=sm_scale, - causal=True, - window_size=(-1, -1), - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, - ) - return o - - def _context_attention_flashattention_fp8( - self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None - ): + def _context_attention_flashattention_fp8(self, q, kv, infer_state, layer_weight, out=None): q, q_scale = q_per_head_fp8_quant( q.view(q.shape[0], self.tp_k_head_num_, -1), infer_state.b_seq_len, @@ -537,59 +487,7 @@ def _token_decode_attention_ppl_fp16(self, q, infer_state: LlamaInferStateInfo, return o_tensor - def _token_decode_attention_ppl_fp16_flashdecoding( - self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None - ): - from lightllm.models.llama.triton_kernel.ppl_fp16_flash_decoding import token_decode_attention_flash_decoding - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_v, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - - def _token_decode_attention_flashattention(self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None): - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( - -1, 1, self.tp_k_head_num_, self.head_dim_ - ) - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) - q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) - k_descale, v_descale = None, None # disable quantization - Lq = q.shape[-1] - sm_scale = 1.0 / (Lq ** 0.5) - o = flash_attn_with_kvcache( - q=q, - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_att_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=infer_state.max_q_seq_len, - softmax_scale=sm_scale, - causal=True, - window_size=(-1, -1), - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, - ) - return o - - def _token_decode_attention_flashattention_fp8( - self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None - ): + def _token_decode_attention_flashattention_fp8(self, q, infer_state, layer_weight, out=None): cache_k = ( (infer_state.mem_manager.kv_buffer[self.layer_num_][:, : self.tp_k_head_num_, :]) .reshape(-1, 1, self.tp_k_head_num_, self.head_dim_) From abdc9281eceb1cd99350fd14e53d9cb198ca0cca Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 06:24:55 +0000 Subject: [PATCH 029/114] fix --- .../common/basemodel/attention/fa3_backend.py | 28 +- .../basemodel/attention/fp8_fa3_backend.py | 314 ++++++++++++++++++ .../triton_kernel/gen_decode_params.py | 6 +- .../layer_infer/transformer_layer_infer.py | 151 --------- 4 files changed, 339 insertions(+), 160 deletions(-) create mode 100644 lightllm/common/basemodel/attention/fp8_fa3_backend.py diff --git a/lightllm/common/basemodel/attention/fa3_backend.py b/lightllm/common/basemodel/attention/fa3_backend.py index 774793cca1..d833a97ec8 100644 --- a/lightllm/common/basemodel/attention/fa3_backend.py +++ b/lightllm/common/basemodel/attention/fa3_backend.py @@ -6,6 +6,7 @@ from lightllm.utils.sgl_utils import flash_attn_with_kvcache from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy +from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor if TYPE_CHECKING: from lightllm.common.basemodel.basemodel import TpPartBaseModel @@ -120,10 +121,27 @@ class Fa3DecodeAttState(BaseDecodeAttState): def init_state(self): self.backend: Fa3AttBackend = self.backend - self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() - self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() - args_mtp_step = get_env_start_args().mtp_step + + if args_mtp_step > 0: + # 修正 mtp 在 fa3 下的输入。 + mtp_size = args_mtp_step + 1 + b_q_seq_len = torch.full( + (self.infer_state.b_seq_len.shape[0] // mtp_size,), + fill_value=mtp_size, + dtype=torch.int32, + device=self.infer_state.b_seq_len.device, + ) + b_kv_seq_len = self.infer_state.b_seq_len[mtp_size - 1 :: mtp_size] + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor( + b_q_seq_len, b_kv_seq_len[mtp_size - 1 :: mtp_size] + ) + self.cu_seqlens_q = b1_cu_q_seq_len.int() + self.cu_seqlens_k = b1_cu_kv_seq_len.int() + else: + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 @@ -175,7 +193,7 @@ def decode_att( alloc_func=torch.empty, ): assert att_control.use_alibi is False - return self._normal_decode_gqa_att( + return self._normal_decode_att( q=q, k=k, v=v, @@ -183,7 +201,7 @@ def decode_att( alloc_func=alloc_func, ) - def _normal_decode_gqa_att( + def _normal_decode_att( self, q: torch.Tensor, k: torch.Tensor, diff --git a/lightllm/common/basemodel/attention/fp8_fa3_backend.py b/lightllm/common/basemodel/attention/fp8_fa3_backend.py new file mode 100644 index 0000000000..f6e22e10ec --- /dev/null +++ b/lightllm/common/basemodel/attention/fp8_fa3_backend.py @@ -0,0 +1,314 @@ +import dataclasses +import torch +from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from typing import Optional, TYPE_CHECKING +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy +from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant +from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor +from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops + +if HAS_VLLM: + scaled_fp8_quant = vllm_ops.scaled_fp8_quant +else: + scaled_fp8_quant = None + +if TYPE_CHECKING: + from lightllm.common.basemodel.basemodel import TpPartBaseModel + + +class Fp8Fa3AttBackend(BaseAttBackend): + def __init__(self, model: "TpPartBaseModel"): + super().__init__() + self.model = model + self.get_page_table_buffer() # init + + def get_page_table_buffer(self): + """ + 用于减少 decode graph 捕获的时候, 造成显存二次方增长的情况. + """ + model = self.model + if self._shared_page_table_buffer is None: + self._shared_page_table_buffer = [ + torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( + get_current_device_id() + ), + torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( + get_current_device_id() + ), + ] + return self._shared_page_table_buffer + + def create_att_prefill_state(self, infer_state) -> "Fp8Fa3PrefillAttState": + return Fp8Fa3PrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "Fp8Fa3DecodeAttState": + return Fp8Fa3DecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class Fp8Fa3PrefillAttState(BasePrefillAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + # 临时共享变量 + mid_token_batch_ids: torch.Tensor = None + k_descale: torch.Tensor = None + v_descale: torch.Tensor = None + + def init_state(self): + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + self.page_table = torch.empty( + (self.infer_state.batch_size, self.infer_state.max_kv_seq_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + self.page_table.copy_( + self.infer_state.req_manager.req_to_token_indexs[ + self.infer_state.b_req_idx, : self.infer_state.max_kv_seq_len + ] + ) + + device = self.infer_state.input_ids.device + batch_size = self.infer_state.batch_size + mem_manager = self.backend.model.mem_manager + + offline_scales: torch.Tensor = mem_manager.scales + head_num = mem_manager.head_num + self.mid_token_batch_ids = torch.repeat_interleave( + torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len + ) + # 为了减少推理计算量,在推理外部初始化k_descale和v_descale + self.k_descale = ( + offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + if offline_scales is not None + else torch.ones( + (mem_manager.layer_num, batch_size, head_num), + dtype=torch.float32, + device=device, + ) + ) + self.v_descale = ( + offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + if offline_scales is not None + else torch.ones( + (mem_manager.layer_num, batch_size, head_num), + dtype=torch.float32, + device=device, + ) + ) + + def copy_for_prefill_cuda_graph(self, new_state: "Fp8Fa3PrefillAttState"): + pass + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.use_alibi is False + return self._fp8_prefill_att( + q=q, + k=k, + v=v, + layer_weight=layer_weight, + alloc_func=alloc_func, + ) + + def _fp8_prefill_att( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_weight, alloc_func=torch.empty + ) -> torch.Tensor: + self.backend: Fp8Fa3AttBackend = self.backend # for typing + + q, q_scale = q_per_head_fp8_quant( + q, + self.infer_state.b_seq_len, + self.cu_seqlens_q, + self.mid_token_batch_ids, + ) + k_head_num = k.shape[1] + k_head_dim = k.shape[2] + cache_k = k.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn) + cache_v = v.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn) + o = flash_attn_with_kvcache( + q=q, + k_cache=cache_k, + v_cache=cache_v, + page_table=self.page_table, + cache_seqlens=self.infer_state.b_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + causal=True, + window_size=(-1, -1), + softcap=0.0, + q_descale=q_scale, + k_descale=self.k_descale[layer_weight.layer_num_], + v_descale=self.v_descale[layer_weight.layer_num_], + return_softmax_lse=False, + ) + return o + + +@dataclasses.dataclass +class Fp8Fa3DecodeAttState(BaseDecodeAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + b_att_seq_len: torch.Tensor = None + # 在是否开启mtp 的不同模式下,其设置不同的值,可以加速算子的运行。 + decode_max_q_seq_len: int = None + + k_descale: torch.Tensor = None + v_descale: torch.Tensor = None + + def init_state(self): + self.backend: Fp8Fa3AttBackend = self.backend + + args_mtp_step = get_env_start_args().mtp_step + if args_mtp_step > 0: + # 修正 mtp 在 fa3 下的输入。 + mtp_size = args_mtp_step + 1 + b_q_seq_len = torch.full( + (self.infer_state.b_seq_len.shape[0] // mtp_size,), + fill_value=mtp_size, + dtype=torch.int32, + device=self.infer_state.b_seq_len.device, + ) + b_kv_seq_len = self.infer_state.b_seq_len[mtp_size - 1 :: mtp_size] + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor( + b_q_seq_len, b_kv_seq_len[mtp_size - 1 :: mtp_size] + ) + self.cu_seqlens_q = b1_cu_q_seq_len.int() + self.cu_seqlens_k = b1_cu_kv_seq_len.int() + else: + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) + assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 + + model = self.backend.model + # 可以使用 cuda graph的时候从 buffer中申请 + if ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ): + page_buffer = self.backend.get_page_table_buffer(model.graph_max_batch_size, model.graph_max_len_in_batch) + self.page_table = page_buffer[self.infer_state.microbatch_index][ + : att_batch_size * model.graph_max_len_in_batch + ].reshape(att_batch_size, model.graph_max_len_in_batch) + else: + self.page_table = torch.empty( + (att_batch_size, self.infer_state.max_kv_seq_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + + if args_mtp_step > 0: + page_table_copy( + page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], + ) + self.b_att_seq_len = self.infer_state.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() + self.decode_max_q_seq_len = args_mtp_step + 1 + else: + page_table_copy( + page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + ) + self.b_att_seq_len = self.infer_state.b_seq_len + self.decode_max_q_seq_len = 1 + + device = self.infer_state.input_ids.device + batch_size = att_batch_size + mem_manager = self.backend.model.mem_manager + + offline_scales: torch.Tensor = mem_manager.scales + head_num = mem_manager.head_num + + # 为了减少推理计算量,在推理外部初始化k_descale和v_descale + self.k_descale = ( + offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + if offline_scales is not None + else torch.ones( + (mem_manager.layer_num, batch_size, head_num), + dtype=torch.float32, + device=device, + ) + ) + self.v_descale = ( + offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + if offline_scales is not None + else torch.ones( + (mem_manager.layer_num, batch_size, head_num), + dtype=torch.float32, + device=device, + ) + ) + return + + def copy_for_decode_cuda_graph(self, new_state: "Fp8Fa3DecodeAttState"): + pass + + def decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert att_control.use_alibi is False + return self._fp8_decode_att( + q=q, + k=k, + v=v, + layer_weight=layer_weight, + alloc_func=alloc_func, + ) + + def _fp8_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + alloc_func=torch.empty, + ): + k_head_num = k.shape[1] + k_head_dim = k.shape[2] + + cache_k = k.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn) + cache_v = v.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn) + + q_head_num = q.shape[1] + q, q_scale = scaled_fp8_quant(q.view(q.shape[0] * k_head_num, -1), use_per_token_if_dynamic=True) + o = flash_attn_with_kvcache( + q=q.view(-1, q_head_num, k_head_dim), + k_cache=cache_k, + v_cache=cache_v, + page_table=self.page_table, + cache_seqlens=self.infer_state.b_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.decode_max_q_seq_len, + causal=False, + window_size=(-1, -1), + softcap=0.0, + q_descale=q_scale.view(self.infer_state.batch_size, k_head_num), + k_descale=self.k_descale[layer_weight.layer_num_], + v_descale=self.v_descale[layer_weight.layer_num_], + return_softmax_lse=False, + ) + return o diff --git a/lightllm/common/basemodel/triton_kernel/gen_decode_params.py b/lightllm/common/basemodel/triton_kernel/gen_decode_params.py index 9804e46681..c8a6a850bc 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_decode_params.py +++ b/lightllm/common/basemodel/triton_kernel/gen_decode_params.py @@ -9,8 +9,6 @@ def gen_decode_params(b_seq_len: torch.Tensor): b_kv_seq_len = b_seq_len position_ids = b_seq_len - 1 - mtp_step = get_env_start_args().mtp_step - mtp_size = mtp_step + 1 - b_q_seq_len = torch.ones(b_seq_len.shape[0] // mtp_size, dtype=torch.int32, device=b_seq_len.device) * mtp_size - b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len[mtp_size - 1 :: mtp_size]) + b_q_seq_len = torch.ones(b_seq_len.shape[0], dtype=torch.int32, device=b_seq_len.device) + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) return b_q_seq_len, b1_cu_q_seq_len, b_kv_seq_len, b1_cu_kv_seq_len, position_ids diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 4215b75ff4..e6efb7e26d 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -1,15 +1,11 @@ import torch import triton -import torch.functional as F import torch.distributed as dist -import numpy as np -from typing import Tuple from functools import partial from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd - from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.flashinfer_struct import LlamaFlashInferStateInfo from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv, destindex_copy_quantize_kv @@ -18,19 +14,9 @@ from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args -from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops -from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant -from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops - -if HAS_VLLM: - scaled_fp8_quant = vllm_ops.scaled_fp8_quant -else: - scaled_fp8_quant = None logger = init_logger(__name__) -from lightllm.utils.sgl_utils import flash_attn_with_kvcache - class LlamaTransformerLayerInfer(TransformerLayerInferTpl): """ """ @@ -61,23 +47,12 @@ def _bind_attention(self): if get_env_start_args().enable_fa3: if True: pass - elif not self.mode: - # self._context_attention_kernel = partial( - # LlamaTransformerLayerInfer._context_attention_flashattention, self - # ) - # self._token_attention_kernel = partial( - # LlamaTransformerLayerInfer._token_decode_attention_flashattention, self - # ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - else: - raise Exception(f"Unsupported mode for fa3 backend: {self.mode}") return elif get_env_start_args().enable_flashinfer_prefill: self._context_attention_kernel = partial( LlamaTransformerLayerInfer._context_attention_flashinfer_kernel, self ) else: - # self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self) pass if "int8kv" in self.mode: @@ -234,47 +209,6 @@ def _context_attention_flashinfer_kernel( ) return o_tensor - def _context_attention_flashattention_fp8(self, q, kv, infer_state, layer_weight, out=None): - q, q_scale = q_per_head_fp8_quant( - q.view(q.shape[0], self.tp_k_head_num_, -1), - infer_state.b_seq_len, - infer_state.cu_seqlens_q, - infer_state.q_scale, - infer_state.token_batch_ids, - ) - cache_k = ( - (infer_state.mem_manager.kv_buffer[self.layer_num_][:, : self.tp_k_head_num_, :]) - .reshape(-1, 1, self.tp_k_head_num_, self.head_dim_) - .view(torch.float8_e4m3fn) - ) - cache_v = ( - ( - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - ) - .reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) - .view(torch.float8_e4m3fn) - ) - o = flash_attn_with_kvcache( - q=q.view(-1, self.tp_q_head_num_, self.head_dim_), - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=infer_state.q_max_seq_len, - causal=True, - window_size=(-1, -1), - softcap=0.0, - q_descale=q_scale, - k_descale=infer_state.k_descale[self.layer_num_], - v_descale=infer_state.v_descale[self.layer_num_], - return_softmax_lse=False, - ) - return o - def _get_o( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: @@ -437,91 +371,6 @@ def _token_decode_attention_flashinfer(self, q, infer_state: LlamaFlashInferStat ) return o_tensor - def _token_decode_attention_ppl_int8kv(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - - # group_int8kv_decode_attention(at::Tensor o, at::Tensor q, at::Tensor k, at::Tensor k_s, at::Tensor v, - # at::Tensor v_s, at::Tensor b_loc, at::Tensor b_seq_len, int max_len_in_batch) - light_ops.group8_int8kv_decode_attention( - o_tensor.view(calcu_shape1), - q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - infer_state.mem_manager.scale_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - return o_tensor - - def _token_decode_attention_ppl_fp16(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - from lightllm_ppl_fp16_kernel import fp16_decode_attention - - # group_int8kv_decode_attention(at::Tensor o, at::Tensor q, at::Tensor k, at::Tensor k_s, - # at::Tensor v, at::Tensor v_s, at::Tensor b_loc, at::Tensor b_seq_len, int max_len_in_batch) - fp16_decode_attention( - o_tensor.view(calcu_shape1), - 1.0 / (self.head_dim_ ** 0.5), - q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - return o_tensor - - def _token_decode_attention_flashattention_fp8(self, q, infer_state, layer_weight, out=None): - cache_k = ( - (infer_state.mem_manager.kv_buffer[self.layer_num_][:, : self.tp_k_head_num_, :]) - .reshape(-1, 1, self.tp_k_head_num_, self.head_dim_) - .view(torch.float8_e4m3fn) - ) - cache_v = ( - ( - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - ) - .reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) - .view(torch.float8_e4m3fn) - ) - q, q_scale = scaled_fp8_quant(q.view(q.shape[0] * self.tp_k_head_num_, -1), use_per_token_if_dynamic=True) - o = flash_attn_with_kvcache( - q=q.view(-1, self.tp_q_head_num_, self.head_dim_), - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=1, - causal=False, - window_size=(-1, -1), - softcap=0.0, - q_descale=q_scale.view(infer_state.batch_size, self.tp_k_head_num_), - k_descale=infer_state.k_descale[self.layer_num_], - v_descale=infer_state.v_descale[self.layer_num_], - return_softmax_lse=False, - ) - return o - def overlap_tpsp_token_forward( self, input_embdings: torch.Tensor, From a2be571c93a14eb57c11b8828760bc2f14ae4efb Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 07:09:20 +0000 Subject: [PATCH 030/114] fix --- .../basemodel/attention/flashinfer_backend.py | 236 ++++++++++++++++++ .../triton_kernel/repack_kv_index.py | 90 +++++++ lightllm/models/llama/model.py | 27 +- 3 files changed, 327 insertions(+), 26 deletions(-) create mode 100644 lightllm/common/basemodel/attention/flashinfer_backend.py create mode 100644 lightllm/common/basemodel/triton_kernel/repack_kv_index.py diff --git a/lightllm/common/basemodel/attention/flashinfer_backend.py b/lightllm/common/basemodel/attention/flashinfer_backend.py new file mode 100644 index 0000000000..487634b2e8 --- /dev/null +++ b/lightllm/common/basemodel/attention/flashinfer_backend.py @@ -0,0 +1,236 @@ +import dataclasses +import torch +from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from typing import Optional, TYPE_CHECKING +from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id +from ..triton_kernel.repack_kv_index import repack_kv_index + +if TYPE_CHECKING: + from lightllm.common.basemodel.basemodel import TpPartBaseModel + + +class FlashInferAttBackend(BaseAttBackend): + def __init__(self, model: "TpPartBaseModel"): + super().__init__() + self.model = model + tp_world_size = get_dp_world_size() + self.tp_q_head_num = model.config["num_attention_heads"] // tp_world_size + self.tp_kv_head_num = max(model.config["num_key_value_heads"] // tp_world_size, 1) + head_dim = model.config["hidden_size"] // model.config["num_attention_heads"] + self.head_dim = model.config.get("head_dim", head_dim) + self.workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) + self.max_seq_length = model.max_seq_length + self.kv_indices_buffer = [ + torch.empty( + model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() + ), + torch.empty( + model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() + ), + ] + self.q_data_type = model.data_type + self.kv_data_type = model.data_type + + def create_att_prefill_state(self, infer_state) -> "FlashInferPrefillAttState": + return FlashInferPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "FlashInferDecodeAttState": + return FlashInferDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class FlashInferPrefillAttState(BasePrefillAttState): + prefill_wrapper: object = None + + def init_state(self): + self.backend: FlashInferAttBackend = self.backend + + import flashinfer + + batch_size = self.infer_state.batch_size + device = self.infer_state.input_ids.device + + q_starts = self.infer_state.b1_cu_q_seq_len.int() + kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32, device=device) + kv_indices = torch.empty( + batch_size * self.backend.max_seq_length, + dtype=torch.int32, + device=device, + ) + repack_kv_index( + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + self.infer_state.b_seq_len, + kv_starts[:-1], + self.infer_state.max_kv_seq_len, + kv_indices, + ) + self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + self.backend.workspace_buffer, + qo_indptr_buf=q_starts, + paged_kv_indptr_buf=kv_starts, + paged_kv_indices_buf=kv_indices, + paged_kv_last_page_len_buf=kv_last_page_len, + ) + self.prefill_wrapper.plan( + q_starts, + kv_starts, + kv_indices, + kv_last_page_len, + self.backend.tp_q_head_num, + self.backend.tp_kv_head_num, + self.backend.head_dim, + 1, + causal=True, + pos_encoding_mode="NONE", + logits_soft_cap=0.0, + q_data_type=self.backend.q_data_type, + kv_data_type=self.backend.kv_data_type, + ) + + def copy_for_prefill_cuda_graph(self, new_state: "FlashInferPrefillAttState"): + pass + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.use_alibi is False + return self._nomarl_prefill_att( + q=q, + k=k, + v=v, + layer_weight=layer_weight, + alloc_func=alloc_func, + ) + + def _nomarl_prefill_att( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_weight, alloc_func=torch.empty + ) -> torch.Tensor: + self.backend: FlashInferAttBackend = self.backend # for typing + + o_tensor = alloc_func(q.shape, q.dtype, device="cuda") + self.prefill_wrapper.run( + q.view(q.shape[0], -1, self.backend.head_dim), + k.unsqueeze(1)[:, :, :, :], + v.unsqueeze(1)[:, :, :, :], + out=o_tensor.view(q.shape[0], -1, self.backend.head_dim), + ) + return o_tensor + + +@dataclasses.dataclass +class FlashInferDecodeAttState(BaseDecodeAttState): + kv_last_page_len_buffer: torch.Tensor = None + kv_indices: torch.Tensor = None + kv_starts: torch.Tensor = None + decode_wrapper: object = None + + def init_state(self): + import flashinfer + + self.backend: FlashInferAttBackend = self.backend + device = self.infer_state.input_ids.device + model = self.backend.model + self.kv_last_page_len_buffer = torch.full((self.infer_state.batch_size,), 1, dtype=torch.int32, device=device) + if ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ): + self.kv_indices = self.backend.kv_indices_buffer[self.infer_state.microbatch_index][ + : self.infer_state.batch_size * self.backend.max_seq_length + ] + else: + self.kv_indices = torch.empty( + self.infer_state.batch_size * self.backend.max_seq_length, + dtype=torch.int32, + device=device, + ) + + repack_kv_index( + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + self.infer_state.b_seq_len, + self.infer_state.b_start_loc, + self.infer_state.max_kv_seq_len, + self.kv_indices, + ) + self.kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + assert self.decode_wrapper is None + self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + self.backend.workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=True, + paged_kv_indptr_buffer=self.kv_starts, + paged_kv_indices_buffer=self.kv_indices, + paged_kv_last_page_len_buffer=self.kv_last_page_len_buffer, + ) + self.decode_wrapper.plan( + self.kv_starts, + self.kv_indices, + self.kv_last_page_len_buffer, + self.backend.tp_q_head_num, + self.backend.tp_kv_head_num, + self.backend.head_dim, + 1, + q_data_type=self.backend.q_data_type, + kv_data_type=self.backend.kv_data_type, + non_blocking=True, + ) + return + + def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"): + self.decode_wrapper.plan( + new_state.kv_starts, + new_state.kv_indices, + new_state.kv_last_page_len_buffer, + new_state.backend.tp_q_head_num, + new_state.backend.tp_kv_head_num, + new_state.backend.head_dim, + 1, + q_data_type=new_state.backend.q_data_type, + kv_data_type=new_state.backend.kv_data_type, + non_blocking=True, + ) + + def decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert att_control.use_alibi is False + return self._normal_decode_att( + q=q, + k=k, + v=v, + layer_weight=layer_weight, + alloc_func=alloc_func, + ) + + def _normal_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + alloc_func=torch.empty, + ): + o_tensor = alloc_func(q.shape, q.dtype) + self.decode_wrapper.run( + q, + k.unsqueeze(1)[:, :, :, :], + v.unsqueeze(1)[:, :, :, :], + out=o_tensor, + ) + return o_tensor diff --git a/lightllm/common/basemodel/triton_kernel/repack_kv_index.py b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py new file mode 100644 index 0000000000..e86d2e819e --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py @@ -0,0 +1,90 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_repack_kv_index( + kv_index, + req_index, + out_kv_index, + seq_len, + start_loc, + kv_stride_h, + SEQ_BLOCK: tl.constexpr, +): + cur_batch = tl.program_id(0) + start_seq_n = tl.program_id(1) + + cur_batch_seq_len = tl.load(seq_len + cur_batch) + cur_batch_req_idx = tl.load(req_index + cur_batch) + cur_batch_start_loc = tl.load(start_loc + cur_batch) + + offs_seq = start_seq_n * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK) + block_end_loc = tl.minimum((start_seq_n + 1) * SEQ_BLOCK, cur_batch_seq_len) + kv_index_data = tl.load( + kv_index + kv_stride_h * cur_batch_req_idx + offs_seq, + mask=offs_seq < block_end_loc, + other=0, + ) + out_kv_index_ptr = out_kv_index + cur_batch_start_loc + offs_seq + tl.store(out_kv_index_ptr, kv_index_data, mask=offs_seq < block_end_loc) + return + + +@torch.no_grad() +def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index): + batch_size = req_index.shape[0] + # flashinfer requires out_kv_index to be zeroed before use + out_kv_index.zero_() + BLOCK = 64 + grid = ( + batch_size, + triton.cdiv(max_seq_len, BLOCK), + ) + + _fwd_kernel_repack_kv_index[grid]( + kv_index, + req_index, + out_kv_index, + seq_len, + start_loc, + kv_index.stride(0), + SEQ_BLOCK=BLOCK, + num_warps=8, + num_stages=1, + ) + return + + +def repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, output): + for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): + output[start : start + sl] = req_to_token_indexs[b][:sl] + + +if __name__ == "__main__": + import torch.nn.functional as F + + BATCH, MAX_SEQ_LEN = 10, 1024 + rand_idx = torch.randperm(2 * MAX_SEQ_LEN * BATCH).cuda().int() + b_req_idx = torch.randperm(BATCH).cuda().int() + b_seq_len = torch.randint(1, MAX_SEQ_LEN, (BATCH,)).cuda().int() + req_to_token_indexs = torch.zeros((2 * BATCH, 2 * MAX_SEQ_LEN)).cuda().int() + b_start_loc = ( + torch.cat([torch.zeros([1], device=b_seq_len.device, dtype=b_seq_len.dtype), b_seq_len[0:-1].cumsum(0)]) + .cuda() + .int() + ) + + output = torch.zeros((b_seq_len.sum(),)).cuda().int() + ref = torch.zeros((b_seq_len.sum(),)).cuda().int() + for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): + req_to_token_indexs[b][:sl] = rand_idx[start : start + sl] + + fn1 = lambda: repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, ref) + fn2 = lambda: repack_kv_index(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, MAX_SEQ_LEN, output) + ms1 = triton.testing.do_bench(fn1) + ms2 = triton.testing.do_bench_cudagraph(fn2) + print(ms1, ms2) + assert torch.allclose(output.float(), ref.float()) diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index 95465a9e6c..033816b1e5 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -21,27 +21,6 @@ logger = init_logger(__name__) -class LlamaFlashInferStateExtraInfo: - def __init__(self, model): - tp_world_size = get_dp_world_size() - self.tp_q_head_num = model.config["num_attention_heads"] // tp_world_size - self.tp_kv_head_num = max(model.config["num_key_value_heads"] // tp_world_size, 1) - head_dim = model.config["hidden_size"] // model.config["num_attention_heads"] - self.head_dim = model.config.get("head_dim", head_dim) - self.workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) - self.max_seq_length = model.max_seq_length - self.kv_indices_buffer = [ - torch.empty( - model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() - ), - torch.empty( - model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() - ), - ] - self.q_data_type = model.data_type - self.kv_data_type = torch.float8_e4m3fn if "offline_calibration_fp8kv" in model.mode else model.data_type - - @ModelRegistry("llama") class LlamaTpPartModel(TpPartBaseModel): # weight class @@ -95,11 +74,7 @@ def _init_mem_manager(self): return def _init_inferstate_cls(self): - if get_env_start_args().enable_fa3: - self.infer_state_class = FlashAttentionStateInfo - elif self.enable_flashinfer: - self.infer_state_class = LlamaFlashInferStateInfo - self.flashinfer_extra_state = LlamaFlashInferStateExtraInfo(self) + pass def _init_custom(self): """ From ae5bb45d368156360316e577f25ac96994fe4f80 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 07:11:52 +0000 Subject: [PATCH 031/114] fix --- .../layer_infer/transformer_layer_infer.py | 26 -- .../llama/triton_kernel/ppl_quant_copy_kv.py | 294 ------------------ 2 files changed, 320 deletions(-) delete mode 100644 lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index e6efb7e26d..fcde917be0 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -196,19 +196,6 @@ def _context_attention_flashinfer_kernel_fp8( ) return o_tensor - def _context_attention_flashinfer_kernel( - self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - kv = kv.unsqueeze(1) - infer_state.prefill_wrapper.run( - q.view(q.shape[0], -1, self.head_dim_), - (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), - out=o_tensor.view(q.shape[0], -1, self.head_dim_), - ) - return o_tensor - def _get_o( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: @@ -358,19 +345,6 @@ def _token_decode_attention_flashinfer_fp8(self, q, infer_state: LlamaFlashInfer ) return o_tensor - def _token_decode_attention_flashinfer(self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None): - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_].unsqueeze(1) - infer_state.decode_wrapper.run( - q.view(calcu_shape1), - (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), - out=o_tensor.view(calcu_shape1), - ) - return o_tensor - def overlap_tpsp_token_forward( self, input_embdings: torch.Tensor, diff --git a/lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py b/lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py deleted file mode 100644 index 3d9a490f47..0000000000 --- a/lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py +++ /dev/null @@ -1,294 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_destindex_copy_quantize_kv( - K, - Dest_loc, - Out, - Out_scale, - stride_k_bs, - stride_k_h, - stride_k_g, - stride_k_d, - stride_o_bs, - stride_o_h, - stride_o_g, - stride_o_d, - stride_os_bs, - stride_os_h, - stride_os_g, - group_size, - BLOCK_GROUP_NUM: tl.constexpr, - BLOCK_GROUP_DIM: tl.constexpr, -): - cur_index = tl.program_id(0) - cur_head = tl.program_id(1) - - offs_g = tl.arange(0, BLOCK_GROUP_NUM) - offs_d = tl.arange(0, BLOCK_GROUP_DIM) - - dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) - - src_data = tl.load( - K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :], - mask=offs_g[:, None] < group_size, - other=0.0, - ) - abs_data = tl.abs(src_data) - data_scale = (tl.max(abs_data, axis=1) / 127.0).to(Out_scale.dtype.element_ty) - q_src_data = (src_data / data_scale[:, None]).to(tl.int8) - - o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :] - os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g - tl.store(o_ptrs, q_src_data, mask=offs_g[:, None] < group_size) - tl.store(os_ptrs, data_scale, mask=offs_g < group_size) - return - - -@torch.no_grad() -def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale): - seq_len = DestLoc.shape[0] - head_num = K.shape[1] - head_dim = K.shape[2] - quant_group_dim = 8 - - assert head_dim % quant_group_dim == 0, "error head dim, can not been supported to copy quant kv" - grid = (seq_len, head_num) - num_warps = 1 - - group_size = head_dim // quant_group_dim - group_dim = quant_group_dim - - K = K.view((K.shape[0], K.shape[1], group_size, group_dim)) - Out = Out.view(Out.shape[0], Out.shape[1], group_size, group_dim) - - _fwd_kernel_destindex_copy_quantize_kv[grid]( - K, - DestLoc, - Out, - Out_scale, - K.stride(0), - K.stride(1), - K.stride(2), - K.stride(3), - Out.stride(0), - Out.stride(1), - Out.stride(2), - Out.stride(3), - Out_scale.stride(0), - Out_scale.stride(1), - Out_scale.stride(2), - group_size, - BLOCK_GROUP_NUM=triton.next_power_of_2(group_size), - BLOCK_GROUP_DIM=group_dim, - num_warps=num_warps, - num_stages=1, - ) - return - - -@triton.jit -def _fwd_kernel_destindex_copy_dequantize_kv( - mem_kv_buffer, - mem_kv_scale, - req_to_token_indexs, - b_seq_len, - b_req_idx, - Out, - stride_kv_b, - stride_kv_h, - stride_kv_g, - stride_kv_d, - stride_o_bh, - stride_o_l, - stride_o_g, - stride_o_d, - stride_s_b, - stride_s_h, - stride_s_g, - stride_req_to_tokens_b, - stride_req_to_tokens_s, - group_size, - head_num: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - BLOCK_GROUP_NUM: tl.constexpr, - BLOCK_GROUP_DIM: tl.constexpr, -): - cur_group = tl.program_id(0) - start_m = tl.program_id(1) - cur_bh = tl.program_id(2) - cur_batch = cur_bh // head_num - cur_head = cur_bh % head_num - - block_start_loc = BLOCK_SIZE * start_m - cur_batch_req_idx = tl.load(b_req_idx + cur_batch) - cur_seq_len = tl.load(b_seq_len + cur_batch) - - # initialize offsets - offs_kv_loc = block_start_loc + tl.arange(0, BLOCK_SIZE) - - # offs_g = tl.arange(0, BLOCK_GROUP_NUM) - offs_d = tl.arange(0, BLOCK_GROUP_DIM) - - kv_loc = tl.load( - req_to_token_indexs + cur_batch_req_idx * stride_req_to_tokens_b + offs_kv_loc, mask=offs_kv_loc < cur_seq_len - ).to(tl.int64) - offs_kv = kv_loc[:, None] * stride_kv_b + cur_head * stride_kv_h + cur_group * stride_kv_g + offs_d[None, :] - - src_data = tl.load( - mem_kv_buffer + offs_kv, - mask=offs_kv_loc[:, None] < cur_seq_len, - other=0.0, - ).to(Out.dtype.element_ty) - - s_ptrs = mem_kv_scale + kv_loc * stride_s_b + cur_head * stride_s_h + cur_group * stride_s_g - data_scale = tl.load( - s_ptrs, - mask=offs_kv_loc < cur_seq_len, - ) - - out_data = src_data * data_scale[:, None] - o_ptrs = Out + cur_bh * stride_o_bh + offs_kv_loc[:, None] * stride_o_l + cur_group * stride_o_g + offs_d[None, :] - tl.store(o_ptrs, out_data, mask=offs_kv_loc[:, None] < cur_seq_len) - return - - -@torch.no_grad() -def destindex_copy_dequantize_kv( - mem_kv_buffer, mem_kv_scale, req_to_token_indexs, b_seq_len, b_req_idx, max_len_in_batch, Out -): - batch_size = b_seq_len.shape[0] - head_num = mem_kv_buffer.shape[1] - head_dim = mem_kv_buffer.shape[2] - quant_group_dim = 8 - BLOCK_SIZE = 128 - group_size = head_dim // quant_group_dim - group_dim = quant_group_dim - assert head_dim % quant_group_dim == 0, "error head dim, can not been supported to copy quant kv" - grid = (group_size, triton.cdiv(max_len_in_batch, BLOCK_SIZE), batch_size * head_num) - num_warps = 1 - mem_kv_buffer = mem_kv_buffer.view((mem_kv_buffer.shape[0], mem_kv_buffer.shape[1], group_size, group_dim)) - mem_kv_scale = mem_kv_scale.view((mem_kv_buffer.shape[0], mem_kv_buffer.shape[1], -1)) - Out = Out.view(Out.shape[0] * Out.shape[1], -1, group_size, group_dim) - - _fwd_kernel_destindex_copy_dequantize_kv[grid]( - mem_kv_buffer, - mem_kv_scale, - req_to_token_indexs, - b_seq_len, - b_req_idx, - Out, - mem_kv_buffer.stride(0), - mem_kv_buffer.stride(1), - mem_kv_buffer.stride(2), - mem_kv_buffer.stride(3), - Out.stride(0), - Out.stride(1), - Out.stride(2), - Out.stride(3), - mem_kv_scale.stride(0), - mem_kv_scale.stride(1), - mem_kv_scale.stride(2), - req_to_token_indexs.stride(0), - req_to_token_indexs.stride(1), - group_size, - head_num=head_num, - BLOCK_SIZE=BLOCK_SIZE, - BLOCK_GROUP_NUM=triton.next_power_of_2(group_size), - BLOCK_GROUP_DIM=group_dim, - num_warps=num_warps, - num_stages=1, - ) - return - - -def test2(): - import time - - B, N_CTX, H, D = 1, 3, 12, 128 - src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32).cuda() - value_dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda().to(torch.int8) - scale_dest = torch.randn((B * N_CTX, H, D // 8), dtype=torch.float16).cuda() - - for _ in range(10): - destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) - torch.cuda.synchronize() - t1 = time.time() - for _ in range(1000): - destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) - torch.cuda.synchronize() - t2 = time.time() - - print("Time cost ", t2 - t1) - value_dest = value_dest.view((B * N_CTX, H, D // 8, 8)) - scale_dest = scale_dest.view((B * N_CTX, H, D // 8, 1)) - print("max ", torch.max(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src))) - print("mean ", torch.mean(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src))) - cos = torch.nn.CosineSimilarity(0) - print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32))) - - -def torch_dequant(kv, kv_scale, o, b_req_idx, b_seq_len, req_to_token_indexs): - - batch = b_req_idx.shape[0] - for i in range(batch): - req_idx = b_req_idx[i] - seq_len = b_seq_len[i] - print(seq_len, b_seq_len) - kv_loc = req_to_token_indexs[req_idx, :seq_len] - head_num = kv.shape[1] - cur_kv = kv[kv_loc, :, :].reshape(seq_len, head_num, -1, 8).to(o.dtype) - cur_scale = kv_scale[kv_loc, :, :].reshape(seq_len, head_num, -1, 1) - out = cur_kv * cur_scale - o[i, :seq_len, :, :] = out.reshape(out.shape[0], out.shape[1], -1) - - -def test3(): - import time - import numpy as np - - Z, H, N_CTX, D_HEAD = 1, 16, 3, 128 - dtype = torch.bfloat16 - kv = torch.empty((Z * N_CTX + 100, 2 * H, D_HEAD), dtype=torch.int8, device="cuda") - kv_scale = torch.randn((Z * N_CTX + 100, 2 * H, D_HEAD // 8), dtype=dtype, device="cuda") - out = torch.empty((Z, 2 * H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - torch_out = torch.empty((Z, N_CTX, 2 * H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - req_to_token_indexs = torch.empty((1000, N_CTX + 7000), dtype=torch.int32, device="cuda") - max_input_len = N_CTX - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - b_req_idx = torch.ones((Z,), dtype=torch.int32, device="cuda") - for i in range(Z): - seq_len = N_CTX - i * 100 - b_seq_len[i] = seq_len - b_req_idx[i] = i - req_to_token_indexs[i][:seq_len] = ( - torch.tensor(np.arange(seq_len), dtype=torch.int32).cuda() + b_seq_len[0:i].sum() - ) - print(b_seq_len) - destindex_copy_dequantize_kv(kv, kv_scale, req_to_token_indexs, b_seq_len, b_req_idx, max_input_len, out) - torch_dequant(kv, kv_scale, torch_out, b_req_idx, b_seq_len, req_to_token_indexs) - torch.cuda.synchronize() - t1 = time.time() - for _ in range(1000): - destindex_copy_dequantize_kv(kv, kv_scale, req_to_token_indexs, b_seq_len, b_req_idx, max_input_len, out) - torch.cuda.synchronize() - t2 = time.time() - print((t2 - t1)) - torch_out = torch_out.transpose(1, 2) - for i in range(Z): - print("max ", torch.max(torch.abs(torch_out - out)[i][:, : b_seq_len[i]])) - print("mean ", torch.mean(torch.abs(torch_out - out)[i][:, : b_seq_len[i]])) - assert torch.allclose(torch_out[i][:, : b_seq_len[i]], out[i][:, : b_seq_len[i]], atol=1e-2, rtol=0) - # print("max ", torch.max(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src))) - # print("mean ", torch.mean(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src))) - # cos = torch.nn.CosineSimilarity(0) - # print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32))) - - -if __name__ == "__main__": - test3() From c5b0c0a0efb99a214ba0747e8c2d7a6ba3fdc3e0 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 07:17:47 +0000 Subject: [PATCH 032/114] fix --- lightllm/common/kv_cache_mem_manager/mem_manager.py | 9 +++++++++ .../llama/layer_infer/transformer_layer_infer.py | 11 ----------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index d8fd93009f..554327bb33 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -5,6 +5,7 @@ import torch.multiprocessing as mp from typing import List, Union from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_for_dp +from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv from lightllm.server.pd_io_struct import KVMoveTask from lightllm.utils.log_utils import init_logger from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt @@ -21,6 +22,7 @@ from multiprocessing.reduction import ForkingPickler from filelock import FileLock + logger = init_logger(__name__) @@ -65,6 +67,13 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False ) self.HOLD_TOKEN_MEMINDEX = self.size + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + """ + 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 + """ + destindex_copy_kv(kv, mem_index, self.kv_buffer[layer_index]) + return + def get_cell_size(self): return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index fcde917be0..cdf7beb8df 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -44,17 +44,6 @@ def _bind_norm(self): return def _bind_attention(self): - if get_env_start_args().enable_fa3: - if True: - pass - return - elif get_env_start_args().enable_flashinfer_prefill: - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_flashinfer_kernel, self - ) - else: - pass - if "int8kv" in self.mode: pass elif "offline_calibration_fp8kv" in self.mode: From bd825dc248abf140a140d7a72c779fe78cb4bea4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 07:27:52 +0000 Subject: [PATCH 033/114] fix --- .../triton_kernel/destindex_copy_kv.py | 124 ------------------ .../kv_copy/ppl_int8kv_copy_kv.py | 4 +- .../ppl_int8kv_mem_manager.py | 14 ++ .../layer_infer/transformer_layer_infer.py | 22 ---- 4 files changed, 16 insertions(+), 148 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/destindex_copy_kv.py b/lightllm/common/basemodel/triton_kernel/destindex_copy_kv.py index bd53de386e..be29f92667 100644 --- a/lightllm/common/basemodel/triton_kernel/destindex_copy_kv.py +++ b/lightllm/common/basemodel/triton_kernel/destindex_copy_kv.py @@ -60,127 +60,3 @@ def destindex_copy_kv(K, DestLoc, Out): num_stages=1, ) return - - -@triton.jit -def _fwd_kernel_destindex_copy_quantize_kv( - K, - Dest_loc, - Out, - Out_scale, - stride_k_bs, - stride_k_h, - stride_k_d, - stride_o_bs, - stride_o_h, - stride_o_d, - stride_os_bs, - stride_os_h, - stride_os_d, - head_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_HEAD: tl.constexpr, -): - cur_index = tl.program_id(0) - offs_h = tl.arange(0, BLOCK_HEAD) - offs_d = tl.arange(0, BLOCK_DMODEL) - - dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) - src_data = tl.load( - K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :], - mask=offs_h[:, None] < head_num, - other=0.0, - ) - abs_data = tl.abs(src_data) - data_scale = (tl.max(abs_data, axis=1) / 127.0).to(Out_scale.dtype.element_ty)[:, None] - q_src_data = (src_data / data_scale).to(tl.int8) - o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :] - os_ptrs = Out_scale + dest_index * stride_os_bs + stride_os_h * offs_h[:, None] - tl.store(o_ptrs, q_src_data, mask=offs_h[:, None] < head_num) - tl.store(os_ptrs, data_scale, mask=offs_h[:, None] < head_num) - - -@torch.no_grad() -def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale): - seq_len = DestLoc.shape[0] - head_num = K.shape[1] - head_dim = K.shape[2] - assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2] - BLOCK_HEAD = triton.next_power_of_2(head_num) - grid = (seq_len,) - num_warps = 1 - - _fwd_kernel_destindex_copy_quantize_kv[grid]( - K, - DestLoc, - Out, - Out_scale, - K.stride(0), - K.stride(1), - K.stride(2), - Out.stride(0), - Out.stride(1), - Out.stride(2), - Out_scale.stride(0), - Out_scale.stride(1), - Out_scale.stride(2), - head_num, - BLOCK_DMODEL=head_dim, - BLOCK_HEAD=BLOCK_HEAD, - num_warps=num_warps, - num_stages=1, - ) - return - - -def test1(): - import time - - B, N_CTX, H, D = 32, 1024, 12, 128 - dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32, device="cuda") - - for _ in range(10): - destindex_copy_kv(src, dest_loc, dest) - torch.cuda.synchronize() - t1 = time.time() - for _ in range(1000): - destindex_copy_kv(src, dest_loc, dest) - torch.cuda.synchronize() - t2 = time.time() - - print("Time cost ", t2 - t1) - print("max ", torch.max(torch.abs(dest - src))) - print("mean ", torch.mean(torch.abs(dest - src))) - assert torch.allclose(src, dest, atol=1e-2, rtol=0) - - -def test2(): - import time - - B, N_CTX, H, D = 32, 1024, 12, 128 - src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32).cuda() - value_dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda().to(torch.int8) - scale_dest = torch.randn((B * N_CTX, H, 1), dtype=torch.float16).cuda() - - for _ in range(10): - destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) - torch.cuda.synchronize() - t1 = time.time() - for _ in range(1000): - destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) - torch.cuda.synchronize() - t2 = time.time() - - print("Time cost ", t2 - t1) - print("max ", torch.max(torch.abs(value_dest * scale_dest - src))) - print("mean ", torch.mean(torch.abs(value_dest * scale_dest - src))) - cos = torch.nn.CosineSimilarity(0) - print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32))) - - -if __name__ == "__main__": - test1() - test2() diff --git a/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int8kv_copy_kv.py b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int8kv_copy_kv.py index f1d85bed84..e5ee5cb8b8 100644 --- a/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int8kv_copy_kv.py +++ b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int8kv_copy_kv.py @@ -49,11 +49,11 @@ def _fwd_kernel_destindex_copy_quantize_kv( @torch.no_grad() -def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale): +def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale, quant_group_dim): seq_len = DestLoc.shape[0] head_num = K.shape[1] head_dim = K.shape[2] - quant_group_dim = 8 + assert triton.next_power_of_2(quant_group_dim) == quant_group_dim, "error quant group dim" assert head_dim % quant_group_dim == 0, "error head dim, can not been supported to copy quant kv" grid = (seq_len, head_num) diff --git a/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py index 2a5aad7c8b..f1d7a23543 100755 --- a/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py @@ -1,6 +1,7 @@ import torch from .mem_manager import MemoryManager +from ..basemodel.triton_kernel.kv_copy.ppl_int8kv_copy_kv import destindex_copy_quantize_kv class PPLINT8KVMemoryManager(MemoryManager): @@ -9,6 +10,19 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True, self.group_quant_size = 8 super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=always_copy, mem_fraction=mem_fraction) + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + """ + 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 + """ + destindex_copy_quantize_kv( + kv, + mem_index, + self.kv_buffer[layer_index], + self.scale_buffer[layer_index], + quant_group_dim=self.group_quant_size, + ) + return + def get_cell_size(self): return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size( self.kv_dtype diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index cdf7beb8df..4911c44fc8 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -282,12 +282,6 @@ def _copy_kv_to_mem_cache_with_calibration(self, buffer, mem_index, mem_manager) mem_manager.update_calibration_data(buffer, self.layer_num_) return - def _copy_kv_to_mem_cache_int8kv(self, buffer, mem_index, mem_manager): - destindex_copy_quantize_kv( - buffer, mem_index, mem_manager.kv_buffer[self.layer_num_], mem_manager.scale_buffer[self.layer_num_] - ) - return - def _copy_kv_to_mem_cache_fp8kv(self, buffer, mem_index, mem_manager): scales = mem_manager.scales destindex_copy_kv_fp8( @@ -298,22 +292,6 @@ def _copy_kv_to_mem_cache_fp8kv(self, buffer, mem_index, mem_manager): ) return - def _copy_kv_to_mem_cache_ppl_int8kv(self, buffer, mem_index, mem_manager): - from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_quantize_kv - - destindex_copy_quantize_kv( - buffer, mem_index, mem_manager.kv_buffer[self.layer_num_], mem_manager.scale_buffer[self.layer_num_] - ) - return - - def _copy_kv_to_mem_cache_ppl_int4kv(self, buffer, mem_index, mem_manager): - from lightllm.models.llama.triton_kernel.ppl_int4kv_copy_kv import destindex_copy_int4kv - - destindex_copy_int4kv( - buffer, mem_index, mem_manager.kv_buffer[self.layer_num_], mem_manager.scale_buffer[self.layer_num_] - ) - return - def _token_decode_attention_flashinfer_fp8(self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None): batch_size = infer_state.batch_size calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) From b7ce3f39a3eaf3dde6c0db9ed86e533d437c59f8 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 07:29:54 +0000 Subject: [PATCH 034/114] fix memmanager --- .../kv_cache_mem_manager/ppl_int4kv_mem_manager.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py index f3218594d3..95821ba786 100755 --- a/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py @@ -1,6 +1,7 @@ import torch from .mem_manager import MemoryManager +from ..basemodel.triton_kernel.kv_copy.ppl_int4kv_copy_kv import destindex_copy_int4kv class PPLINT4KVMemoryManager(MemoryManager): @@ -9,6 +10,19 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True, self.group_quant_size = 8 super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=always_copy, mem_fraction=mem_fraction) + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + """ + 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 + """ + destindex_copy_int4kv( + kv, + mem_index, + self.kv_buffer[layer_index], + self.scale_buffer[layer_index], + quant_group_size=self.group_quant_size, + ) + return + def get_cell_size(self): return 2 * self.head_num * self.head_dim // 2 * self.layer_num * torch._utils._element_size( self.kv_dtype From 15afeb55be37454ccfe46e40124d6f61465c0bf3 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 07:34:55 +0000 Subject: [PATCH 035/114] fix memmanager --- .../common/kv_cache_mem_manager/__init__.py | 2 -- .../int8kv_mem_manager.py | 29 ------------------- .../common/kv_cache_mem_manager/mem_utils.py | 6 +--- .../layer_infer/transformer_layer_infer.py | 9 ------ lightllm/utils/kv_cache_utils.py | 1 - 5 files changed, 1 insertion(+), 46 deletions(-) delete mode 100755 lightllm/common/kv_cache_mem_manager/int8kv_mem_manager.py diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py index 66caf5d789..e2ddac45a0 100644 --- a/lightllm/common/kv_cache_mem_manager/__init__.py +++ b/lightllm/common/kv_cache_mem_manager/__init__.py @@ -1,5 +1,4 @@ from .mem_manager import MemoryManager, ReadOnlyStaticsMemoryManager -from .int8kv_mem_manager import INT8KVMemoryManager from .calibration_fp8kv_mem_manager import CalibrationFP8KVMemoryManager from .export_calibration_mem_manager import ExportCalibrationMemoryManager from .ppl_int8kv_mem_manager import PPLINT8KVMemoryManager @@ -10,7 +9,6 @@ __all__ = [ "MemoryManager", "ReadOnlyStaticsMemoryManager", - "INT8KVMemoryManager", "CalibrationFP8KVMemoryManager", "ExportCalibrationMemoryManager", "PPLINT4KVMemoryManager", diff --git a/lightllm/common/kv_cache_mem_manager/int8kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/int8kv_mem_manager.py deleted file mode 100755 index 5725cdb7bb..0000000000 --- a/lightllm/common/kv_cache_mem_manager/int8kv_mem_manager.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch - -from .mem_manager import MemoryManager - - -class INT8KVMemoryManager(MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True, mem_fraction=0.9): - self.kv_dtype = torch.int8 - super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=True, mem_fraction=mem_fraction) - - def get_cell_size(self): - return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size( - self.kv_dtype - ) + 2 * self.head_num * self.layer_num * torch._utils._element_size(self.dtype) - - def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - self.kv_buffer = torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=torch.int8, device="cuda") - self.scale_buffer = torch.empty((layer_num, size + 1, 2 * head_num, 1), dtype=dtype, device="cuda") - - def _free_buffers(self): - self.kv_buffer = None - self.scale_buffer = None - - def get_index_kv_buffer(self, index): - return {"kv_buffer": self.kv_buffer[:, index], "scale_buffer": self.scale_buffer[:, index]} - - def load_index_kv_buffer(self, index, load_tensor_dict): - self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) - self.scale_buffer[:, index].copy_(load_tensor_dict["scale_buffer"]) diff --git a/lightllm/common/kv_cache_mem_manager/mem_utils.py b/lightllm/common/kv_cache_mem_manager/mem_utils.py index 259c5a56f8..23e24d9a42 100644 --- a/lightllm/common/kv_cache_mem_manager/mem_utils.py +++ b/lightllm/common/kv_cache_mem_manager/mem_utils.py @@ -1,6 +1,5 @@ from . import ( MemoryManager, - INT8KVMemoryManager, CalibrationFP8KVMemoryManager, ExportCalibrationMemoryManager, PPLINT8KVMemoryManager, @@ -41,9 +40,6 @@ def select_mem_manager_class(): elif "ppl_int4kv_flashdecoding" in mode: memory_manager_class = PPLINT4KVMemoryManager logger.info(f"Model kv cache using mode {mode}") - elif "triton_int8kv" in mode: - memory_manager_class = INT8KVMemoryManager - logger.info("Model kv cache using mode triton int8kv") elif "triton_fp8kv" in mode: raise Exception("currently only for deepseek") elif "offline_calibration_fp8kv" in mode: @@ -61,4 +57,4 @@ def select_mem_manager_class(): @lru_cache(maxsize=None) def used_mem_manager_has_scale() -> bool: mem_class = select_mem_manager_class() - return mem_class in [PPLINT8KVMemoryManager, PPLINT4KVMemoryManager, INT8KVMemoryManager] + return mem_class in [PPLINT8KVMemoryManager, PPLINT4KVMemoryManager] diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 4911c44fc8..38b45bf585 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -273,15 +273,6 @@ def _tpsp_ffn( # gate_out, up_out = None, None # return ffn2_out - def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager): - destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) - return - - def _copy_kv_to_mem_cache_with_calibration(self, buffer, mem_index, mem_manager): - destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) - mem_manager.update_calibration_data(buffer, self.layer_num_) - return - def _copy_kv_to_mem_cache_fp8kv(self, buffer, mem_index, mem_manager): scales = mem_manager.scales destindex_copy_kv_fp8( diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index 4875b4eee6..3de806330c 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -19,7 +19,6 @@ from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.common.kv_cache_mem_manager import ( MemoryManager, - INT8KVMemoryManager, CalibrationFP8KVMemoryManager, ExportCalibrationMemoryManager, PPLINT8KVMemoryManager, From c323eadad66231f452af305c0f11b4a7e4bd050f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 08:00:01 +0000 Subject: [PATCH 036/114] fix all --- .../triton_kernel/kv_copy/mla_copy_kv.py | 107 ++++++++++++++++++ .../common/kv_cache_mem_manager/__init__.py | 2 - .../deepseek2_fp8kv_mem_manager.py | 8 -- .../deepseek2_mem_manager.py | 17 +++ .../export_calibration_mem_manager.py | 15 +++ .../common/kv_cache_mem_manager/mem_utils.py | 4 - lightllm/utils/kv_cache_utils.py | 12 -- 7 files changed, 139 insertions(+), 26 deletions(-) create mode 100644 lightllm/common/basemodel/triton_kernel/kv_copy/mla_copy_kv.py delete mode 100644 lightllm/common/kv_cache_mem_manager/deepseek2_fp8kv_mem_manager.py diff --git a/lightllm/common/basemodel/triton_kernel/kv_copy/mla_copy_kv.py b/lightllm/common/basemodel/triton_kernel/kv_copy/mla_copy_kv.py new file mode 100644 index 0000000000..39deb1b6f7 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/kv_copy/mla_copy_kv.py @@ -0,0 +1,107 @@ +import torch + +import triton +import triton.language as tl + + +def _is_power_of_two(n): + return n > 0 and (n & (n - 1)) == 0 + + +@triton.jit +def _fwd_kernel_destindex_copy_kv( + KV_nope, + KV_rope, + Dest_loc, + O_nope, + O_rope, + stride_kv_nope_bs, + stride_kv_nope_h, + stride_kv_nope_d, + stride_kv_rope_bs, + stride_kv_rope_h, + stride_kv_rope_d, + stride_o_nope_bs, + stride_o_nope_h, + stride_o_nope_d, + stride_o_rope_bs, + stride_o_rope_h, + stride_o_rope_d, + BLOCK_DMODEL_NOPE: tl.constexpr, + BLOCK_DMODEL_ROPE: tl.constexpr, +): + cur_index = tl.program_id(0) + offs_d_nope = tl.arange(0, BLOCK_DMODEL_NOPE) + offs_d_rope = tl.arange(0, BLOCK_DMODEL_ROPE) + + dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) + + kv_nope_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_d_nope[None, :] + kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_d_rope[None, :] + + o_nope_ptrs = O_nope + dest_index * stride_o_nope_bs + stride_o_nope_d * offs_d_nope[None, :] + o_rope_ptrs = O_rope + dest_index * stride_o_rope_bs + stride_o_rope_d * offs_d_rope[None, :] + + kv_nope = tl.load(kv_nope_ptrs) + kv_rope = tl.load(kv_rope_ptrs) + + tl.store(o_nope_ptrs, kv_nope) + tl.store(o_rope_ptrs, kv_rope) + return + + +@torch.no_grad() +def destindex_copy_kv(KV_nope, KV_rope, DestLoc, O_nope, O_rope): + seq_len = DestLoc.shape[0] + kv_nope_head_dim = KV_nope.shape[2] + kv_rope_head_dim = KV_rope.shape[2] + + assert KV_nope.shape[1] == O_nope.shape[1] + assert KV_nope.shape[2] == O_nope.shape[2] + assert KV_rope.shape[1] == O_rope.shape[1] + assert KV_rope.shape[2] == O_rope.shape[2] + grid = (seq_len,) + num_warps = 1 + + _fwd_kernel_destindex_copy_kv[grid]( + KV_nope, + KV_rope, + DestLoc, + O_nope, + O_rope, + KV_nope.stride(0), + KV_nope.stride(1), + KV_nope.stride(2), + KV_rope.stride(0), + KV_rope.stride(1), + KV_rope.stride(2), + O_nope.stride(0), + O_nope.stride(1), + O_nope.stride(2), + O_rope.stride(0), + O_rope.stride(1), + O_rope.stride(2), + BLOCK_DMODEL_NOPE=kv_nope_head_dim, + BLOCK_DMODEL_ROPE=kv_rope_head_dim, + num_warps=num_warps, + num_stages=1, + ) + return + + +if __name__ == "__main__": + import torch.nn.functional as F + + B, N_CTX, H, NOPE_HEAD, ROPE_HEAD = 32, 1024, 1, 512, 64 + dtype = torch.bfloat16 + dest_loc = torch.randint(0, 100, (50,), device="cuda").unique() + kv = torch.randn((len(dest_loc), H, NOPE_HEAD + ROPE_HEAD), dtype=dtype).cuda() + O_nope = torch.zeros((B * N_CTX, H, NOPE_HEAD), dtype=dtype).cuda() + O_rope = torch.zeros((B * N_CTX, H, ROPE_HEAD), dtype=dtype).cuda() + + kv_nope = kv[:, :, :NOPE_HEAD] + kv_rope = kv[:, :, NOPE_HEAD:] + destindex_copy_kv(kv_nope, kv_rope, dest_loc, O_nope, O_rope) + + assert torch.allclose(O_nope[dest_loc], kv_nope, atol=1e-2, rtol=0) + assert torch.allclose(O_rope[dest_loc], kv_rope, atol=1e-2, rtol=0) diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py index e2ddac45a0..7d516e6728 100644 --- a/lightllm/common/kv_cache_mem_manager/__init__.py +++ b/lightllm/common/kv_cache_mem_manager/__init__.py @@ -4,7 +4,6 @@ from .ppl_int8kv_mem_manager import PPLINT8KVMemoryManager from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager from .deepseek2_mem_manager import Deepseek2MemoryManager -from .deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager __all__ = [ "MemoryManager", @@ -14,5 +13,4 @@ "PPLINT4KVMemoryManager", "PPLINT8KVMemoryManager", "Deepseek2MemoryManager", - "Deepseek2FP8KVMemoryManager", ] diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_fp8kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_fp8kv_mem_manager.py deleted file mode 100644 index 00699f4b15..0000000000 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_fp8kv_mem_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -import torch -from .deepseek2_mem_manager import Deepseek2MemoryManager - - -class Deepseek2FP8KVMemoryManager(Deepseek2MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): - # scale被追加到kv_buffer末尾, 因此加2, dtype统一改成uint8 - super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction) diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py index 7711734601..3a8310a71a 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py @@ -9,6 +9,7 @@ from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.common.kv_trans_kernel.nixl_kv_trans import mla_page_io +from ..basemodel.triton_kernel.kv_copy.mla_copy_kv import destindex_copy_kv logger = init_logger(__name__) @@ -17,6 +18,22 @@ class Deepseek2MemoryManager(MemoryManager): def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + """ + 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 + """ + rope_dim = 64 + kv_lora_rank = kv.shape[2] - rope_dim + + destindex_copy_kv( + kv[:, :, :kv_lora_rank], + kv[:, :, kv_lora_rank:], + mem_index, + self.kv_buffer[layer_index][:, :, :kv_lora_rank], + self.kv_buffer[layer_index][:, :, kv_lora_rank:], + ) + return + def get_cell_size(self): return self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) diff --git a/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py b/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py index b2749176ea..a51691d226 100755 --- a/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py @@ -1,6 +1,21 @@ +import torch from .offline_fp8_quant_mem_manager import OfflineFP8QuantMemManager +from lightllm.common.basemodel.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 class ExportCalibrationMemoryManager(OfflineFP8QuantMemManager): def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_export_mode=True) + + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + """ + 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 + """ + scales = self.scales + destindex_copy_kv_fp8( + kv, + mem_index, + scales[layer_index] if scales is not None else None, + self.kv_buffer[layer_index].view(torch.float8_e4m3fn), + ) + return diff --git a/lightllm/common/kv_cache_mem_manager/mem_utils.py b/lightllm/common/kv_cache_mem_manager/mem_utils.py index 23e24d9a42..ea8d9d2fb7 100644 --- a/lightllm/common/kv_cache_mem_manager/mem_utils.py +++ b/lightllm/common/kv_cache_mem_manager/mem_utils.py @@ -5,7 +5,6 @@ PPLINT8KVMemoryManager, PPLINT4KVMemoryManager, Deepseek2MemoryManager, - Deepseek2FP8KVMemoryManager, ) from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args @@ -26,9 +25,6 @@ def select_mem_manager_class(): if issubclass(model_class, Deepseek2TpPartModel): mem_class = Deepseek2MemoryManager - if "triton_fp8kv" in mode: - mem_class = Deepseek2FP8KVMemoryManager - logger.info(f"Model kv cache using mode {mode}, mem_manager class: {mem_class}") return mem_class diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index 3de806330c..3256fdd1fd 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -24,7 +24,6 @@ PPLINT8KVMemoryManager, PPLINT4KVMemoryManager, Deepseek2MemoryManager, - Deepseek2FP8KVMemoryManager, ) from typing import List, Tuple, Optional @@ -76,17 +75,6 @@ def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": scale_head_dim=0, scale_data_type=get_llm_data_type(), ) - elif mem_manager_class is Deepseek2FP8KVMemoryManager: - cpu_cache_meta = CpuKVCacheMeta( - page_num=0, - token_page_size=args.cpu_cache_token_page_size, - layer_num=get_layer_num(args.model_dir), - num_heads=1, - head_dim=512 + 64 + 2, - data_type=torch.uint8, - scale_head_dim=0, - scale_data_type=get_llm_data_type(), - ) elif mem_manager_class is MemoryManager: cpu_cache_meta = CpuKVCacheMeta( page_num=0, From 8ca694a15ffbb511780010187a699629b07925f9 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 08:21:45 +0000 Subject: [PATCH 037/114] add fp8 flashattention --- .../attention/fp8_flashinfer_backend.py | 118 ++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 lightllm/common/basemodel/attention/fp8_flashinfer_backend.py diff --git a/lightllm/common/basemodel/attention/fp8_flashinfer_backend.py b/lightllm/common/basemodel/attention/fp8_flashinfer_backend.py new file mode 100644 index 0000000000..abed5f8b03 --- /dev/null +++ b/lightllm/common/basemodel/attention/fp8_flashinfer_backend.py @@ -0,0 +1,118 @@ +import dataclasses +import torch +from .base_att import AttControl +from typing import Optional, TYPE_CHECKING +from ..triton_kernel.repack_kv_index import repack_kv_index +from .flashinfer_backend import FlashInferAttBackend, FlashInferPrefillAttState, FlashInferDecodeAttState + +if TYPE_CHECKING: + from lightllm.common.basemodel.basemodel import TpPartBaseModel + + +class Fp8FlashInferAttBackend(FlashInferAttBackend): + def __init__(self, model: "TpPartBaseModel"): + super().__init__(model=model) + self.kv_data_type = torch.float8_e4m3fn + + def create_att_prefill_state(self, infer_state) -> "Fp8FlashInferPrefillAttState": + return Fp8FlashInferPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "Fp8FlashInferDecodeAttState": + return Fp8FlashInferDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class Fp8FlashInferPrefillAttState(FlashInferPrefillAttState): + offline_scales: torch.Tensor = None + + def init_state(self): + super().init_state() + self.offline_scales = self.infer_state.mem_manager.scales_list + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.use_alibi is False + return self._fp8_prefill_att( + q=q, + k=k, + v=v, + layer_weight=layer_weight, + alloc_func=alloc_func, + ) + + def _fp8_prefill_att( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_weight, alloc_func=torch.empty + ) -> torch.Tensor: + o_tensor = alloc_func(q.shape, q.dtype, device="cuda") + k = k.unsqueeze(1).view(torch.float8_e4m3fn) + v = v.unsqueeze(1).view(torch.float8_e4m3fn) + + offline_scales = self.offline_scales + k_descale = offline_scales[layer_weight.layer_num_][0] if offline_scales is not None else None + v_descale = offline_scales[layer_weight.layer_num_][1] if offline_scales is not None else None + self.prefill_wrapper.run( + q, + (k, v), + k_scale=k_descale, + v_scale=v_descale, + out=o_tensor, + ) + return o_tensor + + +@dataclasses.dataclass +class Fp8FlashInferDecodeAttState(FlashInferDecodeAttState): + offline_scales: torch.Tensor = None + + def init_state(self): + super().init_state() + self.offline_scales = self.infer_state.mem_manager.scales_list + + def decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert att_control.use_alibi is False + return self._fp8_decode_att( + q=q, + k=k, + v=v, + layer_weight=layer_weight, + alloc_func=alloc_func, + ) + + def _fp8_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + alloc_func=torch.empty, + ): + o_tensor = alloc_func(q.shape, q.dtype, device="cuda") + + k = k.unsqueeze(1).view(torch.float8_e4m3fn) + v = v.unsqueeze(1).view(torch.float8_e4m3fn) + offline_scales = self.offline_scales + k_descale = offline_scales[layer_weight.layer_num_][0] if offline_scales is not None else None + v_descale = offline_scales[layer_weight.layer_num_][1] if offline_scales is not None else None + self.decode_wrapper.run( + q, + (k, v), + k_scale=k_descale, + v_scale=v_descale, + out=o_tensor, + ) + return o_tensor From f1c5e4a1b606cdaa0d31c6fe50131b65c23fc68c Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 08:29:18 +0000 Subject: [PATCH 038/114] fix llama --- lightllm/models/llama/flashinfer_struct.py | 127 --------------- .../layer_infer/transformer_layer_infer.py | 150 ++++-------------- lightllm/models/llama/model.py | 3 - 3 files changed, 34 insertions(+), 246 deletions(-) delete mode 100644 lightllm/models/llama/flashinfer_struct.py diff --git a/lightllm/models/llama/flashinfer_struct.py b/lightllm/models/llama/flashinfer_struct.py deleted file mode 100644 index 7f9beac1db..0000000000 --- a/lightllm/models/llama/flashinfer_struct.py +++ /dev/null @@ -1,127 +0,0 @@ -import os -import torch -import numpy as np -import torch.distributed as dist -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.envs_utils import get_env_start_args -from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index - - -class LlamaFlashInferStateInfo(LlamaInferStateInfo): - def __init__(self): - super().__init__() - self.prefill_wrapper = None - self.decode_wrapper = None - self.flashinfer_extra_state = None - - def init_some_extra_state(self, model): - super().init_some_extra_state(model) - self.flashinfer_extra_state = model.flashinfer_extra_state - - import flashinfer - - if not self.is_prefill: - if get_env_start_args().enable_flashinfer_decode: - self.kv_last_page_len_buffer = torch.full( - (self.batch_size,), 1, dtype=torch.int32, device=self.input_ids.device - ) - if self.batch_size <= model.graph_max_batch_size: - self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ - : self.batch_size * self.flashinfer_extra_state.max_seq_length - ] - else: - self.kv_indices = torch.empty( - self.batch_size * self.flashinfer_extra_state.max_seq_length, - dtype=torch.int32, - device=self.input_ids.device, - ) - - repack_kv_index( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - self.b_seq_len, - self.b_start_loc, - self.max_len_in_batch, - self.kv_indices, - ) - self.kv_starts = self.b1_cu_kv_seq_len.int() - if self.decode_wrapper is None: - self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_extra_state.workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=True, - paged_kv_indptr_buffer=self.kv_starts, - paged_kv_indices_buffer=self.kv_indices, - paged_kv_last_page_len_buffer=self.kv_last_page_len_buffer, - ) - self.decode_wrapper.plan( - self.kv_starts, - self.kv_indices, - self.kv_last_page_len_buffer, - self.flashinfer_extra_state.tp_q_head_num, - self.flashinfer_extra_state.tp_kv_head_num, - self.flashinfer_extra_state.head_dim, - 1, - q_data_type=self.flashinfer_extra_state.q_data_type, - kv_data_type=self.flashinfer_extra_state.kv_data_type, - non_blocking=True, - ) - else: - if get_env_start_args().enable_flashinfer_prefill: - q_starts = self.b1_cu_q_seq_len.int() - kv_starts = self.b1_cu_kv_seq_len.int() - kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=self.input_ids.device) - kv_indices = torch.empty( - self.batch_size * self.flashinfer_extra_state.max_seq_length, - dtype=torch.int32, - device=self.input_ids.device, - ) - repack_kv_index( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - self.b_seq_len, - kv_starts[:-1], - self.max_kv_seq_len, - kv_indices, - ) - self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( - self.flashinfer_extra_state.workspace_buffer, - qo_indptr_buf=q_starts, - paged_kv_indptr_buf=kv_starts, - paged_kv_indices_buf=kv_indices, - paged_kv_last_page_len_buf=kv_last_page_len, - ) - self.prefill_wrapper.plan( - q_starts, - kv_starts, - kv_indices, - kv_last_page_len, - self.flashinfer_extra_state.tp_q_head_num, - self.flashinfer_extra_state.tp_kv_head_num, - self.flashinfer_extra_state.head_dim, - 1, - causal=True, - pos_encoding_mode="NONE", - logits_soft_cap=0.0, - q_data_type=self.flashinfer_extra_state.q_data_type, - kv_data_type=self.flashinfer_extra_state.kv_data_type, - ) - return - - def copy_for_cuda_graph(self, new_infer_state): - super().copy_for_cuda_graph(new_infer_state) - if get_env_start_args().enable_flashinfer_decode and not self.is_prefill: - self.decode_wrapper.plan( - new_infer_state.kv_starts, - new_infer_state.kv_indices, - new_infer_state.kv_last_page_len_buffer, - new_infer_state.flashinfer_extra_state.tp_q_head_num, - new_infer_state.flashinfer_extra_state.tp_kv_head_num, - new_infer_state.flashinfer_extra_state.head_dim, - 1, - q_data_type=new_infer_state.flashinfer_extra_state.q_data_type, - kv_data_type=new_infer_state.flashinfer_extra_state.kv_data_type, - non_blocking=True, - ) - return diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 38b45bf585..17c4a19e2d 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -35,7 +35,6 @@ def __init__(self, layer_num, network_config, mode=[]): def _bind_func(self): self._bind_norm() - self._bind_attention() return def _bind_norm(self): @@ -43,38 +42,41 @@ def _bind_norm(self): self._ffn_norm = partial(LlamaTransformerLayerInfer._ffn_norm, self) return - def _bind_attention(self): - if "int8kv" in self.mode: - pass - elif "offline_calibration_fp8kv" in self.mode: - assert get_env_start_args().enable_flashinfer_prefill and get_env_start_args().enable_flashinfer_decode - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_fp8kv, self) - # self._context_attention_kernel = partial( - # LlamaTransformerLayerInfer._context_attention_flashinfer_kernel_fp8, self - # ) - # self._token_attention_kernel = partial( - # LlamaTransformerLayerInfer._token_decode_attention_flashinfer_fp8, self - # ) - elif "triton_flashdecoding" in self.mode: - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - elif "triton_gqa_flashdecoding" in self.mode: - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - elif "export_fp8kv_calibration" in self.mode: - # self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_flashinfer, - # self) - self._copy_kv_to_mem_cache = partial( - LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self - ) - elif not self.mode: - # if get_env_start_args().enable_flashinfer_decode: - # self._token_attention_kernel = partial( - # LlamaTransformerLayerInfer._token_decode_attention_flashinfer, self - # ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - else: - raise Exception(f"Unsupported mode: {self.mode}") + def _context_attention_kernel( + self, + q: torch.Tensor, + kv: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: LlamaTransformerLayerWeight, + ) -> torch.Tensor: + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + _k = kv[:, 0 : self.tp_k_head_num_, :] + _v = kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] + o_tensor = infer_state.prefill_att_state.prefill_att( + q=_q, + k=_k, + v=_v, + layer_weight=layer_weight, + alloc_func=self.alloc_tensor, + ) + o_tensor = o_tensor.view(q.shape) + return o_tensor - return + def _token_attention_kernel( + self, + q: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: LlamaTransformerLayerWeight, + ) -> torch.Tensor: + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + _k = kv[:, 0 : self.tp_k_head_num_, :] + _v = kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] + o_tensor = infer_state.decode_att_state.decode_att( + q=_q, k=_k, v=_v, layer_weight=layer_weight, alloc_func=self.alloc_tensor + ) + return o_tensor.view(q.shape) def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight @@ -131,60 +133,6 @@ def _tpsp_get_qkv( return q, cache_kv - def _context_attention_kernel( - self, - q: torch.Tensor, - kv: torch.Tensor, - infer_state: LlamaInferStateInfo, - layer_weight: LlamaTransformerLayerWeight, - out=None, - ) -> torch.Tensor: - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) - _k = kv[:, 0 : self.tp_k_head_num_, :] - _v = kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] - o_tensor = infer_state.prefill_att_state.prefill_att( - q=_q, - k=_k, - v=_v, - layer_weight=layer_weight, - alloc_func=self.alloc_tensor, - ) - o_tensor = o_tensor.view(q.shape) - return o_tensor - - def _token_attention_kernel( - self, q: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight, out=None - ) -> torch.Tensor: - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) - _k = kv[:, 0 : self.tp_k_head_num_, :] - _v = kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] - o_tensor = infer_state.decode_att_state.decode_att( - q=_q, k=_k, v=_v, layer_weight=layer_weight, alloc_func=self.alloc_tensor - ) - return o_tensor.view(q.shape) - - def _context_attention_flashinfer_kernel_fp8( - self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - kv = kv.unsqueeze(1) - k = kv[:, :, : self.tp_k_head_num_, :].view(torch.float8_e4m3fn) - v = kv[:, :, self.tp_k_head_num_ :, :].view(torch.float8_e4m3fn) - offline_scales = infer_state.mem_manager.scales_list - k_descale = offline_scales[self.layer_num_][0] if offline_scales is not None else None - v_descale = offline_scales[self.layer_num_][1] if offline_scales is not None else None - infer_state.prefill_wrapper.run( - q.view(q.shape[0], -1, self.head_dim_), - (k, v), - k_scale=k_descale, - v_scale=v_descale, - out=o_tensor.view(q.shape[0], -1, self.head_dim_), - ) - return o_tensor - def _get_o( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: @@ -273,36 +221,6 @@ def _tpsp_ffn( # gate_out, up_out = None, None # return ffn2_out - def _copy_kv_to_mem_cache_fp8kv(self, buffer, mem_index, mem_manager): - scales = mem_manager.scales - destindex_copy_kv_fp8( - buffer, - mem_index, - scales[self.layer_num_] if scales is not None else None, - mem_manager.kv_buffer[self.layer_num_].view(torch.float8_e4m3fn), - ) - return - - def _token_decode_attention_flashinfer_fp8(self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None): - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_].unsqueeze(1) - k = kv[:, :, : self.tp_k_head_num_, :].view(torch.float8_e4m3fn) - v = kv[:, :, self.tp_k_head_num_ :, :].view(torch.float8_e4m3fn) - offline_scales = infer_state.mem_manager.scales_list - k_descale = offline_scales[self.layer_num_][0] if offline_scales is not None else None - v_descale = offline_scales[self.layer_num_][1] if offline_scales is not None else None - infer_state.decode_wrapper.run( - q.view(calcu_shape1), - (k, v), - k_scale=k_descale, - v_scale=v_descale, - out=o_tensor.view(calcu_shape1), - ) - return o_tensor - def overlap_tpsp_token_forward( self, input_embdings: torch.Tensor, diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index 033816b1e5..438e2f9575 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -9,14 +9,11 @@ from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo -from lightllm.models.llama.flashinfer_struct import LlamaFlashInferStateInfo from lightllm.common.basemodel import TpPartBaseModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args -from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id logger = init_logger(__name__) From 5203fd3fd93c572ddf6533206dee79b8b588cba4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 08:30:36 +0000 Subject: [PATCH 039/114] fix --- lightllm/models/llama/layer_infer/transformer_layer_infer.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 17c4a19e2d..f16876d4c0 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -2,18 +2,13 @@ import triton import torch.distributed as dist from functools import partial - from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.flashinfer_struct import LlamaFlashInferStateInfo -from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv, destindex_copy_quantize_kv -from lightllm.common.basemodel.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 from lightllm.common.basemodel import TransformerLayerInferTpl from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor from lightllm.utils.log_utils import init_logger -from lightllm.utils.envs_utils import get_env_start_args logger = init_logger(__name__) From 002c1f7746ba0c5d711ef2b4b3ee68505201e025 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 08:34:52 +0000 Subject: [PATCH 040/114] fix --- lightllm/server/api_cli.py | 8 ++++++++ lightllm/server/core/objs/start_args_type.py | 1 + 2 files changed, 9 insertions(+) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 6ed2553eab..7c2ec215b9 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -357,6 +357,14 @@ def make_argument_parser() -> argparse.ArgumentParser: default=None, help="""kv type used in llm""", ) + parser.add_argument( + "--llm_kv_quant_group_size", + type=int, + default=8, + help="""kv quant group size used in llm kv, when llm_kv_type is quanted type,such as int8kv, + this params will be effective. + """, + ) parser.add_argument( "--enable_flashinfer_prefill", diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 6065923ffb..73c7b51d85 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -121,6 +121,7 @@ class StartArgs: llm_prefill_att_backend: str = field(default=None, metadata={"choices": [None, "triton", "fa3", "flashinfer"]}) llm_decode_att_backend: str = field(default=None, metadata={"choices": [None, "triton", "fa3", "flashinfer"]}) llm_kv_type: str = field(default=None, metadata={"choices": [None, ""]}) + llm_kv_quant_group_size: int = field(default=8) sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]}) penalty_counter_mode: str = field( default="gpu_counter", metadata={"choices": ["cpu_counter", "pin_mem_counter", "gpu_counter"]} From 6f3a71c5a6ac5f0a0d6c056b12d3dc4e4c2eb10f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 08:42:26 +0000 Subject: [PATCH 041/114] fix --- .../common/kv_cache_mem_manager/mem_utils.py | 25 ++++++------------- lightllm/server/api_cli.py | 4 +-- lightllm/server/core/objs/start_args_type.py | 2 +- 3 files changed, 11 insertions(+), 20 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/mem_utils.py b/lightllm/common/kv_cache_mem_manager/mem_utils.py index ea8d9d2fb7..2960ac6d36 100644 --- a/lightllm/common/kv_cache_mem_manager/mem_utils.py +++ b/lightllm/common/kv_cache_mem_manager/mem_utils.py @@ -16,8 +16,6 @@ @lru_cache(maxsize=None) def select_mem_manager_class(): - mode = get_env_start_args().mode - # case 1 # 先判断是否是 deepseek 系列的模型 model_class = get_llm_model_class() @@ -25,28 +23,21 @@ def select_mem_manager_class(): if issubclass(model_class, Deepseek2TpPartModel): mem_class = Deepseek2MemoryManager - logger.info(f"Model kv cache using mode {mode}, mem_manager class: {mem_class}") + logger.info(f"Model kv cache using default, mem_manager class: {mem_class}") return mem_class # case normal - logger.info(f"mode setting params: {mode}") - if "ppl_int8kv" in mode or "ppl_int8kv_flashdecoding" in mode or "ppl_int8kv_flashdecoding_diverse" in mode: + logger.info(f"mode setting params: {get_env_start_args().llm_kv_type}") + if get_env_start_args().llm_kv_type == "int8kv": memory_manager_class = PPLINT8KVMemoryManager - logger.info(f"Model kv cache using mode {mode}") - elif "ppl_int4kv_flashdecoding" in mode: + elif get_env_start_args().llm_kv_type == "int4kv": memory_manager_class = PPLINT4KVMemoryManager - logger.info(f"Model kv cache using mode {mode}") - elif "triton_fp8kv" in mode: - raise Exception("currently only for deepseek") - elif "offline_calibration_fp8kv" in mode: - memory_manager_class = CalibrationFP8KVMemoryManager - logger.info("Model kv cache using mode offline calibration fp8kv") - elif "export_fp8kv_calibration" in mode: + elif get_env_start_args().llm_kv_type == "fp8kv": memory_manager_class = ExportCalibrationMemoryManager - logger.info("Using mode export fp8kv calibration") - else: + elif get_env_start_args().llm_kv_type is None: memory_manager_class = MemoryManager - logger.info("Model kv cache using mode normal") + + logger.info(f"Model kv cache using mem_manager class: {memory_manager_class}") return memory_manager_class diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 7c2ec215b9..8b16cb5196 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -353,9 +353,9 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--llm_kv_type", type=str, - choices=[None, ""], + choices=[None, "int8kv", "int4kv", "fp8kv"], default=None, - help="""kv type used in llm""", + help="""kv type used in llm, None for dtype that llm used in config.json""", ) parser.add_argument( "--llm_kv_quant_group_size", diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 73c7b51d85..8ce31dd657 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -120,7 +120,7 @@ class StartArgs: enable_flashinfer_decode: bool = field(default=False) llm_prefill_att_backend: str = field(default=None, metadata={"choices": [None, "triton", "fa3", "flashinfer"]}) llm_decode_att_backend: str = field(default=None, metadata={"choices": [None, "triton", "fa3", "flashinfer"]}) - llm_kv_type: str = field(default=None, metadata={"choices": [None, ""]}) + llm_kv_type: str = field(default=None, metadata={"choices": [None, "int8kv", "int4kv", "fp8kv"]}) llm_kv_quant_group_size: int = field(default=8) sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]}) penalty_counter_mode: str = field( From 415cecb747f4954db1ff44370eabe63548fb7749 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 08:49:14 +0000 Subject: [PATCH 042/114] fix --- .../template/transformer_layer_infer_template.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index 436ca77d84..fdfc9a193e 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -3,8 +3,6 @@ import torch.distributed as dist from ..transformer_layer_infer import TransformerLayerInfer from ...infer_struct import InferStateInfo -from lightllm.utils.infer_utils import mark_cost_time -from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv from lightllm.distributed import all_reduce from typing import Tuple from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor @@ -39,11 +37,11 @@ def _tpsp_get_qkv(self, input, infer_state: InferStateInfo, layer_weight) -> Tup def _post_cache_kv(self, cache_kv, infer_state: InferStateInfo, layer_weight): mem_manager = infer_state.mem_manager - self._copy_kv_to_mem_cache(cache_kv, infer_state.mem_index, mem_manager) - return - - def _copy_kv_to_mem_cache(self, buffer, mem_index, mem_manager): - destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) + mem_manager.copy_kv_to_mem_manager( + layer_index=self.layer_num_, + mem_index=infer_state.mem_index, + kv=cache_kv, + ) return def _context_attention_kernel(self, q, kv, infer_state: InferStateInfo, layer_weight, out=None) -> torch.Tensor: From 4bcf1485f0522fa08fe431ea2682deeef2091e2f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 09:12:26 +0000 Subject: [PATCH 043/114] fix --- .../common/kv_cache_mem_manager/deepseek2_mem_manager.py | 4 +++- .../export_calibration_mem_manager.py | 3 ++- lightllm/common/kv_cache_mem_manager/mem_manager.py | 3 ++- .../kv_cache_mem_manager/ppl_int4kv_mem_manager.py | 3 ++- .../kv_cache_mem_manager/ppl_int8kv_mem_manager.py | 3 ++- .../gpt_oss/layer_infer/transformer_layer_infer.py | 9 +++------ lightllm/models/mistral/model.py | 4 +--- lightllm/models/qwen2_vl/infer_struct.py | 3 --- 8 files changed, 15 insertions(+), 17 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py index 3a8310a71a..a5e8f4dd8e 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py @@ -9,7 +9,7 @@ from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.common.kv_trans_kernel.nixl_kv_trans import mla_page_io -from ..basemodel.triton_kernel.kv_copy.mla_copy_kv import destindex_copy_kv + logger = init_logger(__name__) @@ -22,6 +22,8 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: """ 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 """ + from ..basemodel.triton_kernel.kv_copy.mla_copy_kv import destindex_copy_kv + rope_dim = 64 kv_lora_rank = kv.shape[2] - rope_dim diff --git a/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py b/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py index a51691d226..af338f0413 100755 --- a/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py @@ -1,6 +1,5 @@ import torch from .offline_fp8_quant_mem_manager import OfflineFP8QuantMemManager -from lightllm.common.basemodel.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 class ExportCalibrationMemoryManager(OfflineFP8QuantMemManager): @@ -11,6 +10,8 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: """ 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 """ + from lightllm.common.basemodel.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 + scales = self.scales destindex_copy_kv_fp8( kv, diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 554327bb33..784a22a9ad 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -5,7 +5,6 @@ import torch.multiprocessing as mp from typing import List, Union from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_for_dp -from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv from lightllm.server.pd_io_struct import KVMoveTask from lightllm.utils.log_utils import init_logger from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt @@ -71,6 +70,8 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: """ 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 """ + from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv + destindex_copy_kv(kv, mem_index, self.kv_buffer[layer_index]) return diff --git a/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py index 95821ba786..3549e31bbb 100755 --- a/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py @@ -1,7 +1,6 @@ import torch from .mem_manager import MemoryManager -from ..basemodel.triton_kernel.kv_copy.ppl_int4kv_copy_kv import destindex_copy_int4kv class PPLINT4KVMemoryManager(MemoryManager): @@ -14,6 +13,8 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: """ 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 """ + from ..basemodel.triton_kernel.kv_copy.ppl_int4kv_copy_kv import destindex_copy_int4kv + destindex_copy_int4kv( kv, mem_index, diff --git a/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py index f1d7a23543..0577c982ae 100755 --- a/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py @@ -1,7 +1,6 @@ import torch from .mem_manager import MemoryManager -from ..basemodel.triton_kernel.kv_copy.ppl_int8kv_copy_kv import destindex_copy_quantize_kv class PPLINT8KVMemoryManager(MemoryManager): @@ -14,6 +13,8 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: """ 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 """ + from ..basemodel.triton_kernel.kv_copy.ppl_int8kv_copy_kv import destindex_copy_quantize_kv + destindex_copy_quantize_kv( kv, mem_index, diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index 93cd7413ba..dbcf379468 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -6,7 +6,6 @@ from typing import Optional from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight -from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.utils.sgl_utils import flash_attn_with_kvcache from lightllm.utils.log_utils import init_logger @@ -51,9 +50,7 @@ def _gpt_oss_rmsnorm(self, hidden_states, weight, eps=1e-6): hidden_states = hidden_states * torch.rsqrt(variance + eps) return (weight * hidden_states).to(input_dtype) # main diff with Llama - def _ffn( - self, input, infer_state: FlashAttentionStateInfo, layer_weight: GptOssTransformerLayerWeight - ) -> torch.Tensor: + def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) -> torch.Tensor: hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape router_logits = layer_weight.moe_gate.mm(hidden_states) @@ -69,7 +66,7 @@ def _ffn( return hidden_states.view(num_tokens, hidden_dim) def _context_sliding_attention_flashattention( - self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight: GptOssTransformerLayerWeight, out=None + self, q, kv, infer_state, layer_weight: GptOssTransformerLayerWeight, out=None ): if self.network_config_["layer_types"][self.layer_num_] == "sliding_attention": window_size = (self.sliding_window - 1, self.sliding_window - 1) @@ -107,7 +104,7 @@ def _context_sliding_attention_flashattention( return o def _token_sliding_attention_flashattention( - self, q, infer_state: FlashAttentionStateInfo, layer_weight: GptOssTransformerLayerWeight, out=None + self, q, infer_state, layer_weight: GptOssTransformerLayerWeight, out=None ): if self.network_config_["layer_types"][self.layer_num_] == "sliding_attention": window_size = (self.sliding_window - 1, self.sliding_window - 1) diff --git a/lightllm/models/mistral/model.py b/lightllm/models/mistral/model.py index d32f51ae78..3dc19ea799 100644 --- a/lightllm/models/mistral/model.py +++ b/lightllm/models/mistral/model.py @@ -8,7 +8,6 @@ from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo from lightllm.models.mistral.layer_infer.transformer_layer_infer import MistralTransformerLayerInfer from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num @@ -44,8 +43,7 @@ def _init_custom(self): return def _init_inferstate_cls(self): - if get_env_start_args().enable_fa3: - self.infer_state_class = FlashAttentionStateInfo + pass def _init_mem_manager(self): # Dealing with head_dim_!=n_embed // num_attention_heads scenarios, such as mistral 13B diff --git a/lightllm/models/qwen2_vl/infer_struct.py b/lightllm/models/qwen2_vl/infer_struct.py index 838590325c..c7b5ba041c 100644 --- a/lightllm/models/qwen2_vl/infer_struct.py +++ b/lightllm/models/qwen2_vl/infer_struct.py @@ -4,13 +4,10 @@ from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.qwen2_vl.triton_kernel.get_mrope_position_ids import get_mrope_position_triton -from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo from lightllm.utils.envs_utils import get_env_start_args class Qwen2VLInferStateInfo(LlamaInferStateInfo): - init_flash_attention_state_func = FlashAttentionStateInfo._init_flash_attention_state - def __init__(self): super().__init__() self.position_cos = None From dd8ee1d090a64e326f8e306bdfad4800c65f9fb1 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 10:08:11 +0000 Subject: [PATCH 044/114] fix all --- .pre-commit-config.yaml | 2 +- .../common/basemodel/attention/__init__.py | 13 ++++ .../common/basemodel/attention/base_att.py | 4 ++ .../basemodel/attention/create_utils.py | 61 +++++++++++++++++++ .../common/basemodel/attention/fa3_backend.py | 10 +-- .../basemodel/attention/flashinfer_backend.py | 9 +-- .../basemodel/attention/fp8_fa3_backend.py | 8 +-- .../attention/fp8_flashinfer_backend.py | 7 +-- .../attention/int4kv_triton_backend.py | 7 ++- .../attention/int8kv_triton_backend.py | 6 +- lightllm/common/basemodel/basemodel.py | 12 ++-- .../export_calibration_mem_manager.py | 6 ++ .../kv_cache_mem_manager/mem_manager.py | 7 ++- .../ppl_int4kv_mem_manager.py | 9 ++- .../ppl_int8kv_mem_manager.py | 9 ++- .../layer_infer/transformer_layer_infer.py | 8 +-- test/test_api/test_generate_api.py | 2 +- 17 files changed, 134 insertions(+), 46 deletions(-) create mode 100644 lightllm/common/basemodel/attention/create_utils.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 573ff399c5..e7e043a1f7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,4 +10,4 @@ repos: rev: 6.1.0 hooks: - id: flake8 - args: ['--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, E231'] + args: ['--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, E231, F541'] diff --git a/lightllm/common/basemodel/attention/__init__.py b/lightllm/common/basemodel/attention/__init__.py index 379f3079ac..b8384e85ab 100644 --- a/lightllm/common/basemodel/attention/__init__.py +++ b/lightllm/common/basemodel/attention/__init__.py @@ -1,2 +1,15 @@ from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from .triton_backend import TritonAttBackend, TritonPrefillAttState, TritonDecodeAttState +from .int4kv_triton_backend import Int4kvTritonAttBackend +from .int8kv_triton_backend import Int8kvTritonAttBackend +from .fa3_backend import Fa3AttBackend +from .fp8_fa3_backend import Fp8Fa3AttBackend +from .flashinfer_backend import FlashInferAttBackend +from .fp8_flashinfer_backend import Fp8FlashInferAttBackend + +from .create_utils import ( + get_prefill_att_backend_class, + get_decode_att_backend_class, + get_mla_prefill_att_backend_class, + get_mla_decode_att_backend_class, +) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index cccc81992a..3abff06e45 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -4,6 +4,7 @@ from typing import Optional, TYPE_CHECKING if TYPE_CHECKING: + from lightllm.common.basemodel.basemodel import TpPartBaseModel from lightllm.common.basemodel.infer_struct import InferStateInfo @@ -27,6 +28,9 @@ def __new__(cls, *args, **kwargs): # 返回已有的实例 return cls._instances[cls] + def __init__(self, model: "TpPartBaseModel"): + self.model = model + def create_att_prefill_state(self) -> "BasePrefillAttState": raise NotImplementedError("not impl") diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py new file mode 100644 index 0000000000..b255ec688a --- /dev/null +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -0,0 +1,61 @@ +from lightllm.utils.envs_utils import get_env_start_args +from .base_att import BaseAttBackend +from .triton_backend import TritonAttBackend +from .int4kv_triton_backend import Int4kvTritonAttBackend +from .int8kv_triton_backend import Int8kvTritonAttBackend +from .fa3_backend import Fa3AttBackend +from .fp8_fa3_backend import Fp8Fa3AttBackend +from .flashinfer_backend import FlashInferAttBackend +from .fp8_flashinfer_backend import Fp8FlashInferAttBackend + +backend_dict = { + None: { + "triton": TritonAttBackend, + "fa3": Fa3AttBackend, + "flash_infer": FlashInferAttBackend, + }, + "int4kv": { + "triton": Int4kvTritonAttBackend, + "fa3": Fp8Fa3AttBackend, + "flash_infer": Fp8FlashInferAttBackend, + }, + "int8kv": { + "triton": Int8kvTritonAttBackend, + "fa3": Fp8Fa3AttBackend, + "flash_infer": Fp8FlashInferAttBackend, + }, +} + + +def get_prefill_att_backend_class() -> BaseAttBackend: + args = get_env_start_args() + llm_dtype = args.llm_kv_type + if args.llm_prefill_att_backend is not None: + return backend_dict[llm_dtype][args.llm_prefill_att_backend] + else: + # 根据环境自动选择最好的 + raise NotImplementedError(f"error") + + +def get_decode_att_backend_class() -> BaseAttBackend: + args = get_env_start_args() + llm_dtype = args.llm_kv_type + if args.llm_decode_att_backend is not None: + return backend_dict[llm_dtype][args.llm_decode_att_backend] + else: + # 根据环境自动选择最好的 + raise NotImplementedError(f"error") + + +def get_mla_prefill_att_backend_class() -> BaseAttBackend: + # args = get_env_start_args() + # llm_dtype = args.llm_kv_type + # 根据环境自动选择最好的 + raise NotImplementedError(f"error") + + +def get_mla_decode_att_backend_class() -> BaseAttBackend: + # args = get_env_start_args() + # llm_dtype = args.llm_kv_type + # 根据环境自动选择最好的 + raise NotImplementedError(f"error") diff --git a/lightllm/common/basemodel/attention/fa3_backend.py b/lightllm/common/basemodel/attention/fa3_backend.py index d833a97ec8..149de18986 100644 --- a/lightllm/common/basemodel/attention/fa3_backend.py +++ b/lightllm/common/basemodel/attention/fa3_backend.py @@ -2,20 +2,16 @@ import torch from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional, TYPE_CHECKING -from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id +from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.sgl_utils import flash_attn_with_kvcache from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor -if TYPE_CHECKING: - from lightllm.common.basemodel.basemodel import TpPartBaseModel - class Fa3AttBackend(BaseAttBackend): - def __init__(self, model: "TpPartBaseModel"): - super().__init__() - self.model = model + def __init__(self, model): + super().__init__(model=model) self.get_page_table_buffer() # init def get_page_table_buffer(self): diff --git a/lightllm/common/basemodel/attention/flashinfer_backend.py b/lightllm/common/basemodel/attention/flashinfer_backend.py index 487634b2e8..549a9d8a91 100644 --- a/lightllm/common/basemodel/attention/flashinfer_backend.py +++ b/lightllm/common/basemodel/attention/flashinfer_backend.py @@ -1,18 +1,13 @@ import dataclasses import torch from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl -from typing import Optional, TYPE_CHECKING from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id from ..triton_kernel.repack_kv_index import repack_kv_index -if TYPE_CHECKING: - from lightllm.common.basemodel.basemodel import TpPartBaseModel - class FlashInferAttBackend(BaseAttBackend): - def __init__(self, model: "TpPartBaseModel"): - super().__init__() - self.model = model + def __init__(self, model): + super().__init__(model=model) tp_world_size = get_dp_world_size() self.tp_q_head_num = model.config["num_attention_heads"] // tp_world_size self.tp_kv_head_num = max(model.config["num_key_value_heads"] // tp_world_size, 1) diff --git a/lightllm/common/basemodel/attention/fp8_fa3_backend.py b/lightllm/common/basemodel/attention/fp8_fa3_backend.py index f6e22e10ec..cf6415a97f 100644 --- a/lightllm/common/basemodel/attention/fp8_fa3_backend.py +++ b/lightllm/common/basemodel/attention/fp8_fa3_backend.py @@ -15,14 +15,10 @@ else: scaled_fp8_quant = None -if TYPE_CHECKING: - from lightllm.common.basemodel.basemodel import TpPartBaseModel - class Fp8Fa3AttBackend(BaseAttBackend): - def __init__(self, model: "TpPartBaseModel"): - super().__init__() - self.model = model + def __init__(self, model): + super().__init__(model=model) self.get_page_table_buffer() # init def get_page_table_buffer(self): diff --git a/lightllm/common/basemodel/attention/fp8_flashinfer_backend.py b/lightllm/common/basemodel/attention/fp8_flashinfer_backend.py index abed5f8b03..288264fa6a 100644 --- a/lightllm/common/basemodel/attention/fp8_flashinfer_backend.py +++ b/lightllm/common/basemodel/attention/fp8_flashinfer_backend.py @@ -1,16 +1,11 @@ import dataclasses import torch from .base_att import AttControl -from typing import Optional, TYPE_CHECKING -from ..triton_kernel.repack_kv_index import repack_kv_index from .flashinfer_backend import FlashInferAttBackend, FlashInferPrefillAttState, FlashInferDecodeAttState -if TYPE_CHECKING: - from lightllm.common.basemodel.basemodel import TpPartBaseModel - class Fp8FlashInferAttBackend(FlashInferAttBackend): - def __init__(self, model: "TpPartBaseModel"): + def __init__(self, model): super().__init__(model=model) self.kv_data_type = torch.float8_e4m3fn diff --git a/lightllm/common/basemodel/attention/int4kv_triton_backend.py b/lightllm/common/basemodel/attention/int4kv_triton_backend.py index f1c1491c52..712eb9d608 100644 --- a/lightllm/common/basemodel/attention/int4kv_triton_backend.py +++ b/lightllm/common/basemodel/attention/int4kv_triton_backend.py @@ -1,13 +1,14 @@ import dataclasses import torch +from lightllm.utils.envs_utils import get_env_start_args from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional, Tuple -from lightllm.utils.envs_utils import enable_diverse_mode_gqa_decode_fast_kernel class Int4kvTritonAttBackend(BaseAttBackend): - def __init__(self, quant_group_size: int): - self.quant_group_size: int = quant_group_size + def __init__(self, model): + super().__init__(model) + self.quant_group_size: int = get_env_start_args().llm_kv_quant_group_size def create_att_prefill_state(self, infer_state) -> "Int4kvTritonPrefillAttState": return Int4kvTritonPrefillAttState(backend=self, infer_state=infer_state) diff --git a/lightllm/common/basemodel/attention/int8kv_triton_backend.py b/lightllm/common/basemodel/attention/int8kv_triton_backend.py index 361b9b2b5c..f4b710fe7d 100644 --- a/lightllm/common/basemodel/attention/int8kv_triton_backend.py +++ b/lightllm/common/basemodel/attention/int8kv_triton_backend.py @@ -1,13 +1,15 @@ import dataclasses import torch +from lightllm.utils.envs_utils import get_env_start_args from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional, Tuple from lightllm.utils.envs_utils import enable_diverse_mode_gqa_decode_fast_kernel class Int8kvTritonAttBackend(BaseAttBackend): - def __init__(self, quant_group_size: int): - self.quant_group_size: int = quant_group_size + def __init__(self, model): + super().__init__(model) + self.quant_group_size: int = get_env_start_args().llm_kv_quant_group_size def create_att_prefill_state(self, infer_state) -> "Int8kvTritonPrefillAttState": return Int8kvTritonPrefillAttState(backend=self, infer_state=infer_state) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 9f39ff6567..9ee45ab271 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -32,7 +32,7 @@ from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.utils.infer_utils import post_empty_cache -from .attention import TritonAttBackend +from .attention import get_prefill_att_backend_class, get_decode_att_backend_class logger = init_logger(__name__) @@ -120,7 +120,12 @@ def __init__(self, kvargs): self._init_inferstate_cls() # wait必须在init cudagraph 之前,避免错误捕获 self._wait_other_modules_ready() + self._init_att_backend() + + logger.info(f"use prefill att backend: {self.prefill_att_backend.__class__.__name__}") + logger.info(f"use decode att backend: {self.decode_att_backend.__class__.__name__}") + self._autotune_warmup() self._init_padded_req() self._init_cudagraph() @@ -241,9 +246,8 @@ def _init_some_value(self): return def _init_att_backend(self): - self.prefill_att_backend = TritonAttBackend() - self.decode_att_backend = TritonAttBackend() - assert id(self.prefill_att_backend) == id(self.decode_att_backend) + self.prefill_att_backend = get_prefill_att_backend_class()(model=self) + self.decode_att_backend = get_decode_att_backend_class()(model=self) return def _init_cudagraph(self): diff --git a/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py b/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py index af338f0413..ffdc9b2c94 100755 --- a/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py @@ -1,4 +1,5 @@ import torch +from typing import Tuple, Any from .offline_fp8_quant_mem_manager import OfflineFP8QuantMemManager @@ -20,3 +21,8 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: self.kv_buffer[layer_index].view(torch.float8_e4m3fn), ) return + + def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: + k = self.kv_buffer[layer_index][:, : self.head_num, :] + v = self.kv_buffer[layer_index][:, self.head_num :, :] + return k, v diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 784a22a9ad..1203cbdec7 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -3,7 +3,7 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp -from typing import List, Union +from typing import List, Union, Tuple, Any from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_for_dp from lightllm.server.pd_io_struct import KVMoveTask from lightllm.utils.log_utils import init_logger @@ -75,6 +75,11 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: destindex_copy_kv(kv, mem_index, self.kv_buffer[layer_index]) return + def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: + k = self.kv_buffer[layer_index][:, : self.head_num, :] + v = self.kv_buffer[layer_index][:, self.head_num :, :] + return k, v + def get_cell_size(self): return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) diff --git a/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py index 3549e31bbb..559980dc12 100755 --- a/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py @@ -1,5 +1,5 @@ import torch - +from typing import Tuple, Any from .mem_manager import MemoryManager @@ -24,6 +24,13 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: ) return + def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: + k = self.kv_buffer[layer_index][:, : self.head_num, :] + k_scale = self.scale_buffer[layer_index][:, : self.head_num, :] + v = self.kv_buffer[layer_index][:, self.head_num :, :] + v_scale = self.scale_buffer[layer_index][:, self.head_num :, :] + return (k, k_scale), (v, v_scale) + def get_cell_size(self): return 2 * self.head_num * self.head_dim // 2 * self.layer_num * torch._utils._element_size( self.kv_dtype diff --git a/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py index 0577c982ae..951d72e2c8 100755 --- a/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py @@ -1,5 +1,5 @@ import torch - +from typing import Tuple, Any from .mem_manager import MemoryManager @@ -24,6 +24,13 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: ) return + def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: + k = self.kv_buffer[layer_index][:, : self.head_num, :] + k_scale = self.scale_buffer[layer_index][:, : self.head_num, :] + v = self.kv_buffer[layer_index][:, self.head_num :, :] + v_scale = self.scale_buffer[layer_index][:, self.head_num :, :] + return (k, k_scale), (v, v_scale) + def get_cell_size(self): return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size( self.kv_dtype diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index f16876d4c0..6bb00291d0 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -44,10 +44,8 @@ def _context_attention_kernel( infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight, ) -> torch.Tensor: - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) - _k = kv[:, 0 : self.tp_k_head_num_, :] - _v = kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] o_tensor = infer_state.prefill_att_state.prefill_att( q=_q, k=_k, @@ -64,10 +62,8 @@ def _token_attention_kernel( infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight, ) -> torch.Tensor: - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) - _k = kv[:, 0 : self.tp_k_head_num_, :] - _v = kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] o_tensor = infer_state.decode_att_state.decode_att( q=_q, k=_k, v=_v, layer_weight=layer_weight, alloc_func=self.alloc_tensor ) diff --git a/test/test_api/test_generate_api.py b/test/test_api/test_generate_api.py index 05fbda44ea..4ea74b7f6d 100644 --- a/test/test_api/test_generate_api.py +++ b/test/test_api/test_generate_api.py @@ -19,7 +19,7 @@ def run(self): print("Error:", response.status_code, response.text) -url = "http://localhost:8000/generate" +url = "http://localhost:8089/generate" headers = {"Content-Type": "application/json"} for i in range(1): From b0a59bbeae7bafdee65de89a4ef06a24386d35e3 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 10:23:03 +0000 Subject: [PATCH 045/114] fix --- lightllm/common/basemodel/attention/fa3_backend.py | 12 ++++++------ .../common/basemodel/attention/fp8_fa3_backend.py | 2 +- lightllm/common/basemodel/infer_struct.py | 10 +++++----- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/lightllm/common/basemodel/attention/fa3_backend.py b/lightllm/common/basemodel/attention/fa3_backend.py index 149de18986..f84fef92ad 100644 --- a/lightllm/common/basemodel/attention/fa3_backend.py +++ b/lightllm/common/basemodel/attention/fa3_backend.py @@ -19,7 +19,7 @@ def get_page_table_buffer(self): 用于减少 decode graph 捕获的时候, 造成显存二次方增长的情况. """ model = self.model - if self._shared_page_table_buffer is None: + if not hasattr(self, "_shared_page_table_buffer"): self._shared_page_table_buffer = [ torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( get_current_device_id() @@ -88,8 +88,8 @@ def _nomarl_prefill_att( sm_scale = 1.0 / (Lq ** 0.5) o = flash_attn_with_kvcache( q=q, - k_cache=k, - v_cache=v, + k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), + v_cache=v.view(v.shape[0], 1, v.shape[1], v.shape[2]), page_table=self.page_table, cache_seqlens=self.infer_state.b_seq_len, cu_seqlens_q=self.cu_seqlens_q, @@ -147,7 +147,7 @@ def init_state(self): self.infer_state.batch_size <= model.graph_max_batch_size and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch ): - page_buffer = self.backend.get_page_table_buffer(model.graph_max_batch_size, model.graph_max_len_in_batch) + page_buffer = self.backend.get_page_table_buffer() self.page_table = page_buffer[self.infer_state.microbatch_index][ : att_batch_size * model.graph_max_len_in_batch ].reshape(att_batch_size, model.graph_max_len_in_batch) @@ -210,8 +210,8 @@ def _normal_decode_att( sm_scale = 1.0 / (Lq ** 0.5) o = flash_attn_with_kvcache( q=q, - k_cache=k, - v_cache=v, + k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), + v_cache=v.view(v.shape[0], 1, v.shape[1], v.shape[2]), page_table=self.page_table, cache_seqlens=self.b_att_seq_len, cu_seqlens_q=self.cu_seqlens_q, diff --git a/lightllm/common/basemodel/attention/fp8_fa3_backend.py b/lightllm/common/basemodel/attention/fp8_fa3_backend.py index cf6415a97f..2af794db43 100644 --- a/lightllm/common/basemodel/attention/fp8_fa3_backend.py +++ b/lightllm/common/basemodel/attention/fp8_fa3_backend.py @@ -26,7 +26,7 @@ def get_page_table_buffer(self): 用于减少 decode graph 捕获的时候, 造成显存二次方增长的情况. """ model = self.model - if self._shared_page_table_buffer is None: + if not hasattr(self, "_shared_page_table_buffer"): self._shared_page_table_buffer = [ torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( get_current_device_id() diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index adad0a6abc..6a74cdf501 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -95,11 +95,6 @@ def __init__(self): self.dp_input_split_sizes: List[List[int]] = None def init_some_extra_state(self, model): - if self.is_prefill: - self.prefill_att_state.init_state() - else: - self.decode_att_state.init_state() - if self.is_prefill: ( self.b_q_seq_len, @@ -125,6 +120,11 @@ def init_some_extra_state(self, model): self.max_kv_seq_len = self.max_len_in_batch self.b_start_loc = self.b1_cu_kv_seq_len[0:-1] + if self.is_prefill: + self.prefill_att_state.init_state() + else: + self.decode_att_state.init_state() + def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"): for attr_name, attr_value in vars(new_infer_state).items(): if isinstance(attr_value, torch.Tensor): From 120f28d15d0d737493332769e2f6282fe6134657 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 10:31:26 +0000 Subject: [PATCH 046/114] fix all --- lightllm/common/basemodel/basemodel.py | 16 ++++++++++++++-- lightllm/common/basemodel/infer_struct.py | 1 + 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 9ee45ab271..f1d41af88f 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -33,6 +33,7 @@ from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.utils.infer_utils import post_empty_cache from .attention import get_prefill_att_backend_class, get_decode_att_backend_class +from .attention import BaseAttBackend logger = init_logger(__name__) @@ -246,8 +247,8 @@ def _init_some_value(self): return def _init_att_backend(self): - self.prefill_att_backend = get_prefill_att_backend_class()(model=self) - self.decode_att_backend = get_decode_att_backend_class()(model=self) + self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class()(model=self) + self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class()(model=self) return def _init_cudagraph(self): @@ -481,6 +482,7 @@ def _prefill( prefill_mem_indexes_ready_event.record() infer_state.init_some_extra_state(self) + infer_state.init_att_state() model_output = self._context_forward(infer_state) if is_padded_model_input: model_output = self._create_unpad_prefill_model_output( @@ -512,6 +514,7 @@ def _decode( infer_state.mem_index, ) infer_state.init_some_extra_state(self) + infer_state.init_att_state() if self.graph.need_capture(find_graph_batch_size): infer_state.is_cuda_graph = True @@ -531,6 +534,7 @@ def _decode( infer_state.mem_index, ) infer_state.init_some_extra_state(self) + infer_state.init_att_state() model_output = self._token_forward(infer_state) return model_output @@ -639,6 +643,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod max_q_seq_len=infer_state0.max_q_seq_len, ) infer_state0.init_some_extra_state(self) + infer_state0.init_att_state() infer_state1 = self._create_inferstate(model_input1, 1) init_req_to_token_indexes( @@ -651,6 +656,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod max_q_seq_len=infer_state1.max_q_seq_len, ) infer_state1.init_some_extra_state(self) + infer_state1.init_att_state() prefill_mem_indexes_ready_event = torch.cuda.Event() prefill_mem_indexes_ready_event.record() @@ -705,6 +711,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state0.mem_index, ) infer_state0.init_some_extra_state(self) + infer_state0.init_att_state() + infer_state1 = self._create_inferstate(padded_model_input1, 1) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -713,6 +721,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state1.mem_index, ) infer_state1.init_some_extra_state(self) + infer_state1.init_att_state() if self.graph.need_capture(find_graph_batch_size): infer_state0.is_cuda_graph = True @@ -741,6 +750,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state0.mem_index, ) infer_state0.init_some_extra_state(self) + infer_state0.init_att_state() + infer_state1 = self._create_inferstate(model_input1, 1) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -749,6 +760,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state1.mem_index, ) infer_state1.init_some_extra_state(self) + infer_state1.init_att_state() model_output0, model_output1 = self._overlap_tpsp_token_forward(infer_state0, infer_state1=infer_state1) return model_output0, model_output1 diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 6a74cdf501..8951c63f1f 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -120,6 +120,7 @@ def init_some_extra_state(self, model): self.max_kv_seq_len = self.max_len_in_batch self.b_start_loc = self.b1_cu_kv_seq_len[0:-1] + def init_att_state(self): if self.is_prefill: self.prefill_att_state.init_state() else: From 2371f7e0e1995b9bed6c3ad09849118d5a53085d Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 10:34:42 +0000 Subject: [PATCH 047/114] fix all --- lightllm/common/basemodel/attention/create_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index b255ec688a..1e7f0e9ea5 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -12,17 +12,17 @@ None: { "triton": TritonAttBackend, "fa3": Fa3AttBackend, - "flash_infer": FlashInferAttBackend, + "flashinfer": FlashInferAttBackend, }, "int4kv": { "triton": Int4kvTritonAttBackend, "fa3": Fp8Fa3AttBackend, - "flash_infer": Fp8FlashInferAttBackend, + "flashinfer": Fp8FlashInferAttBackend, }, "int8kv": { "triton": Int8kvTritonAttBackend, "fa3": Fp8Fa3AttBackend, - "flash_infer": Fp8FlashInferAttBackend, + "flashinfer": Fp8FlashInferAttBackend, }, } From a4a7614feb6dfee214523191066ab860b3aa766f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 10:47:00 +0000 Subject: [PATCH 048/114] fix flashinfer --- .../common/basemodel/attention/flashinfer_backend.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/lightllm/common/basemodel/attention/flashinfer_backend.py b/lightllm/common/basemodel/attention/flashinfer_backend.py index 549a9d8a91..151eadf9a8 100644 --- a/lightllm/common/basemodel/attention/flashinfer_backend.py +++ b/lightllm/common/basemodel/attention/flashinfer_backend.py @@ -109,13 +109,11 @@ def _nomarl_prefill_att( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_weight, alloc_func=torch.empty ) -> torch.Tensor: self.backend: FlashInferAttBackend = self.backend # for typing - o_tensor = alloc_func(q.shape, q.dtype, device="cuda") self.prefill_wrapper.run( - q.view(q.shape[0], -1, self.backend.head_dim), - k.unsqueeze(1)[:, :, :, :], - v.unsqueeze(1)[:, :, :, :], - out=o_tensor.view(q.shape[0], -1, self.backend.head_dim), + q, + (k.unsqueeze(1), v.unsqueeze(1)), + out=o_tensor, ) return o_tensor @@ -224,8 +222,7 @@ def _normal_decode_att( o_tensor = alloc_func(q.shape, q.dtype) self.decode_wrapper.run( q, - k.unsqueeze(1)[:, :, :, :], - v.unsqueeze(1)[:, :, :, :], + (k.unsqueeze(1), v.unsqueeze(1)), out=o_tensor, ) return o_tensor From a330e9b1ba8627570af1f53fdc3c3ee2836dc7af Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 9 Jan 2026 02:00:35 +0000 Subject: [PATCH 049/114] fix --- .../token_attention_nopad_reduceV.py | 223 ------------------ .../token_attention_nopad_softmax.py | 78 ------ .../token_attention_softmax_and_reducev.py | 112 --------- 3 files changed, 413 deletions(-) delete mode 100644 lightllm/models/llama/triton_kernel/token_attention_nopad_reduceV.py delete mode 100644 lightllm/models/llama/triton_kernel/token_attention_nopad_softmax.py delete mode 100644 lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_reduceV.py b/lightllm/models/llama/triton_kernel/token_attention_nopad_reduceV.py deleted file mode 100644 index 243a8d1f66..0000000000 --- a/lightllm/models/llama/triton_kernel/token_attention_nopad_reduceV.py +++ /dev/null @@ -1,223 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_token_att2( - Prob, - V, - Out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - stride_req_to_tokens_b, - stride_req_to_tokens_s, - stride_ph, - stride_pbs, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - cur_kv_head = cur_head // kv_group_num - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_index = 0 - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - - v_loc_off = cur_batch_req_idx * stride_req_to_tokens_b + (cur_batch_start_index + offs_n) * stride_req_to_tokens_s - p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs - v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - p_value = tl.load(Prob + p_offs + start_n, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0) - v_loc = tl.load( - Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s, - mask=(start_n + offs_n) < cur_batch_seq_len, - other=0.0, - ).to(tl.int64) - v_value = tl.load( - V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0 - ) - acc += tl.sum(p_value[:, None] * v_value, 0) - - acc = acc.to(Out.dtype.element_ty) - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - return - - -@torch.no_grad() -def token_att_fwd2(prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen): - BLOCK = 128 - # BLOCK = 64 # for triton 2.0.0dev - batch, head = B_req_idx.shape[0], prob.shape[0] - grid = (batch, head) - num_warps = 4 - dim = v.shape[-1] - - kv_group_num = prob.shape[0] // v.shape[1] - - _fwd_kernel_token_att2[grid]( - prob, - v, - out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - Req_to_tokens.stride(0), - Req_to_tokens.stride(1), - prob.stride(0), - prob.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - kv_group_num=kv_group_num, - BLOCK_DMODEL=dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - -@triton.jit -def _fwd_kernel_token_att2_int8v( - Prob, - V, - V_scale, - Out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, # B_Start_Loc 保存的是如果连续存储时候的累加输入和 - stride_req_to_tokens_b, - stride_req_to_tokens_s, - stride_ph, - stride_pbs, - stride_vbs, - stride_vh, - stride_vd, - stride_vsbs, - stride_vsh, - stride_vsd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - cur_kv_head = cur_head // kv_group_num - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_index = 0 - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - - v_loc_off = cur_batch_req_idx * stride_req_to_tokens_b + (cur_batch_start_index + offs_n) * stride_req_to_tokens_s - p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs - v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - vs_offs = cur_kv_head * stride_vsh - - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - p_value = tl.load(Prob + p_offs + start_n, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0) - v_loc = tl.load( - Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s, - mask=(start_n + offs_n) < cur_batch_seq_len, - other=0.0, - ).to(tl.int64) - v_value = tl.load( - V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0 - ) - vs_value = tl.load( - V_scale + vs_offs + v_loc[:, None] * stride_vsbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - acc += tl.sum(p_value[:, None] * v_value * vs_value, 0) - - acc = acc.to(Out.dtype.element_ty) - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - return - - -@torch.no_grad() -def token_att_fwd2_int8v(prob, v, v_scale, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, max_len_in_batch): - if max_len_in_batch < 512: - BLOCK = triton.next_power_of_2(max_len_in_batch) - else: - BLOCK = 512 - batch, head = B_req_idx.shape[0], prob.shape[0] - grid = (batch, head) - num_warps = 4 - dim = v.shape[-1] - kv_group_num = prob.shape[0] // v.shape[1] - - _fwd_kernel_token_att2_int8v[grid]( - prob, - v, - v_scale, - out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - Req_to_tokens.stride(0), - Req_to_tokens.stride(1), - prob.stride(0), - prob.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - v_scale.stride(0), - v_scale.stride(1), - v_scale.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - kv_group_num=kv_group_num, - BLOCK_DMODEL=dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - -def torch_att(V, P, bs, seqlen, num_head, head_dim): - V = V.view(bs, seqlen, num_head, head_dim).transpose(1, 2) - P = P.reshape(num_head, bs, 1, seqlen).transpose(0, 1) - out = torch.matmul(P, V) - - return out diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_softmax.py b/lightllm/models/llama/triton_kernel/token_attention_nopad_softmax.py deleted file mode 100644 index 5e6040ac55..0000000000 --- a/lightllm/models/llama/triton_kernel/token_attention_nopad_softmax.py +++ /dev/null @@ -1,78 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_token_softmax( - Logics, B_Start_Loc, B_Seqlen, - Prob_Out, - stride_logic_h, stride_logic_bs, - stride_prob_h, stride_prob_bs, - BLOCK_SIZE: tl.constexpr -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - col_offsets = tl.arange(0, BLOCK_SIZE) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - row = tl.load(Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, - mask=col_offsets < cur_batch_seq_len, other=-float('inf')).to(tl.float32) - - row_minus_max = row - tl.max(row, axis=0) - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - - tl.store(Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) - * stride_prob_bs, softmax_output, mask=col_offsets < cur_batch_seq_len) - return - -@torch.no_grad() -def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len): - BLOCK_SIZE = triton.next_power_of_2(max_input_len) - batch, head_num = B_Start_Loc.shape[0], Logics.shape[0] - - num_warps = 4 - if BLOCK_SIZE >= 2048: - num_warps = 8 - if BLOCK_SIZE >= 4096: - num_warps = 16 - - _fwd_kernel_token_softmax[(batch, head_num)]( - Logics, B_Start_Loc, B_Seqlen, - Prob_Out, - Logics.stride(0), Logics.stride(1), - Prob_Out.stride(0), Prob_Out.stride(1), - num_warps=num_warps, - BLOCK_SIZE=BLOCK_SIZE, - ) - return - -def test1(): - - import torch - - B, N_CTX, H, D = 4, 1025, 12, 128 - - dtype = torch.float16 - - Logics = torch.empty((H, B * N_CTX), dtype=dtype, device="cuda").normal_(mean=0.1, std=10) - ProbOut = torch.empty((H, B * N_CTX), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - - b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") - b_seq_len = torch.zeros((B,), dtype=torch.int32, device="cuda") - - for i in range(B): - b_start_loc[i] = i * N_CTX - b_seq_len[i] = N_CTX - - token_softmax_fwd(Logics, b_start_loc, b_seq_len, ProbOut, N_CTX) - - torch_out = Logics.reshape(H * B, -1).softmax(-1).reshape(H, B * N_CTX) - o = ProbOut - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) - assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) diff --git a/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py b/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py deleted file mode 100644 index d963f8582a..0000000000 --- a/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py +++ /dev/null @@ -1,112 +0,0 @@ -import torch -import triton -import triton.language as tl -import torch.nn.functional as F - - -@triton.jit -def _fwd_kernel( - Logics, - V, - Out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - stride_logic_h, - stride_logic_bs, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_req_to_token_b, - stride_req_to_token_s, - other_kv_index, # 避免读取到nan的数据 - kv_group_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - - off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - v_ptrs = V + off_v - - e_max = float("-inf") - e_sum = 0.0 - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - v_index = tl.load( - Req_to_tokens + cur_batch_req_idx * stride_req_to_token_b + (start_n + offs_n) * stride_req_to_token_s, - mask=(start_n + offs_n) < cur_batch_seq_len, - other=other_kv_index, - ).to(tl.int64) - - qk = tl.load( - Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs, - mask=start_n + offs_n < cur_batch_seq_len, - other=float("-inf"), - ) - - n_e_max = tl.maximum(tl.max(qk, 0), e_max) - old_scale = tl.exp(e_max - n_e_max) - p = tl.exp(qk - n_e_max) - e_sum = e_sum * old_scale + tl.sum(p, 0) - v = tl.load(v_ptrs + v_index[:, None] * stride_vbs) - acc = acc * old_scale + tl.sum(p[:, None] * v, 0) - e_max = n_e_max - - acc = acc / e_sum - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - return - - -@torch.no_grad() -def token_softmax_reducev_fwd(logics, v, o, req_to_tokens, b_req_idx, b_start_loc, b_seq_len): - BLOCK = 64 - batch, head = b_seq_len.shape[0], logics.shape[0] - grid = (batch, head) - kv_group_num = logics.shape[0] // v.shape[1] - - num_warps = 1 - _fwd_kernel[grid]( - logics, - v, - o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, - logics.stride(0), - logics.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - req_to_tokens.stride(0), - req_to_tokens.stride(1), - 0, - kv_group_num, - BLOCK_DMODEL=v.shape[-1], - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=3, - ) - return From 724a6009398c44d6cdfac7ea3674a5bba91f638b Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 11:29:29 +0000 Subject: [PATCH 050/114] remove chatglm2 --- lightllm/models/__init__.py | 1 - lightllm/models/chatglm2/__init__.py | 0 .../models/chatglm2/layer_infer/__init__.py | 0 .../layer_infer/transformer_layer_infer.py | 28 --- .../models/chatglm2/layer_weights/__init__.py | 0 .../pre_and_post_layer_weight.py | 20 --- .../layer_weights/transformer_layer_weight.py | 72 -------- lightllm/models/chatglm2/model.py | 78 --------- .../models/chatglm2/triton_kernel/__init__.py | 0 .../chatglm2/triton_kernel/rotary_emb.py | 160 ------------------ 10 files changed, 359 deletions(-) delete mode 100644 lightllm/models/chatglm2/__init__.py delete mode 100644 lightllm/models/chatglm2/layer_infer/__init__.py delete mode 100755 lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py delete mode 100644 lightllm/models/chatglm2/layer_weights/__init__.py delete mode 100644 lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py delete mode 100755 lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py delete mode 100644 lightllm/models/chatglm2/model.py delete mode 100644 lightllm/models/chatglm2/triton_kernel/__init__.py delete mode 100755 lightllm/models/chatglm2/triton_kernel/rotary_emb.py diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 4ee02f003b..afc3fc660a 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -8,7 +8,6 @@ from lightllm.models.qwen2.model import Qwen2TpPartModel from lightllm.models.qwen3.model import Qwen3TpPartModel from lightllm.models.qwen3_moe.model import Qwen3MOEModel -from lightllm.models.chatglm2.model import ChatGlm2TpPartModel from lightllm.models.internlm.model import InternlmTpPartModel from lightllm.models.stablelm.model import StablelmTpPartModel from lightllm.models.internlm2.model import Internlm2TpPartModel diff --git a/lightllm/models/chatglm2/__init__.py b/lightllm/models/chatglm2/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/models/chatglm2/layer_infer/__init__.py b/lightllm/models/chatglm2/layer_infer/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py b/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py deleted file mode 100755 index 07ffc4beab..0000000000 --- a/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.chatglm2.layer_weights.transformer_layer_weight import ChatGLM2TransformerLayerWeight - - -class ChatGLM2TransformerLayerInfer(LlamaTransformerLayerInfer): - """ """ - - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) - return - - def swiglu(self, x): - x = torch.chunk(x, 2, dim=-1) - return torch.nn.functional.silu(x[0]) * x[1] - - def _ffn( - self, input, infer_state: LlamaInferStateInfo, layer_weight: ChatGLM2TransformerLayerWeight - ) -> torch.Tensor: - input = input.view(-1, self.embed_dim_) - up_gate_out = layer_weight.gate_up_proj.mm(input) - input = None - ffn1_out = self.swiglu(up_gate_out) - up_gate_out = None - ffn2_out = layer_weight.down_proj.mm(ffn1_out) - ffn1_out = None - return ffn2_out diff --git a/lightllm/models/chatglm2/layer_weights/__init__.py b/lightllm/models/chatglm2/layer_weights/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py deleted file mode 100644 index 0139eb8837..0000000000 --- a/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py +++ /dev/null @@ -1,20 +0,0 @@ -from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight - - -class ChatGLM2PreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) - - self.wte_weight_ = EmbeddingWeight( - weight_name="transformer.embedding.word_embeddings.weight", data_type=self.data_type_ - ) - self.lm_head_weight_ = LMHeadWeight( - weight_name="transformer.output_layer.weight", - data_type=self.data_type_, - ) - self.final_norm_weight_ = NoTpNormWeight( - weight_name="transformer.encoder.final_layernorm.weight", - data_type=self.data_type_, - bias_name=None, - ) diff --git a/lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py b/lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py deleted file mode 100755 index d4dd1b7a29..0000000000 --- a/lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py +++ /dev/null @@ -1,72 +0,0 @@ -from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight - - -class ChatGLM2TransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__( - layer_num, - data_type, - network_config, - mode, - quant_cfg, - ) - return - - def _preprocess_weight(self, weights): - n_kv_embed = self.head_dim * self.n_kv_head - qkv_weight_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.weight" - if qkv_weight_name in weights: - qkv_weight_ = weights[qkv_weight_name] - weights[self._q_weight_name] = qkv_weight_[: self.n_embed, :] - weights[self._k_weight_name] = qkv_weight_[self.n_embed : self.n_embed + n_kv_embed, :] - weights[self._v_weight_name] = qkv_weight_[self.n_embed + n_kv_embed : self.n_embed + 2 * n_kv_embed, :] - del weights[qkv_weight_name] - - qkv_bias_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.bias" - if qkv_bias_name in weights: - qkv_bias_ = weights[qkv_bias_name] - weights[self._q_bias_name] = qkv_bias_[: self.n_embed] - weights[self._k_bias_name] = qkv_bias_[self.n_embed : self.n_embed + n_kv_embed] - weights[self._v_bias_name] = qkv_bias_[self.n_embed + n_kv_embed : self.n_embed + 2 * n_kv_embed] - del weights[qkv_bias_name] - - gate_up_weight_name = f"transformer.encoder.layers.{self.layer_num_}.mlp.dense_h_to_4h.weight" - if gate_up_weight_name in weights: - gate_up_weight_ = weights[gate_up_weight_name] - weights[self._gate_weight_name] = gate_up_weight_[: self.n_inter, :] - weights[self._up_weight_name] = gate_up_weight_[self.n_inter : 2 * self.n_inter, :] - del weights[gate_up_weight_name] - - def _parse_config(self): - self.n_embed = self.network_config_["hidden_size"] - self.n_head = self.network_config_["num_attention_heads"] - self.n_inter = self.network_config_["ffn_hidden_size"] - self.n_kv_head = self.network_config_["multi_query_group_num"] - self.head_dim = self.network_config_.get("head_dim", self.n_embed // self.n_head) - - def load_hf_weights(self, weights): - self._preprocess_weight(weights) - super().load_hf_weights(weights) - return - - def _init_weight_names(self): - self._q_weight_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.q_proj.weight" - self._q_bias_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.q_proj.bias" - self._k_weight_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.k_proj.weight" - self._k_bias_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.k_proj.bias" - self._v_weight_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.v_proj.weight" - self._v_bias_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.v_proj.bias" - self._o_weight_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.dense.weight" - self._o_bias_name = None - - self._gate_weight_name = f"transformer.encoder.layers.{self.layer_num_}.mlp.gate_proj.weight" - self._gate_bias_name = None - self._up_weight_name = f"transformer.encoder.layers.{self.layer_num_}.mlp.up_proj.weight" - self._up_bias_name = None - self._down_weight_name = f"transformer.encoder.layers.{self.layer_num_}.mlp.dense_4h_to_h.weight" - self._down_bias_name = None - - self._att_norm_weight_name = f"transformer.encoder.layers.{self.layer_num_}.input_layernorm.weight" - self._att_norm_bias_name = None - self._ffn_norm_weight_name = f"transformer.encoder.layers.{self.layer_num_}.post_attention_layernorm.weight" - self._ffn_norm_bias_name = None diff --git a/lightllm/models/chatglm2/model.py b/lightllm/models/chatglm2/model.py deleted file mode 100644 index e6aa395275..0000000000 --- a/lightllm/models/chatglm2/model.py +++ /dev/null @@ -1,78 +0,0 @@ -import os -import json -import torch - -from lightllm.models.registry import ModelRegistry -from lightllm.models.chatglm2.layer_infer.transformer_layer_infer import ChatGLM2TransformerLayerInfer -from lightllm.models.chatglm2.layer_weights.transformer_layer_weight import ChatGLM2TransformerLayerWeight -from lightllm.models.chatglm2.layer_weights.pre_and_post_layer_weight import ChatGLM2PreAndPostLayerWeight -from lightllm.models.llama.model import LlamaTpPartModel -from lightllm.common.build_utils import repair_config -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -@ModelRegistry("chatglm") -class ChatGlm2TpPartModel(LlamaTpPartModel): - # Please use the fast tokenizer from: - # [THUDM/chatglm3-6b PR #12](https://huggingface.co/THUDM/chatglm3-6b/discussions/12). - - # weight class - pre_and_post_weight_class = ChatGLM2PreAndPostLayerWeight - transformer_weight_class = ChatGLM2TransformerLayerWeight - - # infer class - transformer_layer_infer_class = ChatGLM2TransformerLayerInfer - - def __init__(self, kvargs): - super().__init__(kvargs) - - def _init_config(self): - super()._init_config() - # rename key - # repair_config() - repair_config(self.config, same_names=["num_hidden_layers", "n_layer", "num_layers"]) - repair_config(self.config, same_names=["vocab_size", "padded_vocab_size"]) - repair_config(self.config, same_names=["rms_norm_eps", "layernorm_epsilon"]) - repair_config(self.config, same_names=["seq_length", "max_sequence_length"]) - return - - def _reset_num_key_value_heads(self): - self.config["num_key_value_heads"] = self.config["multi_query_group_num"] - return - - def _verify_params(self): - assert self.load_way == "HF", "ChatGLM only support HF format for now" - assert self.tp_world_size_ in [1, 2], "ChatGLM can only run in tp=1 or tp=2 for now" - - def _init_to_get_rotary(self, base=10000): - if self.config.get("rope_scaling", {}) is None: - rope_scaling_factor = 1.0 - else: - rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) - if "max_sequence_length" in self.config: - max_seq_len = self.config["max_sequence_length"] - else: - max_seq_len = self.config.get("max_position_embeddings", 2048) * rope_scaling_factor - - base = float(base) * self.config.get("rope_ratio", 1.0) - - # NTK - try: - ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1)) - assert ntk_alpha >= 1 - if ntk_alpha > 1: - logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}") - max_seq_len *= ntk_alpha - base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula - except: - pass - n_elem = self.head_dim_ // 2 - inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) - t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor - freqs = torch.outer(t, inv_freq) - - self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() - self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() - return diff --git a/lightllm/models/chatglm2/triton_kernel/__init__.py b/lightllm/models/chatglm2/triton_kernel/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/models/chatglm2/triton_kernel/rotary_emb.py b/lightllm/models/chatglm2/triton_kernel/rotary_emb.py deleted file mode 100755 index ad1d1c2cf0..0000000000 --- a/lightllm/models/chatglm2/triton_kernel/rotary_emb.py +++ /dev/null @@ -1,160 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _rotary_kernel( - Q, - K, - Cos, - Sin, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_cosbs, - stride_cosd, - stride_sinbs, - stride_sind, - max_total_len, - HEAD_Q, - HEAD_K, # N_CTX 代表要计算的上下文长度 - BLOCK_HEAD: tl.constexpr, - BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, -): - cur_head_index = tl.program_id(0) - cur_seq_index = tl.program_id(1) - - cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) - cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) - - dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 - dim_range1 = dim_range0 + 1 - - off_q0 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range0[None, None, :] * stride_qd - ) - off_q1 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range1[None, None, :] * stride_qd - ) - - cos_range = tl.arange(0, BLOCK_DMODEL // 2) - off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd - - q0 = tl.load( - Q + off_q0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), - other=0.0, - ) - q1 = tl.load( - Q + off_q1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), - other=0.0, - ) - - cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - out0 = q0 * cos - q1 * sin - out1 = q0 * sin + q1 * cos - - tl.store( - Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) - ) - tl.store( - Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) - ) - - off_k0 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range0[None, None, :] * stride_kd - ) - off_k1 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range1[None, None, :] * stride_kd - ) - - off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd - - k0 = tl.load( - K + off_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - other=0.0, - ) - k1 = tl.load( - K + off_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - other=0.0, - ) - - cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - out_k0 = k0 * cos - k1 * sin - out_k1 = k0 * sin + k1 * cos - - tl.store( - K + off_k0, - out_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - ) - tl.store( - K + off_k1, - out_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - ) - return - - -@torch.no_grad() -def rotary_emb_fwd(q, k, cos, sin): - total_len = q.shape[0] - head_num_q, head_num_k = q.shape[1], k.shape[1] - head_dim = q.shape[2] // 2 - assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" - - BLOCK_SEQ = 16 - BLOCK_HEAD = 4 - if head_dim >= 128: - num_warps = 8 - else: - num_warps = 4 - - grid = (triton.cdiv(head_num_q, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) - _rotary_kernel[grid]( - q, - k, - cos, - sin, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - cos.stride(0), - cos.stride(1), - sin.stride(0), - sin.stride(1), - total_len, - head_num_q, - head_num_k, - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_SEQ=BLOCK_SEQ, - BLOCK_DMODEL=head_dim, - num_warps=num_warps, - num_stages=1, - ) - return From 412d6e0cfca29e33a57e2e4d271d635d15c3c0eb Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 12:30:57 +0000 Subject: [PATCH 051/114] fix --- .../common/basemodel/attention/base_att.py | 19 +++++++-- .../common/basemodel/attention/fa3_backend.py | 36 ++++++++++++---- .../basemodel/attention/flashinfer_backend.py | 19 +++++---- .../basemodel/attention/fp8_fa3_backend.py | 41 +++++++++++++------ .../attention/fp8_flashinfer_backend.py | 31 ++++++++------ .../attention/int4kv_triton_backend.py | 21 +++++----- .../attention/int8kv_triton_backend.py | 26 ++++++------ .../basemodel/attention/triton_backend.py | 36 +++++++--------- 8 files changed, 139 insertions(+), 90 deletions(-) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index 3abff06e45..22b5ff4870 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -1,7 +1,7 @@ import torch from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING, Tuple, Union if TYPE_CHECKING: from lightllm.common.basemodel.basemodel import TpPartBaseModel @@ -37,6 +37,16 @@ def create_att_prefill_state(self) -> "BasePrefillAttState": def create_att_decode_state(self) -> "BaseDecodeAttState": raise NotImplementedError("not impl") + def _find_layer_index( + self, k: torch.Tensor, v: torch.Tensor, att_state: Union["BasePrefillAttState", "BaseDecodeAttState"] + ) -> int: + kv_buffer = att_state.infer_state.mem_manager.kv_buffer + layer_count = len(kv_buffer) + find_dict = {kv_buffer[i].data_ptr(): i for i in range(layer_count)} + key = min(k.data_ptr(), v.data_ptr()) + assert key in find_dict + return find_dict[key] + @dataclass class AttControl: @@ -67,7 +77,6 @@ def prefill_att( q: torch.Tensor, k: torch.tensor, v: torch.tensor, - layer_weight, att_control: AttControl = AttControl(), alloc_func=torch.empty, ) -> torch.Tensor: @@ -93,7 +102,6 @@ def decode_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, att_control: AttControl = AttControl(), alloc_func=torch.empty, ) -> torch.Tensor: @@ -107,3 +115,8 @@ class AttControl: """ use_alibi: bool = False + tp_alibi: torch.Tensor = None + use_sliding_window: bool = False + sliding_window: Tuple[int, int] = (-1, -1) + use_att_sink: bool = False + sink_weight: torch.Tensor = None diff --git a/lightllm/common/basemodel/attention/fa3_backend.py b/lightllm/common/basemodel/attention/fa3_backend.py index f84fef92ad..794c0dff8a 100644 --- a/lightllm/common/basemodel/attention/fa3_backend.py +++ b/lightllm/common/basemodel/attention/fa3_backend.py @@ -65,7 +65,6 @@ def prefill_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, att_control: AttControl = AttControl(), alloc_func=torch.empty, ) -> torch.Tensor: @@ -74,15 +73,25 @@ def prefill_att( q=q, k=k, v=v, - layer_weight=layer_weight, + att_control=att_control, alloc_func=alloc_func, ) def _nomarl_prefill_att( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_weight, alloc_func=torch.empty + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty ) -> torch.Tensor: self.backend: Fa3AttBackend = self.backend # for typing + if att_control.use_sliding_window: + window_size = att_control.sliding_window + else: + window_size = (-1, -1) + + if att_control.use_att_sink: + sink_weight: torch.Tensor = att_control.sink_weight + else: + sink_weight = None + k_descale, v_descale = None, None # disable quantization Lq = q.shape[-1] sm_scale = 1.0 / (Lq ** 0.5) @@ -97,11 +106,12 @@ def _nomarl_prefill_att( max_seqlen_q=self.infer_state.max_q_seq_len, softmax_scale=sm_scale, causal=True, - window_size=(-1, -1), + window_size=window_size, softcap=0.0, k_descale=k_descale, v_descale=v_descale, return_softmax_lse=False, + sinks=sink_weight, ) return o @@ -184,7 +194,6 @@ def decode_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, att_control: AttControl = AttControl(), alloc_func=torch.empty, ): @@ -193,7 +202,7 @@ def decode_att( q=q, k=k, v=v, - layer_weight=layer_weight, + att_control=att_control, alloc_func=alloc_func, ) @@ -202,9 +211,19 @@ def _normal_decode_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, + att_control: AttControl, alloc_func=torch.empty, ): + if att_control.use_sliding_window: + window_size = att_control.sliding_window + else: + window_size = (-1, -1) + + if att_control.use_att_sink: + sink_weight: torch.Tensor = att_control.sink_weight + else: + sink_weight = None + k_descale, v_descale = None, None # disable quantization Lq = q.shape[-1] sm_scale = 1.0 / (Lq ** 0.5) @@ -219,10 +238,11 @@ def _normal_decode_att( max_seqlen_q=self.decode_max_q_seq_len, softmax_scale=sm_scale, causal=True, - window_size=(-1, -1), + window_size=window_size, softcap=0.0, k_descale=k_descale, v_descale=v_descale, return_softmax_lse=False, + sinks=sink_weight, ) return o diff --git a/lightllm/common/basemodel/attention/flashinfer_backend.py b/lightllm/common/basemodel/attention/flashinfer_backend.py index 151eadf9a8..024e1e714d 100644 --- a/lightllm/common/basemodel/attention/flashinfer_backend.py +++ b/lightllm/common/basemodel/attention/flashinfer_backend.py @@ -92,21 +92,23 @@ def prefill_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, att_control: AttControl = AttControl(), alloc_func=torch.empty, ) -> torch.Tensor: - assert att_control.use_alibi is False + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) return self._nomarl_prefill_att( q=q, k=k, v=v, - layer_weight=layer_weight, alloc_func=alloc_func, ) def _nomarl_prefill_att( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_weight, alloc_func=torch.empty + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty ) -> torch.Tensor: self.backend: FlashInferAttBackend = self.backend # for typing o_tensor = alloc_func(q.shape, q.dtype, device="cuda") @@ -198,16 +200,18 @@ def decode_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, att_control: AttControl = AttControl(), alloc_func=torch.empty, ): - assert att_control.use_alibi is False + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) return self._normal_decode_att( q=q, k=k, v=v, - layer_weight=layer_weight, alloc_func=alloc_func, ) @@ -216,7 +220,6 @@ def _normal_decode_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, alloc_func=torch.empty, ): o_tensor = alloc_func(q.shape, q.dtype) diff --git a/lightllm/common/basemodel/attention/fp8_fa3_backend.py b/lightllm/common/basemodel/attention/fp8_fa3_backend.py index 2af794db43..a03cf26d17 100644 --- a/lightllm/common/basemodel/attention/fp8_fa3_backend.py +++ b/lightllm/common/basemodel/attention/fp8_fa3_backend.py @@ -9,6 +9,7 @@ from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops +from typing import Union if HAS_VLLM: scaled_fp8_quant = vllm_ops.scaled_fp8_quant @@ -43,6 +44,16 @@ def create_att_prefill_state(self, infer_state) -> "Fp8Fa3PrefillAttState": def create_att_decode_state(self, infer_state) -> "Fp8Fa3DecodeAttState": return Fp8Fa3DecodeAttState(backend=self, infer_state=infer_state) + def _find_layer_index( + self, k: torch.Tensor, v: torch.Tensor, att_state: Union["Fp8Fa3PrefillAttState", "Fp8Fa3DecodeAttState"] + ) -> int: + kv_buffer = att_state.infer_state.mem_manager.kv_buffer + layer_count = len(kv_buffer) + find_dict = {kv_buffer[i].data_ptr(): i for i in range(layer_count)} + key = min(k.data_ptr(), v.data_ptr()) + assert key in find_dict + return find_dict[key] + @dataclasses.dataclass class Fp8Fa3PrefillAttState(BasePrefillAttState): @@ -105,21 +116,23 @@ def prefill_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, att_control: AttControl = AttControl(), alloc_func=torch.empty, ) -> torch.Tensor: - assert att_control.use_alibi is False + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) return self._fp8_prefill_att( q=q, k=k, v=v, - layer_weight=layer_weight, alloc_func=alloc_func, ) def _fp8_prefill_att( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_weight, alloc_func=torch.empty + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty ) -> torch.Tensor: self.backend: Fp8Fa3AttBackend = self.backend # for typing @@ -133,6 +146,7 @@ def _fp8_prefill_att( k_head_dim = k.shape[2] cache_k = k.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn) cache_v = v.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn) + layer_index = self.backend._find_layer_index(k=cache_k, v=cache_v, att_state=self) o = flash_attn_with_kvcache( q=q, k_cache=cache_k, @@ -146,8 +160,8 @@ def _fp8_prefill_att( window_size=(-1, -1), softcap=0.0, q_descale=q_scale, - k_descale=self.k_descale[layer_weight.layer_num_], - v_descale=self.v_descale[layer_weight.layer_num_], + k_descale=self.k_descale[layer_index], + v_descale=self.v_descale[layer_index], return_softmax_lse=False, ) return o @@ -261,16 +275,18 @@ def decode_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, att_control: AttControl = AttControl(), alloc_func=torch.empty, ): - assert att_control.use_alibi is False + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) return self._fp8_decode_att( q=q, k=k, v=v, - layer_weight=layer_weight, alloc_func=alloc_func, ) @@ -279,7 +295,6 @@ def _fp8_decode_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, alloc_func=torch.empty, ): k_head_num = k.shape[1] @@ -288,6 +303,8 @@ def _fp8_decode_att( cache_k = k.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn) cache_v = v.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn) + layer_index = self.backend._find_layer_index(k=cache_k, v=cache_v, att_state=self) + q_head_num = q.shape[1] q, q_scale = scaled_fp8_quant(q.view(q.shape[0] * k_head_num, -1), use_per_token_if_dynamic=True) o = flash_attn_with_kvcache( @@ -303,8 +320,8 @@ def _fp8_decode_att( window_size=(-1, -1), softcap=0.0, q_descale=q_scale.view(self.infer_state.batch_size, k_head_num), - k_descale=self.k_descale[layer_weight.layer_num_], - v_descale=self.v_descale[layer_weight.layer_num_], + k_descale=self.k_descale[layer_index], + v_descale=self.v_descale[layer_index], return_softmax_lse=False, ) return o diff --git a/lightllm/common/basemodel/attention/fp8_flashinfer_backend.py b/lightllm/common/basemodel/attention/fp8_flashinfer_backend.py index 288264fa6a..c334cf6193 100644 --- a/lightllm/common/basemodel/attention/fp8_flashinfer_backend.py +++ b/lightllm/common/basemodel/attention/fp8_flashinfer_backend.py @@ -29,29 +29,31 @@ def prefill_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, att_control: AttControl = AttControl(), alloc_func=torch.empty, ) -> torch.Tensor: - assert att_control.use_alibi is False + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) return self._fp8_prefill_att( q=q, k=k, v=v, - layer_weight=layer_weight, alloc_func=alloc_func, ) def _fp8_prefill_att( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_weight, alloc_func=torch.empty + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty ) -> torch.Tensor: o_tensor = alloc_func(q.shape, q.dtype, device="cuda") k = k.unsqueeze(1).view(torch.float8_e4m3fn) v = v.unsqueeze(1).view(torch.float8_e4m3fn) - + layer_index = self.backend._find_layer_index(k=k, v=v, att_state=self) offline_scales = self.offline_scales - k_descale = offline_scales[layer_weight.layer_num_][0] if offline_scales is not None else None - v_descale = offline_scales[layer_weight.layer_num_][1] if offline_scales is not None else None + k_descale = offline_scales[layer_index][0] if offline_scales is not None else None + v_descale = offline_scales[layer_index][1] if offline_scales is not None else None self.prefill_wrapper.run( q, (k, v), @@ -75,16 +77,18 @@ def decode_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, att_control: AttControl = AttControl(), alloc_func=torch.empty, ): - assert att_control.use_alibi is False + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) return self._fp8_decode_att( q=q, k=k, v=v, - layer_weight=layer_weight, alloc_func=alloc_func, ) @@ -93,7 +97,6 @@ def _fp8_decode_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, alloc_func=torch.empty, ): o_tensor = alloc_func(q.shape, q.dtype, device="cuda") @@ -101,8 +104,10 @@ def _fp8_decode_att( k = k.unsqueeze(1).view(torch.float8_e4m3fn) v = v.unsqueeze(1).view(torch.float8_e4m3fn) offline_scales = self.offline_scales - k_descale = offline_scales[layer_weight.layer_num_][0] if offline_scales is not None else None - v_descale = offline_scales[layer_weight.layer_num_][1] if offline_scales is not None else None + layer_index = self.backend._find_layer_index(k=k, v=v, att_state=self) + + k_descale = offline_scales[layer_index][0] if offline_scales is not None else None + v_descale = offline_scales[layer_index][1] if offline_scales is not None else None self.decode_wrapper.run( q, (k, v), diff --git a/lightllm/common/basemodel/attention/int4kv_triton_backend.py b/lightllm/common/basemodel/attention/int4kv_triton_backend.py index 712eb9d608..2b0a075480 100644 --- a/lightllm/common/basemodel/attention/int4kv_triton_backend.py +++ b/lightllm/common/basemodel/attention/int4kv_triton_backend.py @@ -37,15 +37,17 @@ def prefill_att( q: torch.Tensor, k: Tuple[torch.Tensor, torch.Tensor], v: Tuple[torch.Tensor, torch.Tensor], - layer_weight, att_control: AttControl = AttControl(), alloc_func=torch.empty, ) -> torch.Tensor: - assert att_control.use_alibi is False + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) self.backend: Int4kvTritonAttBackend = self.backend # for typing - if self.backend.quant_group_size == 8: - pass + k, k_scale = k v, v_scale = v o = self._groupsize_quant_prefill_att( @@ -54,7 +56,6 @@ def prefill_att( k_scale=k_scale, v=v, v_scale=v_scale, - layer_weight=layer_weight, alloc_func=alloc_func, ) return o @@ -66,7 +67,6 @@ def _groupsize_quant_prefill_att( k_scale: torch.Tensor, v: torch.Tensor, v_scale: torch.Tensor, - layer_weight, alloc_func=torch.empty, ) -> torch.Tensor: # o_tensor = alloc_func(q.shape, q.dtype, device=q.device) @@ -128,11 +128,14 @@ def decode_att( q: torch.Tensor, k: Tuple[torch.Tensor, torch.Tensor], v: Tuple[torch.Tensor, torch.Tensor], - layer_weight, att_control: AttControl = AttControl(), alloc_func=torch.empty, ): - assert att_control.use_alibi is False + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) k, k_scale = k v, v_scale = v @@ -142,7 +145,6 @@ def decode_att( k_scale=k_scale, v=v, v_scale=v_scale, - layer_weight=layer_weight, alloc_func=alloc_func, ) @@ -153,7 +155,6 @@ def ppl_int4kv_decode_att( k_scale: torch.Tensor, v: torch.Tensor, v_scale: torch.Tensor, - layer_weight, alloc_func=torch.empty, ) -> torch.Tensor: from ..triton_kernel.att.decode_att.int4kv.ppl_int4kv_flash_decoding import ( diff --git a/lightllm/common/basemodel/attention/int8kv_triton_backend.py b/lightllm/common/basemodel/attention/int8kv_triton_backend.py index f4b710fe7d..bf0ffdfdc9 100644 --- a/lightllm/common/basemodel/attention/int8kv_triton_backend.py +++ b/lightllm/common/basemodel/attention/int8kv_triton_backend.py @@ -38,15 +38,17 @@ def prefill_att( q: torch.Tensor, k: Tuple[torch.Tensor, torch.Tensor], v: Tuple[torch.Tensor, torch.Tensor], - layer_weight, att_control: AttControl = AttControl(), alloc_func=torch.empty, ) -> torch.Tensor: - assert att_control.use_alibi is False + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) self.backend: Int8kvTritonAttBackend = self.backend # for typing - if self.backend.quant_group_size == 8: - pass + k, k_scale = k v, v_scale = v o = self._groupsize_quant_prefill_att( @@ -55,7 +57,6 @@ def prefill_att( k_scale=k_scale, v=v, v_scale=v_scale, - layer_weight=layer_weight, alloc_func=alloc_func, ) return o @@ -67,7 +68,6 @@ def _groupsize_quant_prefill_att( k_scale: torch.Tensor, v: torch.Tensor, v_scale: torch.Tensor, - layer_weight, alloc_func=torch.empty, ) -> torch.Tensor: # o_tensor = alloc_func(q.shape, q.dtype, device=q.device) @@ -129,17 +129,18 @@ def decode_att( q: torch.Tensor, k: Tuple[torch.Tensor, torch.Tensor], v: Tuple[torch.Tensor, torch.Tensor], - layer_weight, att_control: AttControl = AttControl(), alloc_func=torch.empty, ): - assert att_control.use_alibi is False + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) k, k_scale = k v, v_scale = v if enable_diverse_mode_gqa_decode_fast_kernel(): - return self.diverse_decode_att( - q=q, k=k, k_scale=k_scale, v=v, v_scale=v_scale, layer_weight=layer_weight, alloc_func=alloc_func - ) + return self.diverse_decode_att(q=q, k=k, k_scale=k_scale, v=v, v_scale=v_scale, alloc_func=alloc_func) else: return self.ppl_mha_int8kv_decode_att( q=q, @@ -147,7 +148,6 @@ def decode_att( k_scale=k_scale, v=v, v_scale=v_scale, - layer_weight=layer_weight, alloc_func=alloc_func, ) @@ -158,7 +158,6 @@ def diverse_decode_att( k_scale: torch.Tensor, v: torch.Tensor, v_scale: torch.Tensor, - layer_weight, alloc_func=torch.empty, ) -> torch.Tensor: @@ -183,7 +182,6 @@ def ppl_mha_int8kv_decode_att( k_scale: torch.Tensor, v: torch.Tensor, v_scale: torch.Tensor, - layer_weight, alloc_func=torch.empty, ) -> torch.Tensor: from ..triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding import ( diff --git a/lightllm/common/basemodel/attention/triton_backend.py b/lightllm/common/basemodel/attention/triton_backend.py index 7b6d3fd918..ec19f16997 100644 --- a/lightllm/common/basemodel/attention/triton_backend.py +++ b/lightllm/common/basemodel/attention/triton_backend.py @@ -25,21 +25,22 @@ def prefill_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, att_control: AttControl = AttControl(), alloc_func=torch.empty, ) -> torch.Tensor: + assert att_control.use_sliding_window is False and att_control.use_att_sink is False if att_control.use_alibi: - return self._alibi_prefill_att(q=q, k=k, v=v, layer_weight=layer_weight, alloc_func=alloc_func) + assert att_control.tp_alibi is not None + return self._alibi_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) else: - return self._nomarl_prefill_att(q=q, k=k, v=v, layer_weight=layer_weight, alloc_func=alloc_func) + return self._nomarl_prefill_att(q=q, k=k, v=v, alloc_func=alloc_func) def _alibi_prefill_att( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, + att_control: AttControl, alloc_func=torch.empty, ): out = alloc_func(q.shape, q.dtype) @@ -52,7 +53,7 @@ def _alibi_prefill_att( v, out, self.infer_state.b_req_idx, - layer_weight.tp_alibi, + att_control.tp_alibi, self.infer_state.b_start_loc, self.infer_state.b_seq_len, self.infer_state.b_ready_cache_len, @@ -61,9 +62,7 @@ def _alibi_prefill_att( ) return out - def _nomarl_prefill_att( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_weight, alloc_func=torch.empty - ): + def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty): from ..triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd out = alloc_func(q.shape, q.dtype) @@ -95,23 +94,20 @@ def decode_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, att_control: AttControl = AttControl(), alloc_func=torch.empty, ): + assert att_control.use_sliding_window is False and att_control.use_att_sink is False if att_control.use_alibi: - return self._alibi_decode_att(q=q, k=k, v=v, layer_weight=layer_weight, alloc_func=alloc_func) + assert att_control.tp_alibi is not None + return self._alibi_decode_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) else: q_head_num = q.shape[1] k_head_num = k.shape[1] if q_head_num == k_head_num: - return self._normal_decode_flash_decoding_att( - q=q, k=k, v=v, layer_weight=layer_weight, alloc_func=alloc_func - ) + return self._normal_decode_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func) elif q_head_num > k_head_num: - return self._normal_decode_gqa_flash_decoding_att( - q=q, k=k, v=v, layer_weight=layer_weight, alloc_func=alloc_func - ) + return self._normal_decode_gqa_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func) else: raise NotImplementedError("error") @@ -120,7 +116,7 @@ def _alibi_decode_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, + att_control: AttControl, alloc_func=torch.empty, ): from ..triton_kernel.alibi_att.token_flashattention_nopad import token_attention_fwd @@ -131,7 +127,7 @@ def _alibi_decode_att( k, v, out, - layer_weight.tp_alibi, + att_control.tp_alibi, self.infer_state.req_manager.req_to_token_indexs, self.infer_state.b_req_idx, self.infer_state.b_start_loc, @@ -147,7 +143,6 @@ def _normal_decode_flash_decoding_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, alloc_func=torch.empty, ): from ..triton_kernel.att.decode_att.mha.flash_decoding.flash_decoding import ( @@ -171,7 +166,6 @@ def _normal_decode_gqa_flash_decoding_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, alloc_func=torch.empty, ): from ..triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding import ( @@ -196,7 +190,6 @@ def _normal_decode_gqa_flash_decoding_att_vsm( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, alloc_func=torch.empty, ): # TODO USE , 在特定场景下比 _normal_decode_gqa_flash_decoding_att 省显存 @@ -245,7 +238,6 @@ def _normal_decode_stage3_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer_weight, alloc_func=torch.empty, ): total_token_num = self.infer_state.total_token_num From ae871837ed25e12dd421bb176bc76a1f6c2fcd00 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 12:32:40 +0000 Subject: [PATCH 052/114] fix --- .../layer_infer/transformer_layer_infer.py | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py index d156c09279..84d9b96cf8 100755 --- a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py @@ -20,27 +20,6 @@ def __init__(self, layer_num, network_config, mode): self.embed_dim_ = network_config["n_embed"] return - def _att_norm( - self, input: torch.Tensor, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight - ) -> torch.Tensor: - return layer_weight.att_norm_weight_.layernorm_forward( - input=input.view(-1, self.embed_dim_), eps=self.eps_, alloc_func=self.alloc_tensor - ) - - def _ffn_norm( - self, input: torch.Tensor, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight - ) -> torch.Tensor: - return layer_weight.ffn_norm_weight_.layernorm_forward( - input=input.view(-1, self.embed_dim_), eps=self.eps_, alloc_func=self.alloc_tensor - ) - - def _get_qkv( - self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight - ) -> Tuple[torch.Tensor, torch.Tensor]: - q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_)) - cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - return q, cache_kv - def _context_attention_kernel( self, q: torch.Tensor, @@ -58,7 +37,7 @@ def _context_attention_kernel( k=_k, v=_v, layer_weight=layer_weight, - att_control=AttControl(use_alibi=True), + att_control=AttControl(use_alibi=True, tp_alibi=layer_weight.tp_alibi), alloc_func=self.alloc_tensor, ) o_tensor = o_tensor.view(q.shape) @@ -76,11 +55,32 @@ def _token_attention_kernel( k=_k, v=_v, layer_weight=layer_weight, - att_control=AttControl(use_alibi=True), + att_control=AttControl(use_alibi=True, tp_alibi=layer_weight.tp_alibi), alloc_func=self.alloc_tensor, ) return o_tensor.view(q.shape) + def _att_norm( + self, input: torch.Tensor, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight + ) -> torch.Tensor: + return layer_weight.att_norm_weight_.layernorm_forward( + input=input.view(-1, self.embed_dim_), eps=self.eps_, alloc_func=self.alloc_tensor + ) + + def _ffn_norm( + self, input: torch.Tensor, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight + ) -> torch.Tensor: + return layer_weight.ffn_norm_weight_.layernorm_forward( + input=input.view(-1, self.embed_dim_), eps=self.eps_, alloc_func=self.alloc_tensor + ) + + def _get_qkv( + self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight + ) -> Tuple[torch.Tensor, torch.Tensor]: + q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_)) + cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + return q, cache_kv + def _get_o(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor: o_tensor = layer_weight.o_proj.mm(input.view(-1, self.tp_o_head_num_ * self.head_dim_)) return o_tensor From 9f504c6baeb0c5509359751c00b8a69f8cee3348 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 12:40:27 +0000 Subject: [PATCH 053/114] fix --- .../basemodel/attention/fp8_fa3_backend.py | 112 +----------------- 1 file changed, 6 insertions(+), 106 deletions(-) diff --git a/lightllm/common/basemodel/attention/fp8_fa3_backend.py b/lightllm/common/basemodel/attention/fp8_fa3_backend.py index a03cf26d17..b4aa55653f 100644 --- a/lightllm/common/basemodel/attention/fp8_fa3_backend.py +++ b/lightllm/common/basemodel/attention/fp8_fa3_backend.py @@ -10,6 +10,7 @@ from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops from typing import Union +from .fa3_backend import Fa3AttBackend, Fa3PrefillAttState, Fa3DecodeAttState if HAS_VLLM: scaled_fp8_quant = vllm_ops.scaled_fp8_quant @@ -17,26 +18,9 @@ scaled_fp8_quant = None -class Fp8Fa3AttBackend(BaseAttBackend): +class Fp8Fa3AttBackend(Fa3AttBackend): def __init__(self, model): super().__init__(model=model) - self.get_page_table_buffer() # init - - def get_page_table_buffer(self): - """ - 用于减少 decode graph 捕获的时候, 造成显存二次方增长的情况. - """ - model = self.model - if not hasattr(self, "_shared_page_table_buffer"): - self._shared_page_table_buffer = [ - torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( - get_current_device_id() - ), - torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( - get_current_device_id() - ), - ] - return self._shared_page_table_buffer def create_att_prefill_state(self, infer_state) -> "Fp8Fa3PrefillAttState": return Fp8Fa3PrefillAttState(backend=self, infer_state=infer_state) @@ -44,41 +28,16 @@ def create_att_prefill_state(self, infer_state) -> "Fp8Fa3PrefillAttState": def create_att_decode_state(self, infer_state) -> "Fp8Fa3DecodeAttState": return Fp8Fa3DecodeAttState(backend=self, infer_state=infer_state) - def _find_layer_index( - self, k: torch.Tensor, v: torch.Tensor, att_state: Union["Fp8Fa3PrefillAttState", "Fp8Fa3DecodeAttState"] - ) -> int: - kv_buffer = att_state.infer_state.mem_manager.kv_buffer - layer_count = len(kv_buffer) - find_dict = {kv_buffer[i].data_ptr(): i for i in range(layer_count)} - key = min(k.data_ptr(), v.data_ptr()) - assert key in find_dict - return find_dict[key] - @dataclasses.dataclass -class Fp8Fa3PrefillAttState(BasePrefillAttState): - cu_seqlens_q: torch.Tensor = None - cu_seqlens_k: torch.Tensor = None - page_table: torch.Tensor = None +class Fp8Fa3PrefillAttState(Fa3PrefillAttState): # 临时共享变量 mid_token_batch_ids: torch.Tensor = None k_descale: torch.Tensor = None v_descale: torch.Tensor = None def init_state(self): - self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() - self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() - self.page_table = torch.empty( - (self.infer_state.batch_size, self.infer_state.max_kv_seq_len), - dtype=torch.int32, - device=self.infer_state.input_ids.device, - ) - self.page_table.copy_( - self.infer_state.req_manager.req_to_token_indexs[ - self.infer_state.b_req_idx, : self.infer_state.max_kv_seq_len - ] - ) - + super().init_state() device = self.infer_state.input_ids.device batch_size = self.infer_state.batch_size mem_manager = self.backend.model.mem_manager @@ -168,77 +127,18 @@ def _fp8_prefill_att( @dataclasses.dataclass -class Fp8Fa3DecodeAttState(BaseDecodeAttState): - cu_seqlens_q: torch.Tensor = None - cu_seqlens_k: torch.Tensor = None - page_table: torch.Tensor = None - b_att_seq_len: torch.Tensor = None - # 在是否开启mtp 的不同模式下,其设置不同的值,可以加速算子的运行。 - decode_max_q_seq_len: int = None - +class Fp8Fa3DecodeAttState(Fa3DecodeAttState): k_descale: torch.Tensor = None v_descale: torch.Tensor = None def init_state(self): + super().init_state() self.backend: Fp8Fa3AttBackend = self.backend args_mtp_step = get_env_start_args().mtp_step - if args_mtp_step > 0: - # 修正 mtp 在 fa3 下的输入。 - mtp_size = args_mtp_step + 1 - b_q_seq_len = torch.full( - (self.infer_state.b_seq_len.shape[0] // mtp_size,), - fill_value=mtp_size, - dtype=torch.int32, - device=self.infer_state.b_seq_len.device, - ) - b_kv_seq_len = self.infer_state.b_seq_len[mtp_size - 1 :: mtp_size] - b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor( - b_q_seq_len, b_kv_seq_len[mtp_size - 1 :: mtp_size] - ) - self.cu_seqlens_q = b1_cu_q_seq_len.int() - self.cu_seqlens_k = b1_cu_kv_seq_len.int() - else: - self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() - self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() - att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 - model = self.backend.model - # 可以使用 cuda graph的时候从 buffer中申请 - if ( - self.infer_state.batch_size <= model.graph_max_batch_size - and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch - ): - page_buffer = self.backend.get_page_table_buffer(model.graph_max_batch_size, model.graph_max_len_in_batch) - self.page_table = page_buffer[self.infer_state.microbatch_index][ - : att_batch_size * model.graph_max_len_in_batch - ].reshape(att_batch_size, model.graph_max_len_in_batch) - else: - self.page_table = torch.empty( - (att_batch_size, self.infer_state.max_kv_seq_len), - dtype=torch.int32, - device=self.infer_state.input_ids.device, - ) - - if args_mtp_step > 0: - page_table_copy( - page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], - req_to_token_indexs=model.req_manager.req_to_token_indexs, - b_req_idx=self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], - ) - self.b_att_seq_len = self.infer_state.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() - self.decode_max_q_seq_len = args_mtp_step + 1 - else: - page_table_copy( - page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], - req_to_token_indexs=model.req_manager.req_to_token_indexs, - b_req_idx=self.infer_state.b_req_idx, - ) - self.b_att_seq_len = self.infer_state.b_seq_len - self.decode_max_q_seq_len = 1 - device = self.infer_state.input_ids.device batch_size = att_batch_size mem_manager = self.backend.model.mem_manager From 031ed6fda95811ad766271d5e7bec4e2fc5e1b7c Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 12:45:58 +0000 Subject: [PATCH 054/114] fix --- .../bloom/layer_infer/transformer_layer_infer.py | 10 ++-------- .../llama/layer_infer/transformer_layer_infer.py | 5 +---- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py index 84d9b96cf8..0316c3652b 100755 --- a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py @@ -28,15 +28,12 @@ def _context_attention_kernel( layer_weight: BloomTransformerLayerWeight, out=None, ) -> torch.Tensor: - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) - _k = kv[:, 0 : self.tp_k_head_num_, :] - _v = kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] o_tensor = infer_state.prefill_att_state.prefill_att( q=_q, k=_k, v=_v, - layer_weight=layer_weight, att_control=AttControl(use_alibi=True, tp_alibi=layer_weight.tp_alibi), alloc_func=self.alloc_tensor, ) @@ -46,15 +43,12 @@ def _context_attention_kernel( def _token_attention_kernel( self, q: torch.Tensor, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None ) -> torch.Tensor: - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) - _k = kv[:, 0 : self.tp_k_head_num_, :] - _v = kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] o_tensor = infer_state.decode_att_state.decode_att( q=_q, k=_k, v=_v, - layer_weight=layer_weight, att_control=AttControl(use_alibi=True, tp_alibi=layer_weight.tp_alibi), alloc_func=self.alloc_tensor, ) diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 6bb00291d0..048101c998 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -50,7 +50,6 @@ def _context_attention_kernel( q=_q, k=_k, v=_v, - layer_weight=layer_weight, alloc_func=self.alloc_tensor, ) o_tensor = o_tensor.view(q.shape) @@ -64,9 +63,7 @@ def _token_attention_kernel( ) -> torch.Tensor: _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) - o_tensor = infer_state.decode_att_state.decode_att( - q=_q, k=_k, v=_v, layer_weight=layer_weight, alloc_func=self.alloc_tensor - ) + o_tensor = infer_state.decode_att_state.decode_att(q=_q, k=_k, v=_v, alloc_func=self.alloc_tensor) return o_tensor.view(q.shape) def _att_norm( From 3b3ad51256ad240c0c876e70e0c77a79a4f2e895 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 13:27:48 +0000 Subject: [PATCH 055/114] fix --- .../layer_infer/transformer_layer_infer.py | 200 ++---------------- 1 file changed, 17 insertions(+), 183 deletions(-) diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index ff20bc6ee6..703d494604 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -90,56 +90,32 @@ def _bind_ffn(self): self._tpsp_ffn = self._tpsp_ffn_tp def _bind_attention(self): - if "triton_fp8kv" in self.mode: - self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_fp8, self) + + self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) + if get_env_start_args().enable_fa3: + self._token_attention_kernel = partial( + Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashattention, self + ) + elif get_env_start_args().enable_flashinfer_decode: self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding_fp8, self + Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashinfer, self ) else: - self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) + self._token_attention_kernel = partial( + Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self + ) + if self.enable_cc_method: if get_env_start_args().enable_fa3: - self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashattention, self - ) - elif get_env_start_args().enable_flashinfer_decode: - self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashinfer, self - ) - else: - self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self + self._context_attention_kernel = partial( + Deepseek2TransformerLayerInfer._context_attention_flashattention_kernel_with_CC, self ) - if self.enable_cc_method: - if "triton_fp8kv" in self.mode: - if get_env_start_args().enable_flashinfer_prefill: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC_fp8, self - ) - else: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC_fp8, self - ) - else: - if get_env_start_args().enable_fa3: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_flashattention_kernel_with_CC, self - ) - elif get_env_start_args().enable_flashinfer_prefill: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC, self - ) - else: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self - ) - else: - if "triton_fp8kv" in self.mode: + elif get_env_start_args().enable_flashinfer_prefill: self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_kernel_origin_fp8, self + Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC, self ) else: self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_kernel_origin, self + Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self ) def _get_qkv( @@ -446,31 +422,6 @@ def _context_attention_flashinfer_kernel_with_CC( infer_state.prefill_wrapper.run(q, k, v, out=o_tensor) return o_tensor - def _context_attention_flashinfer_kernel_with_CC_fp8( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2FlashInferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - k_nope, k_rope, v = self._decompress_kv( - kv, - infer_state, - layer_weight, - True, - infer_state.total_token_num, - infer_state.b_seq_len, - infer_state.max_value_in_b_seq_len, - infer_state.b1_kv_start_loc, - ) - o_tensor = ( - self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out - ) - k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1) - infer_state.prefill_wrapper.run(q, k, v, out=o_tensor) - return o_tensor - def _context_attention_kernel_with_CC( self, q: torch.Tensor, @@ -507,100 +458,6 @@ def _context_attention_kernel_with_CC( ) return o_tensor - def _context_attention_kernel_with_CC_fp8( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2InferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - k_nope, k_rope, v = self._decompress_kv( - kv, - infer_state, - layer_weight, - True, - infer_state.total_token_num, - infer_state.b_seq_len, - infer_state.max_value_in_b_seq_len, - infer_state.b1_kv_start_loc, - ) - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out - context_attention_fwd_with_v( - q_nope, - q_rope, - k_nope, - k_rope, - v, - o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]), - infer_state.b_start_loc, - infer_state.b1_kv_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - self.softmax_scale, - ) - return o_tensor - - def _context_attention_kernel_origin( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2InferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - context_attention_fwd( - q_nope, - q_rope, - kv[:, :, : -self.qk_rope_head_dim], - kv[:, :, -self.qk_rope_head_dim :], - o_tensor.view(-1, self.tp_q_head_num_, self.kv_lora_rank), - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, - self.softmax_scale, - ) - return o_tensor - - def _context_attention_kernel_origin_fp8( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2InferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn) - kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16) - context_attention_fwd_fp8( - q_nope, - q_rope, - kv[:, :, : -self.qk_rope_head_dim], - kv[:, :, -self.qk_rope_head_dim :], - kv_scale, - o_tensor.view(-1, self.tp_q_head_num_, self.kv_lora_rank), - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, - self.softmax_scale, - ) - return o_tensor - def _token_gqa_decode_attention_flashattention( self, q, infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None ): @@ -670,29 +527,6 @@ def _token_gqa_decode_attention_flashdecoding( ) return out - def _token_gqa_decode_attention_flashdecoding_fp8( - self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None - ): - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - - kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn) - kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16) - return gqa_token_decode_attention_flash_decoding_fp8( - q_nope, - q_rope, - kv[:, :, : -self.qk_rope_head_dim], - kv[:, :, -self.qk_rope_head_dim :], - kv_scale, - infer_state, - self.tp_q_head_num_, - self.kv_lora_rank, - self.qk_rope_head_dim, - self.qk_nope_head_dim, - self.softmax_scale, - alloc_tensor_func=self.alloc_tensor, - ) - def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager): destindex_copy_kv( buffer[:, :, : self.kv_lora_rank], From 292d961cfb6c2479f0e2737bf7ab1bbc6bf5d913 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 13:57:09 +0000 Subject: [PATCH 056/114] add triton mla decode. --- .../common/basemodel/attention/base_att.py | 26 +- .../basemodel/attention/mla_triton_backend.py | 100 +++++++ .../triton_kernel/mla_att/__init__.py | 0 .../mla_att/decode_att/__init__.py | 1 + .../mla_att/decode_att/gqa_flash_decoding.py | 149 ++++++++++ .../decode_att/gqa_flash_decoding_config.py | 63 ++++ .../decode_att/gqa_flash_decoding_stage1.py | 274 ++++++++++++++++++ .../decode_att/gqa_flash_decoding_stage2.py | 91 ++++++ 8 files changed, 689 insertions(+), 15 deletions(-) create mode 100644 lightllm/common/basemodel/attention/mla_triton_backend.py create mode 100644 lightllm/common/basemodel/triton_kernel/mla_att/__init__.py create mode 100644 lightllm/common/basemodel/triton_kernel/mla_att/decode_att/__init__.py create mode 100644 lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py create mode 100644 lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_config.py create mode 100644 lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_stage1.py create mode 100644 lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_stage2.py diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index 22b5ff4870..167dd16011 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -1,7 +1,7 @@ import torch from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, TYPE_CHECKING, Tuple, Union +from typing import Optional, TYPE_CHECKING, Tuple, Union, Dict if TYPE_CHECKING: from lightllm.common.basemodel.basemodel import TpPartBaseModel @@ -55,6 +55,16 @@ class AttControl: """ use_alibi: bool = False + tp_alibi: torch.Tensor = None + use_sliding_window: bool = False + sliding_window: Tuple[int, int] = (-1, -1) + use_att_sink: bool = False + sink_weight: torch.Tensor = None + # mla 专用传参项 + mla_prefill: bool = False + mla_prefill_dict: Dict = None + mla_decode: bool = False + mla_decode_dict: Dict = None @dataclass @@ -106,17 +116,3 @@ def decode_att( alloc_func=torch.empty, ) -> torch.Tensor: pass - - -@dataclass -class AttControl: - """ - prefill_att 和 decode_att 的入参,用于控制att backend 内部的行为, 选择正确的att 实现。 - """ - - use_alibi: bool = False - tp_alibi: torch.Tensor = None - use_sliding_window: bool = False - sliding_window: Tuple[int, int] = (-1, -1) - use_att_sink: bool = False - sink_weight: torch.Tensor = None diff --git a/lightllm/common/basemodel/attention/mla_triton_backend.py b/lightllm/common/basemodel/attention/mla_triton_backend.py new file mode 100644 index 0000000000..e622afb7b4 --- /dev/null +++ b/lightllm/common/basemodel/attention/mla_triton_backend.py @@ -0,0 +1,100 @@ +import dataclasses +import torch +from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from typing import Tuple + + +class MlaTritonAttBackend(BaseAttBackend): + def create_att_prefill_state(self, infer_state) -> "MlaTritonPrefillAttState": + return MlaTritonPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "MlaTritonDecodeAttState": + return MlaTritonDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class MlaTritonPrefillAttState(BasePrefillAttState): + def init_state(self): + pass + + def copy_for_prefill_cuda_graph(self, new_state: "MlaTritonPrefillAttState"): + pass + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.use_sliding_window is False and att_control.use_att_sink is False + return self._mla_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + + def _mla_prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, + ): + pass + + +@dataclasses.dataclass +class MlaTritonDecodeAttState(BaseDecodeAttState): + def init_state(self): + pass + + def copy_for_decode_cuda_graph(self, new_state: "MlaTritonDecodeAttState"): + pass + + def decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert ( + att_control.use_sliding_window is False + and att_control.use_att_sink is False + and att_control.use_alibi is False + ) + assert v is None + q_nope, q_rope = q + return self._mla_decode_att( + q_nope=q_nope, + q_rope=q_rope, + kv=k, + att_control=att_control, + alloc_func=alloc_func, + ) + + def _mla_decode_att( + self, + q_nope: torch.Tensor, + q_rope: torch.Tensor, + kv: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, + ): + assert att_control.mla_decode + softmax_scale = att_control.mla_prefill_dict["softmax_scale"] + + from ..triton_kernel.mla_att.decode_att import gqa_token_decode_attention_flash_decoding + + qk_rope_head_dim = 64 + + out = gqa_token_decode_attention_flash_decoding( + q_nope=q_nope, + q_rope=q_rope, + kv_nope=kv[:, :, :qk_rope_head_dim], + kv_rope=kv[:, :, -qk_rope_head_dim:], + infer_state=self.infer_state, + softmax_scale=softmax_scale, + alloc_tensor_func=alloc_func, + ) + return out diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/__init__.py b/lightllm/common/basemodel/triton_kernel/mla_att/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/__init__.py b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/__init__.py new file mode 100644 index 0000000000..fb0609a401 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/__init__.py @@ -0,0 +1 @@ +from .gqa_flash_decoding import gqa_token_decode_attention_flash_decoding diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py new file mode 100644 index 0000000000..9d5f6bb8c9 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py @@ -0,0 +1,149 @@ +import os +import torch +import torch.multiprocessing as mp +import triton +import triton.language as tl +from typing import List +from lightllm.utils.log_utils import init_logger +from .gqa_flash_decoding_config import MlaDecodeAttentionKernelConfig +from lightllm.utils.device_utils import get_device_sm_count + +logger = init_logger(__name__) + + +def gqa_token_decode_attention_flash_decoding( + q_nope, q_rope, kv_nope, kv_rope, infer_state, softmax_scale, out=None, alloc_tensor_func=torch.empty, **run_config +): + batch_size = infer_state.batch_size + max_len_in_batch = infer_state.max_len_in_batch + + q_head_num, kv_lora_rank = q_nope.shape[1], q_nope.shape[2] + q_rope_dim = q_rope.shape[2] + assert q_rope_dim == 64 + + calcu_shape1 = (batch_size, q_head_num, kv_lora_rank) + calcu_shape2 = (batch_size, q_head_num, q_rope_dim) + + if not run_config: + if torch.cuda.is_current_stream_capturing(): + avg_seq_len_in_batch = max_len_in_batch + else: + avg_seq_len_in_batch = infer_state.total_token_num // batch_size + + run_config = MlaDecodeAttentionKernelConfig.try_to_get_best_config( + batch_size=batch_size, + avg_seq_len_in_batch=avg_seq_len_in_batch, + q_head_num=q_head_num, + q_head_dim=kv_lora_rank, + q_rope_dim=q_rope_dim, + out_dtype=torch.bfloat16, + ) + + BLOCK_N = run_config["BLOCK_N"] + + from .gqa_flash_decoding_stage1 import flash_decode_stage1 + from .gqa_flash_decoding_stage2 import flash_decode_stage2 + + o_tensor = alloc_tensor_func(q_nope.shape, q_nope.dtype, q_nope.device) if out is None else out + + fake_decode_att_block_seq = torch.empty([0], dtype=torch.int64, device="cuda") + mid_o = torch.empty([q_head_num, 0, kv_lora_rank], dtype=torch.float32, device="cuda") + mid_o_logexpsum = torch.empty([q_head_num, 0], dtype=torch.float32, device="cuda") + + vsm_count = flash_decode_stage1( + fake_decode_att_block_seq, + q_nope.view(calcu_shape1), + q_rope.view(calcu_shape2), + kv_nope, + kv_rope, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_seq_len, + mid_o, + mid_o_logexpsum, + softmax_scale, + get_sm_count=True, + **run_config + ) + + if not hasattr(infer_state, "decode_att_block_seq"): + assert batch_size <= 2048 + decode_att_block_seq = torch.empty( + [ + 1, + ], + dtype=torch.int64, + device="cuda", + ) + mid_o_batch_start_index = torch.empty( + [ + batch_size, + ], + dtype=torch.int64, + device="cuda", + ) + _fwd_kernel_calcu_index_and_block_seq[(1,)]( + infer_state.b_seq_len, + decode_att_block_seq, + mid_o_batch_start_index, + vsm_count, + batch_size, + BLOCK_N=BLOCK_N, + num_warps=4, + ) + + infer_state.decode_att_block_seq = decode_att_block_seq + infer_state.mid_o_batch_start_index = mid_o_batch_start_index + + mid_o = torch.empty([q_head_num, vsm_count * 4 + batch_size, kv_lora_rank], dtype=torch.float32, device="cuda") + mid_o_logexpsum = torch.empty([q_head_num, vsm_count * 4 + batch_size], dtype=torch.float32, device="cuda") + + flash_decode_stage1( + infer_state.decode_att_block_seq, + q_nope.view(calcu_shape1), + q_rope.view(calcu_shape2), + kv_nope, + kv_rope, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_seq_len, + mid_o, + mid_o_logexpsum, + softmax_scale, + get_sm_count=False, + **run_config + ) + + flash_decode_stage2( + infer_state.decode_att_block_seq, + infer_state.mid_o_batch_start_index, + mid_o, + mid_o_logexpsum, + infer_state.b_seq_len, + o_tensor.view(calcu_shape1), + **run_config + ) + return o_tensor + + +@triton.jit +def _fwd_kernel_calcu_index_and_block_seq( + b_seq_len_ptr, + mid_o_decode_att_block_seq_ptr, + mid_o_batch_start_index_ptr, + num_sm, + batch_size, + BLOCK_N: tl.constexpr, +): + b_seq_len = tl.load(b_seq_len_ptr + tl.arange(0, 2048), mask=tl.arange(0, 2048) < batch_size, other=0) + total_token_num = tl.sum(b_seq_len) + + block_seq = tl.cast(total_token_num / (num_sm * 4), dtype=tl.int32) + 1 + block_seq = tl.cdiv(block_seq, BLOCK_N) * BLOCK_N + + block_seq_len = tl.cdiv(b_seq_len, block_seq) + cumsum_seq_len = tl.cumsum(block_seq_len) + batch_start_index = cumsum_seq_len - block_seq_len + tl.store(mid_o_batch_start_index_ptr + tl.arange(0, 2048), batch_start_index, mask=tl.arange(0, 2048) < batch_size) + tl.store(mid_o_decode_att_block_seq_ptr, block_seq) + return diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_config.py b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_config.py new file mode 100644 index 0000000000..be99ca9bfc --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_config.py @@ -0,0 +1,63 @@ +from lightllm.common.kernel_config import KernelConfigs +from frozendict import frozendict +from functools import lru_cache +from typing import Dict + + +class MlaDecodeAttentionKernelConfig(KernelConfigs): + kernel_name: str = "mla_decode_attentnion" + + @classmethod + @lru_cache(maxsize=200) + def try_to_get_best_config( + cls, + batch_size: int, + avg_seq_len_in_batch: int, + q_head_num: int, + q_head_dim: int, + q_rope_dim: int, + out_dtype: str, + ) -> dict: + key_params = { + "q_head_num": q_head_num, + "q_head_dim": q_head_dim, + "q_rope_dim": q_rope_dim, + "out_dtype": str(out_dtype), + } + key_params = frozendict(key_params) + + finded_config = cls.get_the_config(key_params) + + if finded_config: + # two search dim, first is avg_seq_len_in_batch, second is batch_size + batch_size_config: dict = finded_config[ + min(finded_config.keys(), key=lambda x: abs(int(x) - avg_seq_len_in_batch)) + ] + config = batch_size_config[min(batch_size_config.keys(), key=lambda x: abs(int(x) - batch_size))] + + return config + else: + config = { + "BLOCK_N": 16, + "BLOCK_Q_HEAD": 16, + "stage1_num_warps": 4, + "stage1_num_stages": 2, + "stage2_num_warps": 4, + "stage2_num_stages": 2, + } + return config + + @classmethod + def save_config( + cls, q_head_num: int, q_head_dim: int, q_rope_dim: int, out_dtype: str, config_json: Dict[int, Dict[int, Dict]] + ): + + key_params = { + "q_head_num": q_head_num, + "q_head_dim": q_head_dim, + "q_rope_dim": q_rope_dim, + "out_dtype": str(out_dtype), + } + key_params = frozendict(key_params) + + return cls.store_config(key_params, config_json) diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_stage1.py new file mode 100644 index 0000000000..f5909fffde --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_stage1.py @@ -0,0 +1,274 @@ +import torch +import triton +import triton.language as tl +from lightllm.utils.device_utils import calcu_kernel_best_vsm_count + + +@triton.jit +def _fwd_kernel_flash_decode_stage1_padding( + Q_nope, + Q_rope, + KV_nope, + KV_rope, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + Mid_O, # [head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [head, seq_block_num] + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_q_bs, + stride_q_h, + stride_q_d, + stride_q_rope_bs, + stride_q_rope_h, + stride_q_rope_d, + stride_kv_bs, + stride_kv_h, + stride_kv_d, + stride_kv_rope_bs, + stride_kv_rope_h, + stride_kv_rope_d, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eh, + stride_mid_o_es, + block_size_ptr, + num_sm, + head_group_num, + head_num, + batch_size, + Q_HEAD_NUM: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_ROPE_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_STAGES: tl.constexpr, + NEED_HEAD_MASK: tl.constexpr, +): + # cur_kv_head = 0 + sm_id = tl.program_id(0).to(tl.int64) + out_batch_start_index = tl.cast(0, tl.int64) + block_seq = tl.load(block_size_ptr, eviction_policy="evict_last") + + cur_q_head_offs = tl.arange(0, Q_HEAD_NUM) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL) + + for cur_batch in range(batch_size): + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch, eviction_policy="evict_last") + cur_block_num = tl.cdiv(cur_batch_seq_len, block_seq) * head_group_num + cur_batch_req_idx = tl.load(B_req_idx + cur_batch, eviction_policy="evict_last") + req_to_tokens_ptr = Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + + while sm_id < cur_block_num: + loop_head_group_index = sm_id % head_group_num + loop_seq_block_index = sm_id // head_group_num + + cur_q_head_range = loop_head_group_index * Q_HEAD_NUM + cur_q_head_offs + if NEED_HEAD_MASK: + head_mask = cur_q_head_range < head_num + + cur_batch_start_index = block_seq * loop_seq_block_index + cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + block_seq) + + off_q = cur_batch * stride_q_bs + cur_q_head_range[:, None] * stride_q_h + offs_d[None, :] + off_rope_q = ( + cur_batch * stride_q_rope_bs + cur_q_head_range[:, None] * stride_q_rope_h + offs_rope_d[None, :] + ) + if NEED_HEAD_MASK: + q = tl.load( + Q_nope + off_q, + mask=head_mask[:, None], + other=0.0, + ) + q_rope = tl.load( + Q_rope + off_rope_q, + mask=head_mask[:, None], + other=0.0, + ) + else: + q = tl.load(Q_nope + off_q) + q_rope = tl.load(Q_rope + off_rope_q) + + block_n_size = tl.cdiv(cur_batch_end_index - cur_batch_start_index, BLOCK_N) + + offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) + sum_exp = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) + max_logic = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) - float("inf") + acc = tl.zeros([Q_HEAD_NUM, BLOCK_DMODEL], dtype=tl.float32) + for start_n in tl.range(0, block_n_size, 1, num_stages=NUM_STAGES): + offs_n_new = start_n * BLOCK_N + offs_n + seq_n_mask = offs_n_new < cur_batch_end_index + kv_loc = tl.load( + req_to_tokens_ptr + offs_n_new, + mask=seq_n_mask, + other=0, + ).to(tl.int64) + off_kv = kv_loc[None, :] * stride_kv_bs + offs_d[:, None] + kv = tl.load(KV_nope + off_kv, mask=seq_n_mask[None, :], other=0.0) + att_value = tl.dot(q, kv) + off_rope_kv = kv_loc[None, :] * stride_kv_rope_bs + offs_rope_d[:, None] + rope_kv = tl.load(KV_rope + off_rope_kv, mask=seq_n_mask[None, :], other=0.0) + att_value += tl.dot(q_rope, rope_kv) + + att_value *= sm_scale + att_value = tl.where(seq_n_mask[None, :], att_value, float("-inf")) + + cur_max_logic = tl.max(att_value, axis=1) + new_max_logic = tl.maximum(cur_max_logic, max_logic) + + exp_logic = tl.exp(att_value - new_max_logic[:, None]) + logic_scale = tl.exp(max_logic - new_max_logic) + acc *= logic_scale[:, None] + acc += tl.dot(exp_logic.to(kv.dtype), tl.trans(kv)) + + sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1) + max_logic = new_max_logic + + off_mid_o = ( + cur_q_head_range[:, None] * stride_mid_oh + + (out_batch_start_index + loop_seq_block_index) * stride_mid_os + + offs_d[None, :] + ) + off_mid_o_logexpsum = cur_q_head_range * stride_mid_o_eh + out_batch_start_index + loop_seq_block_index + if NEED_HEAD_MASK: + tl.store( + Mid_O + off_mid_o, + acc / sum_exp[:, None], + mask=head_mask[:, None], + ) + tl.store( + Mid_O_LogExpSum + off_mid_o_logexpsum, + max_logic + tl.log(sum_exp), + mask=head_mask, + ) + else: + tl.store( + Mid_O + off_mid_o, + acc / sum_exp[:, None], + ) + tl.store( + Mid_O_LogExpSum + off_mid_o_logexpsum, + max_logic + tl.log(sum_exp), + ) + sm_id += num_sm + + out_batch_start_index += cur_block_num // head_group_num + sm_id -= cur_block_num + return + + +@torch.no_grad() +def flash_decode_stage1( + in_block_seq: torch.Tensor, + q_nope, + q_rope, + kv_nope, + kv_rope, + Req_to_tokens, + B_req_idx, + B_Seqlen, + mid_out, + mid_out_logsumexp, + softmax_scale, + get_sm_count: bool = False, + **run_config, +): + if run_config: + Q_HEAD_NUM = run_config["BLOCK_Q_HEAD"] + BLOCK_N = run_config["BLOCK_N"] + num_warps = run_config["stage1_num_warps"] + num_stages = run_config["stage1_num_stages"] + + # shape constraints + q_nope_dim = q_nope.shape[-1] + q_rope_dim = q_rope.shape[-1] + + assert q_nope_dim == kv_nope.shape[-1] + assert q_rope_dim == kv_rope.shape[-1] + assert q_nope_dim in {16, 32, 64, 128, 256, 512} + assert q_rope_dim in {16, 32, 64, 128, 256} + assert kv_nope.shape[1] == 1 + + batch_size, q_head_num = B_req_idx.shape[0], q_nope.shape[1] + head_group_num = triton.cdiv(q_head_num, Q_HEAD_NUM) + NEED_HEAD_MASK = (q_head_num % Q_HEAD_NUM) != 0 + + kernel = _fwd_kernel_flash_decode_stage1_padding.warmup( + q_nope, + q_rope, + kv_nope, + kv_rope, + softmax_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + mid_out, + mid_out_logsumexp, + *Req_to_tokens.stride(), + *q_nope.stride(), + *q_rope.stride(), + *kv_nope.stride(), + *kv_rope.stride(), + *mid_out.stride(), + *mid_out_logsumexp.stride(), + in_block_seq, + num_sm=1, + head_group_num=head_group_num, + head_num=q_head_num, + batch_size=batch_size, + Q_HEAD_NUM=Q_HEAD_NUM, + BLOCK_DMODEL=q_nope_dim, + BLOCK_ROPE_DMODEL=q_rope_dim, + BLOCK_N=BLOCK_N, + NEED_HEAD_MASK=NEED_HEAD_MASK, + NUM_STAGES=num_stages, + num_warps=num_warps, + num_stages=1, + grid=(1,), + ) + + kernel._init_handles() + num_sm = calcu_kernel_best_vsm_count(kernel, num_warps=num_warps) + grid = (num_sm,) + if get_sm_count: + return num_sm + + assert num_sm * 4 + batch_size <= mid_out.shape[1] + + _fwd_kernel_flash_decode_stage1_padding[grid]( + q_nope, + q_rope, + kv_nope, + kv_rope, + softmax_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + mid_out, + mid_out_logsumexp, + *Req_to_tokens.stride(), + *q_nope.stride(), + *q_rope.stride(), + *kv_nope.stride(), + *kv_rope.stride(), + *mid_out.stride(), + *mid_out_logsumexp.stride(), + in_block_seq, + num_sm=num_sm, + head_group_num=head_group_num, + head_num=q_head_num, + batch_size=batch_size, + Q_HEAD_NUM=Q_HEAD_NUM, + BLOCK_DMODEL=q_nope_dim, + BLOCK_ROPE_DMODEL=q_rope_dim, + BLOCK_N=BLOCK_N, + NEED_HEAD_MASK=NEED_HEAD_MASK, + NUM_STAGES=num_stages, + num_warps=num_warps, + num_stages=1, + ) + + return diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_stage2.py b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_stage2.py new file mode 100644 index 0000000000..5b5e7b747c --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_stage2.py @@ -0,0 +1,91 @@ +import os +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + block_seq_ptr, + batch_start_index, + B_Seqlen, + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + Out, # [batch, head, head_dim] + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eh, + stride_mid_o_es, + stride_obs, + stride_oh, + stride_od, + BLOCK_DMODEL: tl.constexpr, + NUM_STAGES: tl.constexpr, +): + cur_head = tl.program_id(0) + cur_batch = tl.program_id(1) + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_start_index = tl.load(batch_start_index + cur_batch) + block_seq = tl.load(block_seq_ptr) + + block_n_size = tl.cdiv(cur_batch_seq_len, block_seq) + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + offs_v = cur_head * stride_mid_oh + cur_batch_start_index * stride_mid_os + offs_d + offs_logic = cur_head * stride_mid_o_eh + cur_batch_start_index + for block_seq_n in tl.range(0, block_n_size, 1, num_stages=NUM_STAGES): + tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) + tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) + new_max_logic = tl.maximum(tlogic, max_logic) + + old_scale = tl.exp(max_logic - new_max_logic) + acc *= old_scale + exp_logic = tl.exp(tlogic - new_max_logic) + acc += exp_logic * tv + sum_exp = sum_exp * old_scale + exp_logic + max_logic = new_max_logic + + tl.store(Out + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp) + return + + +@torch.no_grad() +def flash_decode_stage2( + out_block_seq: torch.Tensor, + batch_start_index: torch.Tensor, + mid_out, + mid_out_logexpsum, + B_Seqlen, + Out, + **run_config +): + if run_config: + num_warps = run_config["stage2_num_warps"] + num_stages = run_config["stage2_num_stages"] + + Lk = mid_out.shape[-1] + assert Lk in {16, 32, 64, 128, 256, 512} + batch, head_num = batch_start_index.shape[0], mid_out.shape[0] + grid = (head_num, batch) + + _fwd_kernel_flash_decode_stage2[grid]( + out_block_seq, + batch_start_index, + B_Seqlen, + mid_out, + mid_out_logexpsum, + Out, + *mid_out.stride(), + *mid_out_logexpsum.stride(), + *Out.stride(), + BLOCK_DMODEL=Lk, + NUM_STAGES=num_stages, + num_warps=num_warps, + num_stages=1, + ) + return From e383e29862e8d7d9b84e31d950d78d52bfd4ea4b Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 13:58:15 +0000 Subject: [PATCH 057/114] remove --- .../triton_kernel/gqa_flash_decoding.py | 156 ---------- .../gqa_flash_decoding_config.py | 63 ---- .../gqa_flash_decoding_stage1.py | 274 ------------------ .../gqa_flash_decoding_stage2.py | 91 ------ 4 files changed, 584 deletions(-) delete mode 100644 lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py delete mode 100644 lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_config.py delete mode 100644 lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py delete mode 100644 lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py deleted file mode 100644 index 256dfce5af..0000000000 --- a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py +++ /dev/null @@ -1,156 +0,0 @@ -import os -import torch -import torch.multiprocessing as mp -import triton -import triton.language as tl -from typing import List -from lightllm.utils.log_utils import init_logger -from .gqa_flash_decoding_config import MlaDecodeAttentionKernelConfig -from lightllm.utils.device_utils import get_device_sm_count - -logger = init_logger(__name__) - - -def gqa_token_decode_attention_flash_decoding( - q_nope, - q_rope, - kv_nope, - kv_rope, - infer_state, - q_head_num, - kv_lora_rank, - q_rope_dim, - qk_nope_head_dim, - softmax_scale, - out=None, - alloc_tensor_func=torch.empty, - **run_config -): - batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch - calcu_shape1 = (batch_size, q_head_num, kv_lora_rank) - calcu_shape2 = (batch_size, q_head_num, q_rope_dim) - - if not run_config: - if torch.cuda.is_current_stream_capturing(): - avg_seq_len_in_batch = max_len_in_batch - else: - avg_seq_len_in_batch = infer_state.total_token_num // batch_size - - run_config = MlaDecodeAttentionKernelConfig.try_to_get_best_config( - batch_size=batch_size, - avg_seq_len_in_batch=avg_seq_len_in_batch, - q_head_num=q_head_num, - q_head_dim=kv_lora_rank, - q_rope_dim=q_rope_dim, - out_dtype=torch.bfloat16, - ) - - BLOCK_N = run_config["BLOCK_N"] - - from .gqa_flash_decoding_stage1 import flash_decode_stage1 - from .gqa_flash_decoding_stage2 import flash_decode_stage2 - - o_tensor = alloc_tensor_func(q_nope.shape, q_nope.dtype, q_nope.device) if out is None else out - - fake_decode_att_block_seq = torch.empty([0], dtype=torch.int64, device="cuda") - mid_o = torch.empty([q_head_num, 0, kv_lora_rank], dtype=torch.float32, device="cuda") - mid_o_logexpsum = torch.empty([q_head_num, 0], dtype=torch.float32, device="cuda") - - vsm_count = flash_decode_stage1( - fake_decode_att_block_seq, - q_nope.view(calcu_shape1), - q_rope.view(calcu_shape2), - kv_nope, - kv_rope, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - mid_o, - mid_o_logexpsum, - softmax_scale, - get_sm_count=True, - **run_config - ) - - if not hasattr(infer_state, "decode_att_block_seq"): - assert batch_size <= 2048 - decode_att_block_seq = torch.empty( - [ - 1, - ], - dtype=torch.int64, - device="cuda", - ) - mid_o_batch_start_index = torch.empty( - [ - batch_size, - ], - dtype=torch.int64, - device="cuda", - ) - _fwd_kernel_calcu_index_and_block_seq[(1,)]( - infer_state.b_seq_len, - decode_att_block_seq, - mid_o_batch_start_index, - vsm_count, - batch_size, - BLOCK_N=BLOCK_N, - num_warps=4, - ) - - infer_state.decode_att_block_seq = decode_att_block_seq - infer_state.mid_o_batch_start_index = mid_o_batch_start_index - - mid_o = torch.empty([q_head_num, vsm_count * 4 + batch_size, kv_lora_rank], dtype=torch.float32, device="cuda") - mid_o_logexpsum = torch.empty([q_head_num, vsm_count * 4 + batch_size], dtype=torch.float32, device="cuda") - - flash_decode_stage1( - infer_state.decode_att_block_seq, - q_nope.view(calcu_shape1), - q_rope.view(calcu_shape2), - kv_nope, - kv_rope, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - mid_o, - mid_o_logexpsum, - softmax_scale, - get_sm_count=False, - **run_config - ) - - flash_decode_stage2( - infer_state.decode_att_block_seq, - infer_state.mid_o_batch_start_index, - mid_o, - mid_o_logexpsum, - infer_state.b_seq_len, - o_tensor.view(calcu_shape1), - **run_config - ) - return o_tensor - - -@triton.jit -def _fwd_kernel_calcu_index_and_block_seq( - b_seq_len_ptr, - mid_o_decode_att_block_seq_ptr, - mid_o_batch_start_index_ptr, - num_sm, - batch_size, - BLOCK_N: tl.constexpr, -): - b_seq_len = tl.load(b_seq_len_ptr + tl.arange(0, 2048), mask=tl.arange(0, 2048) < batch_size, other=0) - total_token_num = tl.sum(b_seq_len) - - block_seq = tl.cast(total_token_num / (num_sm * 4), dtype=tl.int32) + 1 - block_seq = tl.cdiv(block_seq, BLOCK_N) * BLOCK_N - - block_seq_len = tl.cdiv(b_seq_len, block_seq) - cumsum_seq_len = tl.cumsum(block_seq_len) - batch_start_index = cumsum_seq_len - block_seq_len - tl.store(mid_o_batch_start_index_ptr + tl.arange(0, 2048), batch_start_index, mask=tl.arange(0, 2048) < batch_size) - tl.store(mid_o_decode_att_block_seq_ptr, block_seq) - return diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_config.py b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_config.py deleted file mode 100644 index be99ca9bfc..0000000000 --- a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_config.py +++ /dev/null @@ -1,63 +0,0 @@ -from lightllm.common.kernel_config import KernelConfigs -from frozendict import frozendict -from functools import lru_cache -from typing import Dict - - -class MlaDecodeAttentionKernelConfig(KernelConfigs): - kernel_name: str = "mla_decode_attentnion" - - @classmethod - @lru_cache(maxsize=200) - def try_to_get_best_config( - cls, - batch_size: int, - avg_seq_len_in_batch: int, - q_head_num: int, - q_head_dim: int, - q_rope_dim: int, - out_dtype: str, - ) -> dict: - key_params = { - "q_head_num": q_head_num, - "q_head_dim": q_head_dim, - "q_rope_dim": q_rope_dim, - "out_dtype": str(out_dtype), - } - key_params = frozendict(key_params) - - finded_config = cls.get_the_config(key_params) - - if finded_config: - # two search dim, first is avg_seq_len_in_batch, second is batch_size - batch_size_config: dict = finded_config[ - min(finded_config.keys(), key=lambda x: abs(int(x) - avg_seq_len_in_batch)) - ] - config = batch_size_config[min(batch_size_config.keys(), key=lambda x: abs(int(x) - batch_size))] - - return config - else: - config = { - "BLOCK_N": 16, - "BLOCK_Q_HEAD": 16, - "stage1_num_warps": 4, - "stage1_num_stages": 2, - "stage2_num_warps": 4, - "stage2_num_stages": 2, - } - return config - - @classmethod - def save_config( - cls, q_head_num: int, q_head_dim: int, q_rope_dim: int, out_dtype: str, config_json: Dict[int, Dict[int, Dict]] - ): - - key_params = { - "q_head_num": q_head_num, - "q_head_dim": q_head_dim, - "q_rope_dim": q_rope_dim, - "out_dtype": str(out_dtype), - } - key_params = frozendict(key_params) - - return cls.store_config(key_params, config_json) diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py deleted file mode 100644 index f5909fffde..0000000000 --- a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py +++ /dev/null @@ -1,274 +0,0 @@ -import torch -import triton -import triton.language as tl -from lightllm.utils.device_utils import calcu_kernel_best_vsm_count - - -@triton.jit -def _fwd_kernel_flash_decode_stage1_padding( - Q_nope, - Q_rope, - KV_nope, - KV_rope, - sm_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, - Mid_O, # [head, seq_block_num, head_dim] - Mid_O_LogExpSum, # [head, seq_block_num] - stride_req_to_tokens_b, - stride_req_to_tokens_s, - stride_q_bs, - stride_q_h, - stride_q_d, - stride_q_rope_bs, - stride_q_rope_h, - stride_q_rope_d, - stride_kv_bs, - stride_kv_h, - stride_kv_d, - stride_kv_rope_bs, - stride_kv_rope_h, - stride_kv_rope_d, - stride_mid_oh, - stride_mid_os, - stride_mid_od, - stride_mid_o_eh, - stride_mid_o_es, - block_size_ptr, - num_sm, - head_group_num, - head_num, - batch_size, - Q_HEAD_NUM: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_ROPE_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - NUM_STAGES: tl.constexpr, - NEED_HEAD_MASK: tl.constexpr, -): - # cur_kv_head = 0 - sm_id = tl.program_id(0).to(tl.int64) - out_batch_start_index = tl.cast(0, tl.int64) - block_seq = tl.load(block_size_ptr, eviction_policy="evict_last") - - cur_q_head_offs = tl.arange(0, Q_HEAD_NUM) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL) - - for cur_batch in range(batch_size): - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch, eviction_policy="evict_last") - cur_block_num = tl.cdiv(cur_batch_seq_len, block_seq) * head_group_num - cur_batch_req_idx = tl.load(B_req_idx + cur_batch, eviction_policy="evict_last") - req_to_tokens_ptr = Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx - - while sm_id < cur_block_num: - loop_head_group_index = sm_id % head_group_num - loop_seq_block_index = sm_id // head_group_num - - cur_q_head_range = loop_head_group_index * Q_HEAD_NUM + cur_q_head_offs - if NEED_HEAD_MASK: - head_mask = cur_q_head_range < head_num - - cur_batch_start_index = block_seq * loop_seq_block_index - cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + block_seq) - - off_q = cur_batch * stride_q_bs + cur_q_head_range[:, None] * stride_q_h + offs_d[None, :] - off_rope_q = ( - cur_batch * stride_q_rope_bs + cur_q_head_range[:, None] * stride_q_rope_h + offs_rope_d[None, :] - ) - if NEED_HEAD_MASK: - q = tl.load( - Q_nope + off_q, - mask=head_mask[:, None], - other=0.0, - ) - q_rope = tl.load( - Q_rope + off_rope_q, - mask=head_mask[:, None], - other=0.0, - ) - else: - q = tl.load(Q_nope + off_q) - q_rope = tl.load(Q_rope + off_rope_q) - - block_n_size = tl.cdiv(cur_batch_end_index - cur_batch_start_index, BLOCK_N) - - offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) - sum_exp = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) - max_logic = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) - float("inf") - acc = tl.zeros([Q_HEAD_NUM, BLOCK_DMODEL], dtype=tl.float32) - for start_n in tl.range(0, block_n_size, 1, num_stages=NUM_STAGES): - offs_n_new = start_n * BLOCK_N + offs_n - seq_n_mask = offs_n_new < cur_batch_end_index - kv_loc = tl.load( - req_to_tokens_ptr + offs_n_new, - mask=seq_n_mask, - other=0, - ).to(tl.int64) - off_kv = kv_loc[None, :] * stride_kv_bs + offs_d[:, None] - kv = tl.load(KV_nope + off_kv, mask=seq_n_mask[None, :], other=0.0) - att_value = tl.dot(q, kv) - off_rope_kv = kv_loc[None, :] * stride_kv_rope_bs + offs_rope_d[:, None] - rope_kv = tl.load(KV_rope + off_rope_kv, mask=seq_n_mask[None, :], other=0.0) - att_value += tl.dot(q_rope, rope_kv) - - att_value *= sm_scale - att_value = tl.where(seq_n_mask[None, :], att_value, float("-inf")) - - cur_max_logic = tl.max(att_value, axis=1) - new_max_logic = tl.maximum(cur_max_logic, max_logic) - - exp_logic = tl.exp(att_value - new_max_logic[:, None]) - logic_scale = tl.exp(max_logic - new_max_logic) - acc *= logic_scale[:, None] - acc += tl.dot(exp_logic.to(kv.dtype), tl.trans(kv)) - - sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1) - max_logic = new_max_logic - - off_mid_o = ( - cur_q_head_range[:, None] * stride_mid_oh - + (out_batch_start_index + loop_seq_block_index) * stride_mid_os - + offs_d[None, :] - ) - off_mid_o_logexpsum = cur_q_head_range * stride_mid_o_eh + out_batch_start_index + loop_seq_block_index - if NEED_HEAD_MASK: - tl.store( - Mid_O + off_mid_o, - acc / sum_exp[:, None], - mask=head_mask[:, None], - ) - tl.store( - Mid_O_LogExpSum + off_mid_o_logexpsum, - max_logic + tl.log(sum_exp), - mask=head_mask, - ) - else: - tl.store( - Mid_O + off_mid_o, - acc / sum_exp[:, None], - ) - tl.store( - Mid_O_LogExpSum + off_mid_o_logexpsum, - max_logic + tl.log(sum_exp), - ) - sm_id += num_sm - - out_batch_start_index += cur_block_num // head_group_num - sm_id -= cur_block_num - return - - -@torch.no_grad() -def flash_decode_stage1( - in_block_seq: torch.Tensor, - q_nope, - q_rope, - kv_nope, - kv_rope, - Req_to_tokens, - B_req_idx, - B_Seqlen, - mid_out, - mid_out_logsumexp, - softmax_scale, - get_sm_count: bool = False, - **run_config, -): - if run_config: - Q_HEAD_NUM = run_config["BLOCK_Q_HEAD"] - BLOCK_N = run_config["BLOCK_N"] - num_warps = run_config["stage1_num_warps"] - num_stages = run_config["stage1_num_stages"] - - # shape constraints - q_nope_dim = q_nope.shape[-1] - q_rope_dim = q_rope.shape[-1] - - assert q_nope_dim == kv_nope.shape[-1] - assert q_rope_dim == kv_rope.shape[-1] - assert q_nope_dim in {16, 32, 64, 128, 256, 512} - assert q_rope_dim in {16, 32, 64, 128, 256} - assert kv_nope.shape[1] == 1 - - batch_size, q_head_num = B_req_idx.shape[0], q_nope.shape[1] - head_group_num = triton.cdiv(q_head_num, Q_HEAD_NUM) - NEED_HEAD_MASK = (q_head_num % Q_HEAD_NUM) != 0 - - kernel = _fwd_kernel_flash_decode_stage1_padding.warmup( - q_nope, - q_rope, - kv_nope, - kv_rope, - softmax_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, - mid_out, - mid_out_logsumexp, - *Req_to_tokens.stride(), - *q_nope.stride(), - *q_rope.stride(), - *kv_nope.stride(), - *kv_rope.stride(), - *mid_out.stride(), - *mid_out_logsumexp.stride(), - in_block_seq, - num_sm=1, - head_group_num=head_group_num, - head_num=q_head_num, - batch_size=batch_size, - Q_HEAD_NUM=Q_HEAD_NUM, - BLOCK_DMODEL=q_nope_dim, - BLOCK_ROPE_DMODEL=q_rope_dim, - BLOCK_N=BLOCK_N, - NEED_HEAD_MASK=NEED_HEAD_MASK, - NUM_STAGES=num_stages, - num_warps=num_warps, - num_stages=1, - grid=(1,), - ) - - kernel._init_handles() - num_sm = calcu_kernel_best_vsm_count(kernel, num_warps=num_warps) - grid = (num_sm,) - if get_sm_count: - return num_sm - - assert num_sm * 4 + batch_size <= mid_out.shape[1] - - _fwd_kernel_flash_decode_stage1_padding[grid]( - q_nope, - q_rope, - kv_nope, - kv_rope, - softmax_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, - mid_out, - mid_out_logsumexp, - *Req_to_tokens.stride(), - *q_nope.stride(), - *q_rope.stride(), - *kv_nope.stride(), - *kv_rope.stride(), - *mid_out.stride(), - *mid_out_logsumexp.stride(), - in_block_seq, - num_sm=num_sm, - head_group_num=head_group_num, - head_num=q_head_num, - batch_size=batch_size, - Q_HEAD_NUM=Q_HEAD_NUM, - BLOCK_DMODEL=q_nope_dim, - BLOCK_ROPE_DMODEL=q_rope_dim, - BLOCK_N=BLOCK_N, - NEED_HEAD_MASK=NEED_HEAD_MASK, - NUM_STAGES=num_stages, - num_warps=num_warps, - num_stages=1, - ) - - return diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py deleted file mode 100644 index 5b5e7b747c..0000000000 --- a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py +++ /dev/null @@ -1,91 +0,0 @@ -import os -import torch -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_flash_decode_stage2( - block_seq_ptr, - batch_start_index, - B_Seqlen, - Mid_O, # [batch, head, seq_block_num, head_dim] - Mid_O_LogExpSum, # [batch, head, seq_block_num] - Out, # [batch, head, head_dim] - stride_mid_oh, - stride_mid_os, - stride_mid_od, - stride_mid_o_eh, - stride_mid_o_es, - stride_obs, - stride_oh, - stride_od, - BLOCK_DMODEL: tl.constexpr, - NUM_STAGES: tl.constexpr, -): - cur_head = tl.program_id(0) - cur_batch = tl.program_id(1) - - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_index = tl.load(batch_start_index + cur_batch) - block_seq = tl.load(block_seq_ptr) - - block_n_size = tl.cdiv(cur_batch_seq_len, block_seq) - sum_exp = 0.0 - max_logic = -float("inf") - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - - offs_v = cur_head * stride_mid_oh + cur_batch_start_index * stride_mid_os + offs_d - offs_logic = cur_head * stride_mid_o_eh + cur_batch_start_index - for block_seq_n in tl.range(0, block_n_size, 1, num_stages=NUM_STAGES): - tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) - tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) - new_max_logic = tl.maximum(tlogic, max_logic) - - old_scale = tl.exp(max_logic - new_max_logic) - acc *= old_scale - exp_logic = tl.exp(tlogic - new_max_logic) - acc += exp_logic * tv - sum_exp = sum_exp * old_scale + exp_logic - max_logic = new_max_logic - - tl.store(Out + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp) - return - - -@torch.no_grad() -def flash_decode_stage2( - out_block_seq: torch.Tensor, - batch_start_index: torch.Tensor, - mid_out, - mid_out_logexpsum, - B_Seqlen, - Out, - **run_config -): - if run_config: - num_warps = run_config["stage2_num_warps"] - num_stages = run_config["stage2_num_stages"] - - Lk = mid_out.shape[-1] - assert Lk in {16, 32, 64, 128, 256, 512} - batch, head_num = batch_start_index.shape[0], mid_out.shape[0] - grid = (head_num, batch) - - _fwd_kernel_flash_decode_stage2[grid]( - out_block_seq, - batch_start_index, - B_Seqlen, - mid_out, - mid_out_logexpsum, - Out, - *mid_out.stride(), - *mid_out_logexpsum.stride(), - *Out.stride(), - BLOCK_DMODEL=Lk, - NUM_STAGES=num_stages, - num_warps=num_warps, - num_stages=1, - ) - return From ec503f7fc64f9a5be5195dc7e263c10693eeabb4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 14:08:28 +0000 Subject: [PATCH 058/114] fix deepseek --- .../deepseek2_mem_manager.py | 6 ++++- .../layer_infer/transformer_layer_infer.py | 27 +++++++------------ 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py index a5e8f4dd8e..3629acf973 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py @@ -3,7 +3,7 @@ import torch.distributed as dist from lightllm.server.pd_io_struct import KVMoveTask from .mem_manager import MemoryManager -from typing import List, Union +from typing import List, Union, Any from lightllm.utils.log_utils import init_logger from lightllm.common.kv_trans_kernel.kv_trans import kv_trans from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node @@ -36,6 +36,10 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: ) return + def get_att_input_params(self, layer_index: int) -> Any: + kv = self.kv_buffer[layer_index] + return kv + def get_cell_size(self): return self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 703d494604..2d42b21ce1 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -8,14 +8,11 @@ from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv import destindex_copy_kv from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 -from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd, -) +from lightllm.common.basemodel.attention.base_att import AttControl from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad_fp8 import context_attention_fwd_fp8 from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad_with_v import context_attention_fwd_with_v from lightllm.models.deepseek2.triton_kernel.sample_kv import sample_kv from lightllm.models.deepseek2.triton_kernel.repeat_rope import repeat_rope -from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding_fp8 import gqa_token_decode_attention_flash_decoding_fp8 from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd @@ -511,20 +508,16 @@ def _token_gqa_decode_attention_flashdecoding( ): q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - out = gqa_token_decode_attention_flash_decoding( - q_nope, - q_rope, - kv[:, :, : -self.qk_rope_head_dim], - kv[:, :, -self.qk_rope_head_dim :], - infer_state, - self.tp_q_head_num_, - self.kv_lora_rank, - self.qk_rope_head_dim, - self.qk_nope_head_dim, - self.softmax_scale, - alloc_tensor_func=self.alloc_tensor, + kv = infer_state.mem_manager.get_att_input_params() + + out = infer_state.decode_att_state.decode_att( + q=(q_nope, q_rope), + k=kv, + v=None, + att_control=AttControl(mla_decode=True, mla_decode_dict={"softmax_scale": self.softmax_scale}), + alloc_func=self.alloc_tensor, ) + return out def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager): From af8cd0ccee3552d69b741acd8ea04aca6e51c819 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 7 Jan 2026 14:09:59 +0000 Subject: [PATCH 059/114] fix --- .../models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py index fd437c3888..ed2f564b5a 100644 --- a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py +++ b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py @@ -5,7 +5,6 @@ import triton.language as tl from typing import List from lightllm.utils.log_utils import init_logger -from .gqa_flash_decoding_config import MlaDecodeAttentionKernelConfig from lightllm.utils.device_utils import get_device_sm_count logger = init_logger(__name__) @@ -38,6 +37,8 @@ def gqa_token_decode_attention_flash_decoding_fp8( else: avg_seq_len_in_batch = infer_state.total_token_num // batch_size + from .gqa_flash_decoding_config import MlaDecodeAttentionKernelConfig + run_config = MlaDecodeAttentionKernelConfig.try_to_get_best_config( batch_size=batch_size, avg_seq_len_in_batch=avg_seq_len_in_batch, From 198cc9bf4fe523ce98b2e6cc9f8705eaec8d5d53 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 02:18:15 +0000 Subject: [PATCH 060/114] add triton mla prefill --- .../basemodel/attention/mla_triton_backend.py | 34 ++- .../mla_att/prefill_att/__init__.py | 1 + .../context_flashattention_nopad_with_v.py | 278 ++++++++++++++++++ .../layer_infer/transformer_layer_infer.py | 80 ++--- .../deepseek2/triton_kernel/sample_kv.py | 151 +++++----- 5 files changed, 407 insertions(+), 137 deletions(-) create mode 100644 lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/__init__.py create mode 100644 lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py diff --git a/lightllm/common/basemodel/attention/mla_triton_backend.py b/lightllm/common/basemodel/attention/mla_triton_backend.py index e622afb7b4..e735d9067e 100644 --- a/lightllm/common/basemodel/attention/mla_triton_backend.py +++ b/lightllm/common/basemodel/attention/mla_triton_backend.py @@ -23,23 +23,49 @@ def copy_for_prefill_cuda_graph(self, new_state: "MlaTritonPrefillAttState"): def prefill_att( self, q: torch.Tensor, - k: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], v: torch.Tensor, att_control: AttControl = AttControl(), alloc_func=torch.empty, ) -> torch.Tensor: - assert att_control.use_sliding_window is False and att_control.use_att_sink is False + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) return self._mla_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) def _mla_prefill_att( self, q: torch.Tensor, - k: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty, ): - pass + from ..triton_kernel.mla_att.prefill_att import context_attention_fwd_with_v + + qk_rope_head_dim = 64 + q_nope, q_rope = q[:, :, :-qk_rope_head_dim], q[:, :, -qk_rope_head_dim:] + o_tensor = alloc_func(q_nope.shape, dtype=q_nope.dtype, device=q.device) + k_nope, k_rope = k + assert att_control.mla_prefill + softmax_scale = att_control.mla_prefill_dict["softmax_scale"] + context_attention_fwd_with_v( + q_nope, + q_rope, + k_nope, + k_rope, + v, + o_tensor, + self.infer_state.b_start_loc, + self.infer_state.b1_cu_kv_seq_len, + self.infer_state.b_seq_len, + self.infer_state.b_ready_cache_len, + self.infer_state.max_q_seq_len, + softmax_scale, + ) + return o_tensor @dataclasses.dataclass diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/__init__.py b/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/__init__.py new file mode 100644 index 0000000000..5725bed2e7 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/__init__.py @@ -0,0 +1 @@ +from .context_flashattention_nopad_with_v import context_attention_fwd_with_v diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py b/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py new file mode 100644 index 0000000000..be06351823 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py @@ -0,0 +1,278 @@ +import torch + +import triton +import triton.language as tl +import math +import torch.nn.functional as F +from lightllm.utils.device_utils import is_tesla + + +@triton.jit +def _fwd_kernel_with_v( + Q_nope, + Q_rope, + K_nope, + K_rope, + V, + sm_scale, + B_Start_Loc, + B_Kv_Start_Loc, + B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 + Out, + stride_q_bs, + stride_q_h, + stride_q_rope_bs, + stride_q_rope_h, + stride_k_bs, + stride_k_h, + stride_k_rope_bs, + stride_k_rope_h, + stride_vbs, + stride_vh, + stride_obs, + stride_oh, + b_prompt_cache_len, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_ROPE_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_k_head = cur_head + + cur_batch_in_q_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_kv_start_index = tl.load(B_Kv_Start_Loc + cur_batch) + prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_bs + cur_head * stride_q_h + offs_d[None, :] + off_q_rope = ( + (cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_rope_bs + + cur_head * stride_q_rope_h + + offs_rope_d[None, :] + ) + off_k = offs_n[None, :] * stride_k_bs + cur_k_head * stride_k_h + offs_d[:, None] + off_k_rope = offs_n[None, :] * stride_k_rope_bs + offs_rope_d[:, None] + off_v = offs_n[:, None] * stride_vbs + cur_k_head * stride_vh + offs_d[None, :] + + q = tl.load(Q_nope + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + q_rope = tl.load(Q_rope + off_q_rope, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + + k_ptrs = K_nope + off_k + k_rope_ptrs = K_rope + off_k_rope + v_ptrs = V + off_v + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + block_end_loc = tl.minimum((start_m + 1) * BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len) + + for start_n in range(0, block_mask * block_end_loc, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load( + k_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_bs, + mask=(start_n + offs_n[None, :]) < block_end_loc, + other=0.0, + ) + k_rope = tl.load( + k_rope_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_rope_bs, + mask=(start_n + offs_n[None, :]) < block_end_loc, + other=0.0, + ) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk += tl.dot(q_rope, k_rope) + qk *= sm_scale + qk = tl.where(offs_m[:, None] + prompt_cache_len >= start_n + offs_n[None, :], qk, float("-100000000.0")) + + # -- compute m_ij, p, l_ij + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + + # -- update m_i and l_i + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + # -- update output accumulator -- + acc = acc * alpha[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < block_end_loc, + other=0.0, + ) + p = p.to(v.dtype) + acc = tl.dot(p, v, acc) + # update m_i and l_i + m_i = m_ij + + acc = acc / l_i[:, None] + # initialize pointers to output + off_o = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return + + +@torch.no_grad() +def context_attention_fwd_with_v( + q_nope, + q_rope, + k_nope, + k_rope, + v, + o, + b_start_loc, + b_kv_start_loc, + b_seq_len, + b_prompt_cache_len, + max_input_len, + softmax_scale, +): + + BLOCK = 128 if not is_tesla() else 64 + q_nope_dim = q_nope.shape[-1] + q_rope_dim = q_rope.shape[-1] + assert q_nope_dim == k_nope.shape[-1] + assert q_rope_dim == k_rope.shape[-1] + assert q_nope_dim in {16, 32, 64, 128, 256, 512} + assert q_rope_dim in {16, 32, 64, 128, 256} + assert q_nope_dim == v.shape[-1] + + if q_nope_dim >= 512: + BLOCK = 64 if not is_tesla() else 32 + else: + BLOCK = 128 if not is_tesla() else 64 + + if q_nope.dtype == torch.float32: + BLOCK = BLOCK // 4 + + sm_scale = softmax_scale * 1.4426950408889634 + batch, head = b_seq_len.shape[0], q_nope.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + num_warps = 4 if q_nope_dim <= 64 else 8 + + _fwd_kernel_with_v[grid]( + q_nope, + q_rope, + k_nope, + k_rope, + v, + sm_scale, + b_start_loc, + b_kv_start_loc, + b_seq_len, + o, + q_nope.stride(0), + q_nope.stride(1), + q_rope.stride(0), + q_rope.stride(1), + k_nope.stride(0), + k_nope.stride(1), + k_rope.stride(0), + k_rope.stride(1), + v.stride(0), + v.stride(1), + o.stride(0), + o.stride(1), + b_prompt_cache_len=b_prompt_cache_len, + BLOCK_M=BLOCK, + BLOCK_DMODEL=q_nope_dim, + BLOCK_ROPE_DMODEL=q_rope_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + +if __name__ == "__main__": + import torch + import flashinfer + + Z, N_CTX, H, D_HEAD, ROPE_HEAD = 32, 1024, 16, 128, 64 + dtype = torch.bfloat16 + + k_nope = torch.randn((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda") + k_rope = torch.randn((Z * N_CTX, 1, ROPE_HEAD), dtype=dtype, device="cuda") + k = torch.cat([k_nope, torch.repeat_interleave(k_rope, H, dim=-2)], dim=-1) + v = torch.randn((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda") + + max_input_len = Z * N_CTX + softmax_scale = 0.117 + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX + b_prompt_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda") + b_prompt_cache_len = torch.randint_like(b_seq_len, high=N_CTX - 1, dtype=torch.int32, device="cuda") + q_lens = b_seq_len - b_prompt_cache_len + q_start_loc = q_lens.cumsum(0) - q_lens + kv_start_loc = b_seq_len.cumsum(0) - b_seq_len + + q_nope = torch.randn((q_lens.sum(), H, D_HEAD), dtype=dtype, device="cuda") + q_rope = torch.randn((q_lens.sum(), H, ROPE_HEAD), dtype=dtype, device="cuda") + q = torch.cat([q_nope, q_rope], dim=-1) + + o = torch.empty((q_lens.sum(), H, D_HEAD), dtype=dtype, device="cuda") + o1 = torch.empty((q_lens.sum(), H, D_HEAD), dtype=dtype, device="cuda") + o2 = torch.empty((q_lens.sum(), H, D_HEAD), dtype=dtype, device="cuda") + + fn1 = lambda: context_attention_fwd_with_v( + q_nope, + q_rope, + k_nope, + k_rope, + v, + o, + q_start_loc, + kv_start_loc, + b_seq_len, + b_prompt_cache_len, + max_input_len, + softmax_scale, + ) + + q_starts = torch.zeros((Z + 1,)).int().cuda() + q_starts[1:] = torch.cumsum(b_seq_len - b_prompt_cache_len, dim=0) + kv_starts = torch.zeros_like(q_starts) + kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) + kv_layout = "NHD" + batch_size = Z + q_indptr = q_starts + kv_indptr = kv_starts + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(workspace_buffer, kv_layout) + wrapper.plan( + qo_indptr=q_indptr, + kv_indptr=kv_indptr, + num_qo_heads=H, + num_kv_heads=H, + head_dim_qk=D_HEAD + ROPE_HEAD, + head_dim_vo=D_HEAD, + q_data_type=dtype, + causal=True, + sm_scale=softmax_scale, + ) + fn2 = lambda: wrapper.run(q, k, v, out=o1) + + ms1 = triton.testing.do_bench(fn1) + ms2 = triton.testing.do_bench(fn2) + cos_sim1 = F.cosine_similarity(o, o1).mean() + print(cos_sim1) + print(ms1) + print(ms2) diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 2d42b21ce1..9347cf83fb 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -272,57 +272,37 @@ def _tpsp_get_o( def _decompress_kv( self, - kv, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, - is_fp8, - total_token_num, - b_seq_len, - max_seq_len, - b_kv_start_loc, - skip_sample=False, ): - if not skip_sample: - if is_fp8: - kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn) - kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16) - k_scale = self.alloc_tensor([total_token_num, 1], dtype=kv_scale.dtype) - else: - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - kv_scale = None - k_scale = None - - compressed_kv = self.alloc_tensor([total_token_num, 1, layer_weight.kv_lora_rank], dtype=kv.dtype) - k_rope = self.alloc_tensor([total_token_num, 1, self.qk_rope_head_dim], dtype=kv.dtype) - sample_kv( - kv, - compressed_kv, - k_rope, - infer_state.b_req_idx, - max_seq_len, - b_seq_len, - infer_state.req_manager.req_to_token_indexs, - b_kv_start_loc, - kv_scale, - k_scale, - ) - if k_scale is not None: - compressed_kv = compressed_kv.to(k_scale.dtype) * k_scale.unsqueeze(-1) - k_rope = k_rope.to(k_scale.dtype) * k_scale.unsqueeze(-1) - else: - compressed_kv, k_rope = torch.split( # (b*s, 1, kv_lora + qk_r) - kv, [layer_weight.kv_lora_rank, layer_weight.qk_rope_head_dim], dim=-1 - ) + compressed_kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + total_token_num = infer_state.total_token_num + sampled_compressed_kv_nope = self.alloc_tensor( + [total_token_num, 1, layer_weight.kv_lora_rank], dtype=compressed_kv.dtype + ) + sampled_k_rope = self.alloc_tensor([total_token_num, 1, self.qk_rope_head_dim], dtype=compressed_kv.dtype) + sample_kv( + all_compressed_kv=compressed_kv, + sampled_compressed_kv_nope=sampled_compressed_kv_nope, + sampled_k_rope=sampled_k_rope, + b_req_idx=infer_state.b_req_idx, + req_to_token_indexs=infer_state.req_manager.req_to_token_indexs, + b_seq_len=infer_state.b_seq_len, + b_kv_start_loc=infer_state.b1_kv_start_loc[:-1], + max_kv_seq_len=infer_state.max_kv_seq_len, + ) # CC - compressed_kv = compressed_kv.view(-1, layer_weight.kv_lora_rank).contiguous() - kv_nope = self.alloc_tensor( - [compressed_kv.shape[0], self.tp_q_head_num_, (self.qk_nope_head_dim + self.v_head_dim)], - dtype=compressed_kv.dtype, + sampled_compressed_kv_nope = sampled_compressed_kv_nope.view( + total_token_num, layer_weight.kv_lora_rank + ).contiguous() + sampled_kv_nope = self.alloc_tensor( + [total_token_num, self.tp_q_head_num_, (self.qk_nope_head_dim + self.v_head_dim)], + dtype=sampled_compressed_kv_nope.dtype, ) - layer_weight.cc_kv_b_proj_.mm(compressed_kv, out=kv_nope.reshape(compressed_kv.shape[0], -1)) - k_nope, v = torch.split(kv_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - return k_nope, k_rope, v + layer_weight.cc_kv_b_proj_.mm(sampled_compressed_kv_nope, out=sampled_kv_nope.view(total_token_num, -1)) + sampled_k_nope, sampled_v = torch.split(sampled_kv_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + return sampled_k_nope, sampled_k_rope, sampled_v # Adapted from: # https://github.com/sgl-project/sglang/blob/c998d04b46920f06d945fbef9023884a768723fc/python/sglang/srt/models/deepseek_v2.py#L962 @@ -428,14 +408,8 @@ def _context_attention_kernel_with_CC( out=None, ) -> torch.Tensor: k_nope, k_rope, v = self._decompress_kv( - kv, - infer_state, - layer_weight, - False, - infer_state.total_token_num, - infer_state.b_seq_len, - infer_state.max_value_in_b_seq_len, - infer_state.b1_kv_start_loc, + infer_state=infer_state, + layer_weight=layer_weight, ) q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out diff --git a/lightllm/models/deepseek2/triton_kernel/sample_kv.py b/lightllm/models/deepseek2/triton_kernel/sample_kv.py index af0aaa2f66..ace118c9e5 100644 --- a/lightllm/models/deepseek2/triton_kernel/sample_kv.py +++ b/lightllm/models/deepseek2/triton_kernel/sample_kv.py @@ -8,111 +8,102 @@ @triton.jit def _sample_kv_kernel( - KV_input, - KV_scale, - KV_nope, - KV_rope, - K_scale, - B_start_loc, - B_Seqlen, - Req_to_tokens, - B_req_idx, - stride_input_dim, - stride_scale_dim, - stride_nope_dim, - stride_rope_dim, + all_compressed_kv, + stride_all_s, + stride_all_d, + sampled_compressed_kv_nope, + stride_nope_s, + stride_nope_d, + sampled_k_rope, + stride_rope_s, + stride_rope_d, + b_kv_start_loc, + b_seq_len, + req_to_token_indexs, stride_req_to_tokens_b, - HAS_SCALE: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_ROPE_DMODEL: tl.constexpr, + b_req_idx, + BLOCK_SEQ: tl.constexpr, + BLOCK_NOPE_DIM: tl.constexpr, + BLOCK_ROPE_DIM: tl.constexpr, ): cur_batch = tl.program_id(0) start_m = tl.program_id(1) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_batch_start_loc = tl.load(B_start_loc + cur_batch) + cur_batch_seq_len = tl.load(b_seq_len + cur_batch) + cur_batch_req_idx = tl.load(b_req_idx + cur_batch) + cur_batch_start_loc = tl.load(b_kv_start_loc + cur_batch) - offs_nope_d = tl.arange(0, BLOCK_DMODEL) - offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_nope_d = tl.arange(0, BLOCK_NOPE_DIM) + offs_rope_d = tl.arange(0, BLOCK_ROPE_DIM) + offs_m = (start_m * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)) % cur_batch_seq_len - block_end_loc = tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len) + if (start_m + 1) * BLOCK_SEQ > cur_batch_seq_len: + return kv_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_m, - mask=offs_m < block_end_loc, + req_to_token_indexs + stride_req_to_tokens_b * cur_batch_req_idx + offs_m, other=0, ).to(tl.int64) - off_kv_nope = kv_loc[:, None] * stride_input_dim + offs_nope_d[None, :] - off_kv_rope = kv_loc[:, None] * stride_input_dim + (offs_rope_d + BLOCK_DMODEL)[None, :] - kv_nope = tl.load(KV_input + off_kv_nope, mask=offs_m[:, None] < block_end_loc, other=0.0) - kv_rope = tl.load(KV_input + off_kv_rope, mask=offs_m[:, None] < block_end_loc, other=0.0) - off_nope = (offs_m + cur_batch_start_loc)[:, None] * stride_nope_dim + offs_nope_d[None, :] - off_rope = (offs_m + cur_batch_start_loc)[:, None] * stride_rope_dim + offs_rope_d[None, :] - nope_ptrs = KV_nope + off_nope - rope_ptrs = KV_rope + off_rope - tl.store(nope_ptrs, kv_nope, mask=offs_m[:, None] < block_end_loc) - tl.store(rope_ptrs, kv_rope, mask=offs_m[:, None] < block_end_loc) - if HAS_SCALE: - kv_scale = tl.load(KV_scale + kv_loc * stride_scale_dim, mask=offs_m < block_end_loc) - off_k_scale = cur_batch_start_loc + offs_m - k_scale_ptrs = K_scale + off_k_scale - tl.store(k_scale_ptrs, kv_scale, mask=offs_m < block_end_loc) + off_kv_nope = kv_loc[:, None] * stride_all_s + offs_nope_d[None, :] + off_kv_rope = kv_loc[:, None] * stride_all_s + (offs_rope_d + BLOCK_NOPE_DIM)[None, :] + kv_nope = tl.load(all_compressed_kv + off_kv_nope) + kv_rope = tl.load(all_compressed_kv + off_kv_rope) + off_nope = (offs_m + cur_batch_start_loc)[:, None] * stride_nope_s + offs_nope_d[None, :] + off_rope = (offs_m + cur_batch_start_loc)[:, None] * stride_rope_s + offs_rope_d[None, :] + nope_ptrs = sampled_compressed_kv_nope + off_nope + rope_ptrs = sampled_k_rope + off_rope + tl.store(nope_ptrs, kv_nope) + tl.store(rope_ptrs, kv_rope) return @torch.no_grad() def sample_kv( - kv_input, - kv_nope, - kv_rope, - b_req_idx, - max_value_in_b_seq_len, - b_seq_len, - req_to_token_indexs, - b_kv_start_loc, - kv_scale=None, - k_scale=None, + all_compressed_kv: torch.Tensor, + sampled_compressed_kv_nope: torch.Tensor, + sampled_k_rope: torch.Tensor, + b_req_idx: torch.Tensor, + req_to_token_indexs: torch.Tensor, + b_seq_len: torch.Tensor, + b_kv_start_loc: torch.Tensor, + max_kv_seq_len: int, ): - BLOCK = 128 if not is_tesla() else 64 - - nope_dim = kv_nope.shape[-1] - rope_dim = kv_rope.shape[-1] - if nope_dim >= 512: - BLOCK = 64 if not is_tesla() else 32 - else: - BLOCK = 128 if not is_tesla() else 64 + nope_dim = sampled_compressed_kv_nope.shape[-1] + rope_dim = sampled_k_rope.shape[-1] batch = b_seq_len.shape[0] - max_input_len = max_value_in_b_seq_len + BLOCK = 64 if not is_tesla() else 32 + num_warps = 4 grid = ( batch, - triton.cdiv(max_input_len, BLOCK), + triton.cdiv(max_kv_seq_len, BLOCK), ) - num_warps = 4 if nope_dim <= 64 else 8 + + all_compressed_kv = all_compressed_kv.view(all_compressed_kv.shape[0], all_compressed_kv.shape[2]) + sampled_compressed_kv_nope = sampled_compressed_kv_nope.view(sampled_compressed_kv_nope.shape[0], nope_dim) + sampled_k_rope = sampled_k_rope.view(sampled_k_rope.shape[0], rope_dim) + assert triton.next_power_of_2(nope_dim) == nope_dim + assert triton.next_power_of_2(rope_dim) == rope_dim _sample_kv_kernel[grid]( - kv_input, - kv_scale, - kv_nope, - kv_rope, - k_scale, - b_kv_start_loc, - b_seq_len, - req_to_token_indexs, - b_req_idx, - kv_input.stride(0), - kv_scale.stride(0) if kv_scale is not None else 0, - kv_nope.stride(0), - kv_rope.stride(0), - req_to_token_indexs.stride(0), - HAS_SCALE=kv_scale is not None, - BLOCK_M=BLOCK, - BLOCK_DMODEL=nope_dim, - BLOCK_ROPE_DMODEL=rope_dim, + all_compressed_kv=all_compressed_kv, + stride_all_s=all_compressed_kv.stride(0), + stride_all_d=all_compressed_kv.stride(1), + sampled_compressed_kv_nope=sampled_compressed_kv_nope, + stride_nope_s=sampled_compressed_kv_nope.stride(0), + stride_nope_d=sampled_compressed_kv_nope.stride(1), + sampled_k_rope=sampled_k_rope, + stride_rope_s=sampled_k_rope.stride(0), + stride_rope_d=sampled_k_rope.stride(1), + b_kv_start_loc=b_kv_start_loc, + b_seq_len=b_seq_len, + req_to_token_indexs=req_to_token_indexs, + stride_req_to_tokens_b=req_to_token_indexs.stride(0), + b_req_idx=b_req_idx, + BLOCK_SEQ=BLOCK, + BLOCK_NOPE_DIM=nope_dim, + BLOCK_ROPE_DIM=rope_dim, num_warps=num_warps, num_stages=1, ) From aa012be27cedd2fda3487d155cccf2d242a32c42 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 02:56:46 +0000 Subject: [PATCH 061/114] add flashinfer mla decode --- .../attention/mla_flashinfer_backend.py | 234 +++++++++++++++ .../layer_infer/transformer_layer_infer.py | 22 +- .../context_flashattention_nopad_with_v.py | 278 ------------------ 3 files changed, 241 insertions(+), 293 deletions(-) create mode 100644 lightllm/common/basemodel/attention/mla_flashinfer_backend.py delete mode 100644 lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_with_v.py diff --git a/lightllm/common/basemodel/attention/mla_flashinfer_backend.py b/lightllm/common/basemodel/attention/mla_flashinfer_backend.py new file mode 100644 index 0000000000..9bebe8ce82 --- /dev/null +++ b/lightllm/common/basemodel/attention/mla_flashinfer_backend.py @@ -0,0 +1,234 @@ +import dataclasses +import torch +from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id +from ..triton_kernel.repack_kv_index import repack_kv_index +from typing import Tuple + + +class MlaFlashInferAttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model=model) + num_heads = model.config["num_attention_heads"] + self.tp_q_head_num = num_heads // get_dp_world_size() + self.qk_nope_head_dim = model.qk_nope_head_dim + self.qk_rope_head_dim = model.qk_rope_head_dim + self.kv_lora_rank = model.kv_lora_rank + self.q_data_type = model.data_type + self.kv_data_type = model.data_type + self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) + self.max_seq_length = model.max_seq_length + self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) + self.kv_indices_buffer = [ + torch.empty( + model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() + ), + torch.empty( + model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() + ), + ] + + from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale + + if model.config["rope_scaling"] is not None: + rope_scaling = model.config["rope_scaling"] + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0) + scaling_factor = rope_scaling["factor"] + if mscale_all_dim: + mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + return + + def create_att_prefill_state(self, infer_state) -> "MlaFlashInferPrefillAttState": + return MlaFlashInferPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "MlaFlashInferDecodeAttState": + return MlaFlashInferDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class MlaFlashInferPrefillAttState(BasePrefillAttState): + prefill_wrapper: object = None + + def init_state(self): + self.backend: MlaFlashInferAttBackend = self.backend + + import flashinfer + + q_starts = self.infer_state.b1_cu_q_seq_len.int() + kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + if self.prefill_wrapper is None: + self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + self.backend.workspace_buffer, "NHD" + ) + self.prefill_wrapper.plan( + qo_indptr=q_starts, + kv_indptr=kv_starts, + num_qo_heads=self.backend.tp_q_head_num, + num_kv_heads=self.backend.tp_q_head_num, + head_dim_qk=self.backend.qk_nope_head_dim + self.backend.qk_rope_head_dim, + head_dim_vo=self.backend.qk_nope_head_dim, + q_data_type=self.backend.q_data_type, + causal=True, + sm_scale=self.backend.softmax_scale, + ) + return + + def copy_for_prefill_cuda_graph(self, new_state: "MlaFlashInferPrefillAttState"): + pass + + def prefill_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._mla_prefill_att( + q=q, + k=k, + v=v, + alloc_func=alloc_func, + ) + + def _mla_prefill_att( + self, q: torch.Tensor, k: Tuple[torch.Tensor, torch.Tensor], v: torch.Tensor, alloc_func=torch.empty + ) -> torch.Tensor: + self.backend: MlaFlashInferAttBackend = self.backend # for typing + k_nope, k_rope = k + o_tensor = alloc_func(q.shape, q.dtype, device="cuda") + q_head_num = q.shape[1] + k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1) + self.prefill_wrapper.run(q, k, v, out=o_tensor) + return o_tensor + + +@dataclasses.dataclass +class MlaFlashInferDecodeAttState(BaseDecodeAttState): + kv_indices: torch.Tensor = None + kv_starts: torch.Tensor = None + decode_wrapper: object = None + + def init_state(self): + import flashinfer + + self.backend: MlaFlashInferAttBackend = self.backend + model = self.backend.model + device = self.infer_state.input_ids.device + batch_size = self.infer_state.batch_size + + self.kv_starts = self.infer_state.b1_cu_kv_seq_len + + self.q_indptr = torch.arange(batch_size + 1, dtype=torch.int32).to(device) + if batch_size <= model.graph_max_batch_size and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch: + self.kv_indices = self.backend.kv_indices_buffer[self.infer_state.microbatch_index][ + : batch_size * self.backend.max_seq_length + ] + else: + self.kv_indices = torch.empty( + batch_size * self.backend.max_seq_length, + dtype=torch.int32, + device=device, + ) + + repack_kv_index( + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + self.infer_state.b_seq_len, + self.infer_state.b_start_loc, + self.infer_state.max_kv_seq_len, + self.kv_indices, + ) + assert self.decode_wrapper is None + + self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( + self.backend.workspace_buffer, + use_cuda_graph=True, + qo_indptr=self.q_indptr, + kv_indices=self.kv_indices, + kv_indptr=self.kv_starts, + kv_len_arr=self.infer_state.b_seq_len, + ) + self.decode_wrapper.plan( + self.q_indptr, + self.kv_starts, + self.kv_indices, + self.infer_state.b_seq_len, + self.backend.tp_q_head_num, + self.backend.kv_lora_rank, + self.backend.qk_rope_head_dim, + 1, + False, # causal + self.backend.softmax_scale, + self.backend.q_data_type, + self.backend.kv_data_type, + ) + return + + def copy_for_decode_cuda_graph(self, new_state: "MlaFlashInferDecodeAttState"): + self.decode_wrapper.plan( + new_state.q_indptr, + new_state.kv_starts, + new_state.kv_indices, + new_state.infer_state.b_seq_len, + new_state.backend.tp_q_head_num, + new_state.backend.kv_lora_rank, + new_state.backend.qk_rope_head_dim, + 1, + False, # causal + new_state.backend.softmax_scale, + new_state.backend.q_data_type, + new_state.backend.kv_data_type, + ) + + def decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._mla_decode_att( + q=q, + k=k, + v=v, + att_control=att_control, + alloc_func=alloc_func, + ) + + def _mla_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, + ): + qk_rope_head_dim = 64 + + q_nope, q_rope = q[:, :, :-qk_rope_head_dim], q[:, :, -qk_rope_head_dim:] + + o_tensor = alloc_func(q_nope.shape, dtype=q_nope.dtype, device=q.device) + assert att_control.mla_decode + + self.decode_wrapper.run( + q_nope, + q_rope, + k[:, :, :-qk_rope_head_dim], + k[:, :, -qk_rope_head_dim:], + out=o_tensor, + return_lse=False, + ) + return o_tensor diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 9347cf83fb..42560705c8 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -411,21 +411,13 @@ def _context_attention_kernel_with_CC( infer_state=infer_state, layer_weight=layer_weight, ) - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out - context_attention_fwd_with_v( - q_nope, - q_rope, - k_nope, - k_rope, - v, - o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]), - infer_state.b_start_loc, - infer_state.b1_kv_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - self.softmax_scale, + + o_tensor = infer_state.prefill_att_state.prefill_att( + q=q, + k=(k_nope, k_rope), + v=v, + att_control=AttControl(mla_prefill=True, mla_prefill_dict={"softmax_scale": self.softmax_scale}), + alloc_func=self.alloc_tensor, ) return o_tensor diff --git a/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_with_v.py b/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_with_v.py deleted file mode 100644 index be06351823..0000000000 --- a/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_with_v.py +++ /dev/null @@ -1,278 +0,0 @@ -import torch - -import triton -import triton.language as tl -import math -import torch.nn.functional as F -from lightllm.utils.device_utils import is_tesla - - -@triton.jit -def _fwd_kernel_with_v( - Q_nope, - Q_rope, - K_nope, - K_rope, - V, - sm_scale, - B_Start_Loc, - B_Kv_Start_Loc, - B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 - Out, - stride_q_bs, - stride_q_h, - stride_q_rope_bs, - stride_q_rope_h, - stride_k_bs, - stride_k_h, - stride_k_rope_bs, - stride_k_rope_h, - stride_vbs, - stride_vh, - stride_obs, - stride_oh, - b_prompt_cache_len, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_ROPE_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_k_head = cur_head - - cur_batch_in_q_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_in_kv_start_index = tl.load(B_Kv_Start_Loc + cur_batch) - prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_bs + cur_head * stride_q_h + offs_d[None, :] - off_q_rope = ( - (cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_rope_bs - + cur_head * stride_q_rope_h - + offs_rope_d[None, :] - ) - off_k = offs_n[None, :] * stride_k_bs + cur_k_head * stride_k_h + offs_d[:, None] - off_k_rope = offs_n[None, :] * stride_k_rope_bs + offs_rope_d[:, None] - off_v = offs_n[:, None] * stride_vbs + cur_k_head * stride_vh + offs_d[None, :] - - q = tl.load(Q_nope + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - q_rope = tl.load(Q_rope + off_q_rope, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K_nope + off_k - k_rope_ptrs = K_rope + off_k_rope - v_ptrs = V + off_v - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - block_end_loc = tl.minimum((start_m + 1) * BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len) - - for start_n in range(0, block_mask * block_end_loc, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_bs, - mask=(start_n + offs_n[None, :]) < block_end_loc, - other=0.0, - ) - k_rope = tl.load( - k_rope_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_rope_bs, - mask=(start_n + offs_n[None, :]) < block_end_loc, - other=0.0, - ) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk += tl.dot(q_rope, k_rope) - qk *= sm_scale - qk = tl.where(offs_m[:, None] + prompt_cache_len >= start_n + offs_n[None, :], qk, float("-100000000.0")) - - # -- compute m_ij, p, l_ij - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk -= m_ij[:, None] - p = tl.math.exp2(qk) - l_ij = tl.sum(p, 1) - - # -- update m_i and l_i - alpha = tl.math.exp2(m_i - m_ij) - l_i = l_i * alpha + l_ij - # -- update output accumulator -- - acc = acc * alpha[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < block_end_loc, - other=0.0, - ) - p = p.to(v.dtype) - acc = tl.dot(p, v, acc) - # update m_i and l_i - m_i = m_ij - - acc = acc / l_i[:, None] - # initialize pointers to output - off_o = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - return - - -@torch.no_grad() -def context_attention_fwd_with_v( - q_nope, - q_rope, - k_nope, - k_rope, - v, - o, - b_start_loc, - b_kv_start_loc, - b_seq_len, - b_prompt_cache_len, - max_input_len, - softmax_scale, -): - - BLOCK = 128 if not is_tesla() else 64 - q_nope_dim = q_nope.shape[-1] - q_rope_dim = q_rope.shape[-1] - assert q_nope_dim == k_nope.shape[-1] - assert q_rope_dim == k_rope.shape[-1] - assert q_nope_dim in {16, 32, 64, 128, 256, 512} - assert q_rope_dim in {16, 32, 64, 128, 256} - assert q_nope_dim == v.shape[-1] - - if q_nope_dim >= 512: - BLOCK = 64 if not is_tesla() else 32 - else: - BLOCK = 128 if not is_tesla() else 64 - - if q_nope.dtype == torch.float32: - BLOCK = BLOCK // 4 - - sm_scale = softmax_scale * 1.4426950408889634 - batch, head = b_seq_len.shape[0], q_nope.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - num_warps = 4 if q_nope_dim <= 64 else 8 - - _fwd_kernel_with_v[grid]( - q_nope, - q_rope, - k_nope, - k_rope, - v, - sm_scale, - b_start_loc, - b_kv_start_loc, - b_seq_len, - o, - q_nope.stride(0), - q_nope.stride(1), - q_rope.stride(0), - q_rope.stride(1), - k_nope.stride(0), - k_nope.stride(1), - k_rope.stride(0), - k_rope.stride(1), - v.stride(0), - v.stride(1), - o.stride(0), - o.stride(1), - b_prompt_cache_len=b_prompt_cache_len, - BLOCK_M=BLOCK, - BLOCK_DMODEL=q_nope_dim, - BLOCK_ROPE_DMODEL=q_rope_dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - -if __name__ == "__main__": - import torch - import flashinfer - - Z, N_CTX, H, D_HEAD, ROPE_HEAD = 32, 1024, 16, 128, 64 - dtype = torch.bfloat16 - - k_nope = torch.randn((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda") - k_rope = torch.randn((Z * N_CTX, 1, ROPE_HEAD), dtype=dtype, device="cuda") - k = torch.cat([k_nope, torch.repeat_interleave(k_rope, H, dim=-2)], dim=-1) - v = torch.randn((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda") - - max_input_len = Z * N_CTX - softmax_scale = 0.117 - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX - b_prompt_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda") - b_prompt_cache_len = torch.randint_like(b_seq_len, high=N_CTX - 1, dtype=torch.int32, device="cuda") - q_lens = b_seq_len - b_prompt_cache_len - q_start_loc = q_lens.cumsum(0) - q_lens - kv_start_loc = b_seq_len.cumsum(0) - b_seq_len - - q_nope = torch.randn((q_lens.sum(), H, D_HEAD), dtype=dtype, device="cuda") - q_rope = torch.randn((q_lens.sum(), H, ROPE_HEAD), dtype=dtype, device="cuda") - q = torch.cat([q_nope, q_rope], dim=-1) - - o = torch.empty((q_lens.sum(), H, D_HEAD), dtype=dtype, device="cuda") - o1 = torch.empty((q_lens.sum(), H, D_HEAD), dtype=dtype, device="cuda") - o2 = torch.empty((q_lens.sum(), H, D_HEAD), dtype=dtype, device="cuda") - - fn1 = lambda: context_attention_fwd_with_v( - q_nope, - q_rope, - k_nope, - k_rope, - v, - o, - q_start_loc, - kv_start_loc, - b_seq_len, - b_prompt_cache_len, - max_input_len, - softmax_scale, - ) - - q_starts = torch.zeros((Z + 1,)).int().cuda() - q_starts[1:] = torch.cumsum(b_seq_len - b_prompt_cache_len, dim=0) - kv_starts = torch.zeros_like(q_starts) - kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) - kv_layout = "NHD" - batch_size = Z - q_indptr = q_starts - kv_indptr = kv_starts - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) - wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(workspace_buffer, kv_layout) - wrapper.plan( - qo_indptr=q_indptr, - kv_indptr=kv_indptr, - num_qo_heads=H, - num_kv_heads=H, - head_dim_qk=D_HEAD + ROPE_HEAD, - head_dim_vo=D_HEAD, - q_data_type=dtype, - causal=True, - sm_scale=softmax_scale, - ) - fn2 = lambda: wrapper.run(q, k, v, out=o1) - - ms1 = triton.testing.do_bench(fn1) - ms2 = triton.testing.do_bench(fn2) - cos_sim1 = F.cosine_similarity(o, o1).mean() - print(cos_sim1) - print(ms1) - print(ms2) From 8ee933cbc90b74a3f6c8ea27a0e0e66499a6352c Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 03:06:06 +0000 Subject: [PATCH 062/114] fix --- .../attention/mla_flashinfer_backend.py | 13 ++-- .../basemodel/attention/mla_triton_backend.py | 16 +++-- .../layer_infer/transformer_layer_infer.py | 65 ------------------- 3 files changed, 16 insertions(+), 78 deletions(-) diff --git a/lightllm/common/basemodel/attention/mla_flashinfer_backend.py b/lightllm/common/basemodel/attention/mla_flashinfer_backend.py index 9bebe8ce82..670652fb3f 100644 --- a/lightllm/common/basemodel/attention/mla_flashinfer_backend.py +++ b/lightllm/common/basemodel/attention/mla_flashinfer_backend.py @@ -189,7 +189,7 @@ def copy_for_decode_cuda_graph(self, new_state: "MlaFlashInferDecodeAttState"): def decode_att( self, - q: torch.Tensor, + q: Tuple[torch.Tensor, torch.Tensor], k: torch.Tensor, v: torch.Tensor, att_control: AttControl = AttControl(), @@ -200,6 +200,9 @@ def decode_att( and att_control.use_sliding_window is False and att_control.use_att_sink is False ) + + assert v is None + return self._mla_decode_att( q=q, k=k, @@ -210,17 +213,15 @@ def decode_att( def _mla_decode_att( self, - q: torch.Tensor, + q: Tuple[torch.Tensor, torch.Tensor], k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty, ): qk_rope_head_dim = 64 - - q_nope, q_rope = q[:, :, :-qk_rope_head_dim], q[:, :, -qk_rope_head_dim:] - - o_tensor = alloc_func(q_nope.shape, dtype=q_nope.dtype, device=q.device) + q_nope, q_rope = q + o_tensor = alloc_func(q_nope.shape, dtype=q_nope.dtype, device=q_nope.device) assert att_control.mla_decode self.decode_wrapper.run( diff --git a/lightllm/common/basemodel/attention/mla_triton_backend.py b/lightllm/common/basemodel/attention/mla_triton_backend.py index e735d9067e..c7aa64e5a6 100644 --- a/lightllm/common/basemodel/attention/mla_triton_backend.py +++ b/lightllm/common/basemodel/attention/mla_triton_backend.py @@ -90,20 +90,20 @@ def decode_att( and att_control.use_alibi is False ) assert v is None - q_nope, q_rope = q + return self._mla_decode_att( - q_nope=q_nope, - q_rope=q_rope, - kv=k, + q=q, + k=k, + v=v, att_control=att_control, alloc_func=alloc_func, ) def _mla_decode_att( self, - q_nope: torch.Tensor, - q_rope: torch.Tensor, - kv: torch.Tensor, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty, ): @@ -113,6 +113,8 @@ def _mla_decode_att( from ..triton_kernel.mla_att.decode_att import gqa_token_decode_attention_flash_decoding qk_rope_head_dim = 64 + q_nope, q_rope = q + kv = k out = gqa_token_decode_attention_flash_decoding( q_nope=q_nope, diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 42560705c8..5b6bf31031 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -374,31 +374,6 @@ def _context_attention_flashattention_kernel_with_CC( o_tensor = tmp_output return o_tensor - def _context_attention_flashinfer_kernel_with_CC( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2FlashInferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - k_nope, k_rope, v = self._decompress_kv( - kv, - infer_state, - layer_weight, - False, - infer_state.total_token_num, - infer_state.b_seq_len, - infer_state.max_value_in_b_seq_len, - infer_state.b1_kv_start_loc, - ) - o_tensor = ( - self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out - ) - k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1) - infer_state.prefill_wrapper.run(q, k, v, out=o_tensor) - return o_tensor - def _context_attention_kernel_with_CC( self, q: torch.Tensor, @@ -450,25 +425,6 @@ def _token_gqa_decode_attention_flashattention( ) return o_tensor - def _token_gqa_decode_attention_flashinfer( - self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None - ): - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) - - infer_state.decode_wrapper.run( - q_nope, - q_rope, - kv[:, :, : -self.qk_rope_head_dim], - kv[:, :, -self.qk_rope_head_dim :], - out=o_tensor, - return_lse=False, - ) - return o_tensor - def _token_gqa_decode_attention_flashdecoding( self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None ): @@ -486,27 +442,6 @@ def _token_gqa_decode_attention_flashdecoding( return out - def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager): - destindex_copy_kv( - buffer[:, :, : self.kv_lora_rank], - buffer[:, :, self.kv_lora_rank :], - mem_index, - mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank], - mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank :], - ) - return - - def _copy_kv_to_mem_cache_fp8(self, buffer, mem_index, mem_manager): - destindex_copy_kv_fp8( - buffer[:, :, : self.kv_lora_rank], - buffer[:, :, self.kv_lora_rank :], - mem_index, - mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank].view(torch.float8_e4m3fn), - mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank : -2].view(torch.float8_e4m3fn), - mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(buffer.dtype), - ) - return - def _moe_ffn( self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight ) -> torch.Tensor: From 6e7af99a8769fc0c6a8302fec7f42877b826b5b7 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 04:27:35 +0000 Subject: [PATCH 063/114] fix --- .../basemodel/attention/mla_fa3_backend.py | 234 ++++++++++++++ .../layer_infer/transformer_layer_infer.py | 285 +++++------------- 2 files changed, 308 insertions(+), 211 deletions(-) create mode 100644 lightllm/common/basemodel/attention/mla_fa3_backend.py diff --git a/lightllm/common/basemodel/attention/mla_fa3_backend.py b/lightllm/common/basemodel/attention/mla_fa3_backend.py new file mode 100644 index 0000000000..f19f56e8f9 --- /dev/null +++ b/lightllm/common/basemodel/attention/mla_fa3_backend.py @@ -0,0 +1,234 @@ +import dataclasses +import torch +from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from typing import Optional, TYPE_CHECKING, Tuple +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy +from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor +from lightllm.utils.sgl_utils import flash_attn_varlen_func + + +class MlaFa3AttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model=model) + self.get_page_table_buffer() # init + + def get_page_table_buffer(self): + """ + 用于减少 decode graph 捕获的时候, 造成显存二次方增长的情况. + """ + model = self.model + if not hasattr(self, "_shared_page_table_buffer"): + self._shared_page_table_buffer = [ + torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( + get_current_device_id() + ), + torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( + get_current_device_id() + ), + ] + return self._shared_page_table_buffer + + def create_att_prefill_state(self, infer_state) -> "MlaFa3PrefillAttState": + return MlaFa3PrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "Fa3DecodeAttState": + return Fa3DecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class MlaFa3PrefillAttState(BasePrefillAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + + def init_state(self): + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + def copy_for_prefill_cuda_graph(self, new_state: "MlaFa3PrefillAttState"): + pass + + def prefill_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._mla_prefill_att( + q=q, + k=k, + v=v, + att_control=att_control, + alloc_func=alloc_func, + ) + + def _mla_prefill_att( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty + ) -> torch.Tensor: + self.backend: MlaFa3AttBackend = self.backend # for typing + k_nope, k_rope = k + q_head_num = q.shape[1] + k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1) + + assert q.ndim == 3 and k.ndim == 3 and v.ndim == 3 + + assert att_control.mla_prefill + softmax_scale = att_control.mla_prefill_dict["softmax_scale"] + + o_tensor = flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k=self.cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + max_seqlen_k=self.infer_state.max_kv_seq_len, + softmax_scale=softmax_scale, + causal=True, + return_softmax_lse=False, + ) + return o_tensor + + +@dataclasses.dataclass +class Fa3DecodeAttState(BaseDecodeAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + b_att_seq_len: torch.Tensor = None + # 在是否开启mtp 的不同模式下,其设置不同的值,可以加速算子的运行。 + decode_max_q_seq_len: int = None + + def init_state(self): + self.backend: MlaFa3AttBackend = self.backend + args_mtp_step = get_env_start_args().mtp_step + + if args_mtp_step > 0: + # 修正 mtp 在 fa3 下的输入。 + mtp_size = args_mtp_step + 1 + b_q_seq_len = torch.full( + (self.infer_state.b_seq_len.shape[0] // mtp_size,), + fill_value=mtp_size, + dtype=torch.int32, + device=self.infer_state.b_seq_len.device, + ) + b_kv_seq_len = self.infer_state.b_seq_len[mtp_size - 1 :: mtp_size] + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor( + b_q_seq_len, b_kv_seq_len[mtp_size - 1 :: mtp_size] + ) + self.cu_seqlens_q = b1_cu_q_seq_len.int() + self.cu_seqlens_k = b1_cu_kv_seq_len.int() + else: + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) + assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 + + model = self.backend.model + # 可以使用 cuda graph的时候从 buffer中申请 + if ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ): + page_buffer = self.backend.get_page_table_buffer() + self.page_table = page_buffer[self.infer_state.microbatch_index][ + : att_batch_size * model.graph_max_len_in_batch + ].reshape(att_batch_size, model.graph_max_len_in_batch) + else: + self.page_table = torch.empty( + (att_batch_size, self.infer_state.max_kv_seq_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + + if args_mtp_step > 0: + page_table_copy( + page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], + ) + self.b_att_seq_len = self.infer_state.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() + self.decode_max_q_seq_len = args_mtp_step + 1 + else: + page_table_copy( + page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + ) + self.b_att_seq_len = self.infer_state.b_seq_len + self.decode_max_q_seq_len = 1 + return + + def copy_for_decode_cuda_graph(self, new_state: "Fa3DecodeAttState"): + pass + + def decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + assert v is None + + return self._mla_decode_att( + q=q, + k=k, + v=v, + att_control=att_control, + alloc_func=alloc_func, + ) + + def _mla_decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + q_nope, q_rope = q + kv = k + qk_rope_head_dim = 64 + kv_lora_rank = kv.shape[-1] - qk_rope_head_dim + k_rope = kv[:, :, -qk_rope_head_dim:].view(-1, 1, 1, qk_rope_head_dim) + kv_nope = kv[:, :, :-qk_rope_head_dim].view(-1, 1, 1, kv_lora_rank) + k_descale, v_descale = None, None + assert att_control.mla_decode + softmax_scale = att_control.mla_decode_dict["softmax_scale"] + + o_tensor = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope, + v_cache=kv_nope, + qv=q_nope, + page_table=self.page_table, + cache_seqlens=self.b_att_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=(-1, -1), + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=False, + ) + return o_tensor diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 5b6bf31031..de750bb581 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -1,31 +1,19 @@ import os import torch -import torch.functional as F import torch.distributed as dist -import numpy as np import triton -from typing import Tuple from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight -from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv import destindex_copy_kv -from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 from lightllm.common.basemodel.attention.base_att import AttControl -from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad_fp8 import context_attention_fwd_fp8 -from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad_with_v import context_attention_fwd_with_v from lightllm.models.deepseek2.triton_kernel.sample_kv import sample_kv -from lightllm.models.deepseek2.triton_kernel.repeat_rope import repeat_rope -from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding_fp8 import gqa_token_decode_attention_flash_decoding_fp8 from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo -from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo from functools import partial from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale from lightllm.distributed.communication_op import all_gather, all_gather_into_tensor, all_reduce, reduce_scatter_tensor from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.dist_utils import get_global_world_size from lightllm.utils.log_utils import init_logger -from lightllm.utils.sgl_utils import flash_attn_varlen_func, flash_attn_with_kvcache, merge_state_v2 logger = init_logger(__name__) @@ -86,34 +74,81 @@ def _bind_ffn(self): self._ffn = partial(LlamaTransformerLayerInfer._ffn, self) self._tpsp_ffn = self._tpsp_ffn_tp - def _bind_attention(self): + def _context_attention_kernel( + self, + q: torch.Tensor, + kv, + infer_state: Deepseek2InferStateInfo, + layer_weight: Deepseek2TransformerLayerWeight, + out=None, + ) -> torch.Tensor: + k_nope, k_rope, v = self._decompress_kv( + infer_state=infer_state, + layer_weight=layer_weight, + ) - self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - if get_env_start_args().enable_fa3: - self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashattention, self - ) - elif get_env_start_args().enable_flashinfer_decode: - self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashinfer, self - ) - else: - self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self - ) - if self.enable_cc_method: - if get_env_start_args().enable_fa3: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_flashattention_kernel_with_CC, self - ) - elif get_env_start_args().enable_flashinfer_prefill: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC, self - ) - else: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self - ) + o_tensor = infer_state.prefill_att_state.prefill_att( + q=q, + k=(k_nope, k_rope), + v=v, + att_control=AttControl(mla_prefill=True, mla_prefill_dict={"softmax_scale": self.softmax_scale}), + alloc_func=self.alloc_tensor, + ) + return o_tensor + + def _token_attention_kernel( + self, + q: torch.Tensor, + infer_state: Deepseek2InferStateInfo, + layer_weight: Deepseek2TransformerLayerWeight, + out=None, + ): + q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + kv = infer_state.mem_manager.get_att_input_params() + + out = infer_state.decode_att_state.decode_att( + q=(q_nope, q_rope), + k=kv, + v=None, + att_control=AttControl(mla_decode=True, mla_decode_dict={"softmax_scale": self.softmax_scale}), + alloc_func=self.alloc_tensor, + ) + return out + + def _decompress_kv( + self, + infer_state: Deepseek2InferStateInfo, + layer_weight: Deepseek2TransformerLayerWeight, + ): + compressed_kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + + total_token_num = infer_state.total_token_num + sampled_compressed_kv_nope = self.alloc_tensor( + [total_token_num, 1, layer_weight.kv_lora_rank], dtype=compressed_kv.dtype + ) + sampled_k_rope = self.alloc_tensor([total_token_num, 1, self.qk_rope_head_dim], dtype=compressed_kv.dtype) + sample_kv( + all_compressed_kv=compressed_kv, + sampled_compressed_kv_nope=sampled_compressed_kv_nope, + sampled_k_rope=sampled_k_rope, + b_req_idx=infer_state.b_req_idx, + req_to_token_indexs=infer_state.req_manager.req_to_token_indexs, + b_seq_len=infer_state.b_seq_len, + b_kv_start_loc=infer_state.b1_kv_start_loc[:-1], + max_kv_seq_len=infer_state.max_kv_seq_len, + ) + # CC + sampled_compressed_kv_nope = sampled_compressed_kv_nope.view( + total_token_num, layer_weight.kv_lora_rank + ).contiguous() + sampled_kv_nope = self.alloc_tensor( + [total_token_num, self.tp_q_head_num_, (self.qk_nope_head_dim + self.v_head_dim)], + dtype=sampled_compressed_kv_nope.dtype, + ) + layer_weight.cc_kv_b_proj_.mm(sampled_compressed_kv_nope, out=sampled_kv_nope.view(total_token_num, -1)) + sampled_k_nope, sampled_v = torch.split(sampled_kv_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + return sampled_k_nope, sampled_k_rope, sampled_v def _get_qkv( self, @@ -270,178 +305,6 @@ def _tpsp_get_o( return o_tensor - def _decompress_kv( - self, - infer_state: Deepseek2InferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - ): - compressed_kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - - total_token_num = infer_state.total_token_num - sampled_compressed_kv_nope = self.alloc_tensor( - [total_token_num, 1, layer_weight.kv_lora_rank], dtype=compressed_kv.dtype - ) - sampled_k_rope = self.alloc_tensor([total_token_num, 1, self.qk_rope_head_dim], dtype=compressed_kv.dtype) - sample_kv( - all_compressed_kv=compressed_kv, - sampled_compressed_kv_nope=sampled_compressed_kv_nope, - sampled_k_rope=sampled_k_rope, - b_req_idx=infer_state.b_req_idx, - req_to_token_indexs=infer_state.req_manager.req_to_token_indexs, - b_seq_len=infer_state.b_seq_len, - b_kv_start_loc=infer_state.b1_kv_start_loc[:-1], - max_kv_seq_len=infer_state.max_kv_seq_len, - ) - # CC - sampled_compressed_kv_nope = sampled_compressed_kv_nope.view( - total_token_num, layer_weight.kv_lora_rank - ).contiguous() - sampled_kv_nope = self.alloc_tensor( - [total_token_num, self.tp_q_head_num_, (self.qk_nope_head_dim + self.v_head_dim)], - dtype=sampled_compressed_kv_nope.dtype, - ) - layer_weight.cc_kv_b_proj_.mm(sampled_compressed_kv_nope, out=sampled_kv_nope.view(total_token_num, -1)) - sampled_k_nope, sampled_v = torch.split(sampled_kv_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - return sampled_k_nope, sampled_k_rope, sampled_v - - # Adapted from: - # https://github.com/sgl-project/sglang/blob/c998d04b46920f06d945fbef9023884a768723fc/python/sglang/srt/models/deepseek_v2.py#L962 - def _context_attention_flashattention_kernel_with_CC( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2FlashAttentionStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - k_nope, k_rope, v = self._decompress_kv( - kv, - infer_state, - layer_weight, - False, - infer_state.total_token_num, - infer_state.b_seq_len, - infer_state.max_value_in_b_seq_len, - infer_state.b1_kv_start_loc, - skip_sample=True, - ) - k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1) - o_tensor, lse, *rest = flash_attn_varlen_func( - q=q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), - k=k.view(-1, self.tp_k_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), - v=v.view(-1, self.tp_v_head_num_, self.v_head_dim), - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k=infer_state.cu_seqlens_q, - max_seqlen_q=infer_state.q_max_seq_len, - max_seqlen_k=infer_state.max_seq_len, - softmax_scale=self.softmax_scale, - causal=True, - return_softmax_lse=True, - ) - if infer_state.has_prefix_kv: - k_nope, k_rope, v = self._decompress_kv( - kv, - infer_state, - layer_weight, - False, - infer_state.prefix_total_token_num, - infer_state.b_ready_cache_len, - infer_state.prefix_k_max_len, - infer_state.cu_seqlens_prefix_k, - ) - k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1) - prefix_output, prefix_lse, *rest = flash_attn_varlen_func( - q=q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), - k=k.view(-1, self.tp_k_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), - v=v.view(-1, self.tp_v_head_num_, self.v_head_dim), - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k=infer_state.cu_seqlens_prefix_k, - max_seqlen_q=infer_state.q_max_seq_len, - max_seqlen_k=infer_state.prefix_k_max_len, - softmax_scale=self.softmax_scale, - causal=False, - return_softmax_lse=True, - ) - lse = torch.transpose(lse, 0, 1).contiguous() - prefix_lse = torch.transpose(prefix_lse, 0, 1).contiguous() - tmp_output = ( - self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) - if out is None - else out - ) - tmp_lse = torch.empty_like(lse) - merge_state_v2(prefix_output, prefix_lse, o_tensor, lse, tmp_output, tmp_lse) - o_tensor = tmp_output - return o_tensor - - def _context_attention_kernel_with_CC( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2InferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - k_nope, k_rope, v = self._decompress_kv( - infer_state=infer_state, - layer_weight=layer_weight, - ) - - o_tensor = infer_state.prefill_att_state.prefill_att( - q=q, - k=(k_nope, k_rope), - v=v, - att_control=AttControl(mla_prefill=True, mla_prefill_dict={"softmax_scale": self.softmax_scale}), - alloc_func=self.alloc_tensor, - ) - return o_tensor - - def _token_gqa_decode_attention_flashattention( - self, q, infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None - ): - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) - kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) - k_descale, v_descale = None, None - o_tensor = flash_attn_with_kvcache( - q=q_rope, - k_cache=k_rope, - v_cache=kv_nope, - qv=q_nope, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_att_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=infer_state.max_q_seq_len, - softmax_scale=self.softmax_scale, - causal=True, - window_size=(-1, -1), - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, - ) - return o_tensor - - def _token_gqa_decode_attention_flashdecoding( - self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None - ): - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - kv = infer_state.mem_manager.get_att_input_params() - - out = infer_state.decode_att_state.decode_att( - q=(q_nope, q_rope), - k=kv, - v=None, - att_control=AttControl(mla_decode=True, mla_decode_dict={"softmax_scale": self.softmax_scale}), - alloc_func=self.alloc_tensor, - ) - - return out - def _moe_ffn( self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight ) -> torch.Tensor: From 24588d2174568ea8271c562b042d9cc1be1444ed Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 04:45:41 +0000 Subject: [PATCH 064/114] fix --- lightllm/common/basemodel/attention/fa3/__init__.py | 0 .../basemodel/attention/{fa3_backend.py => fa3/fp.py} | 0 .../attention/{fp8_fa3_backend.py => fa3/fp8.py} | 0 .../attention/{mla_fa3_backend.py => fa3/mla.py} | 8 ++++---- .../common/basemodel/attention/flashinfer/__init__.py | 0 .../attention/{flashinfer_backend.py => flashinfer/fp.py} | 0 .../{fp8_flashinfer_backend.py => flashinfer/fp8.py} | 0 .../{mla_flashinfer_backend.py => flashinfer/mla.py} | 0 lightllm/common/basemodel/attention/triton/__init__.py | 0 .../attention/{triton_backend.py => triton/fp.py} | 0 .../{int4kv_triton_backend.py => triton/int4kv.py} | 0 .../{int8kv_triton_backend.py => triton/int8kv.py} | 0 .../attention/{mla_triton_backend.py => triton/mla.py} | 0 13 files changed, 4 insertions(+), 4 deletions(-) create mode 100644 lightllm/common/basemodel/attention/fa3/__init__.py rename lightllm/common/basemodel/attention/{fa3_backend.py => fa3/fp.py} (100%) rename lightllm/common/basemodel/attention/{fp8_fa3_backend.py => fa3/fp8.py} (100%) rename lightllm/common/basemodel/attention/{mla_fa3_backend.py => fa3/mla.py} (96%) create mode 100644 lightllm/common/basemodel/attention/flashinfer/__init__.py rename lightllm/common/basemodel/attention/{flashinfer_backend.py => flashinfer/fp.py} (100%) rename lightllm/common/basemodel/attention/{fp8_flashinfer_backend.py => flashinfer/fp8.py} (100%) rename lightllm/common/basemodel/attention/{mla_flashinfer_backend.py => flashinfer/mla.py} (100%) create mode 100644 lightllm/common/basemodel/attention/triton/__init__.py rename lightllm/common/basemodel/attention/{triton_backend.py => triton/fp.py} (100%) rename lightllm/common/basemodel/attention/{int4kv_triton_backend.py => triton/int4kv.py} (100%) rename lightllm/common/basemodel/attention/{int8kv_triton_backend.py => triton/int8kv.py} (100%) rename lightllm/common/basemodel/attention/{mla_triton_backend.py => triton/mla.py} (100%) diff --git a/lightllm/common/basemodel/attention/fa3/__init__.py b/lightllm/common/basemodel/attention/fa3/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/attention/fa3_backend.py b/lightllm/common/basemodel/attention/fa3/fp.py similarity index 100% rename from lightllm/common/basemodel/attention/fa3_backend.py rename to lightllm/common/basemodel/attention/fa3/fp.py diff --git a/lightllm/common/basemodel/attention/fp8_fa3_backend.py b/lightllm/common/basemodel/attention/fa3/fp8.py similarity index 100% rename from lightllm/common/basemodel/attention/fp8_fa3_backend.py rename to lightllm/common/basemodel/attention/fa3/fp8.py diff --git a/lightllm/common/basemodel/attention/mla_fa3_backend.py b/lightllm/common/basemodel/attention/fa3/mla.py similarity index 96% rename from lightllm/common/basemodel/attention/mla_fa3_backend.py rename to lightllm/common/basemodel/attention/fa3/mla.py index f19f56e8f9..17430bcaed 100644 --- a/lightllm/common/basemodel/attention/mla_fa3_backend.py +++ b/lightllm/common/basemodel/attention/fa3/mla.py @@ -34,8 +34,8 @@ def get_page_table_buffer(self): def create_att_prefill_state(self, infer_state) -> "MlaFa3PrefillAttState": return MlaFa3PrefillAttState(backend=self, infer_state=infer_state) - def create_att_decode_state(self, infer_state) -> "Fa3DecodeAttState": - return Fa3DecodeAttState(backend=self, infer_state=infer_state) + def create_att_decode_state(self, infer_state) -> "MlaFa3DecodeAttState": + return MlaFa3DecodeAttState(backend=self, infer_state=infer_state) @dataclasses.dataclass @@ -100,7 +100,7 @@ def _mla_prefill_att( @dataclasses.dataclass -class Fa3DecodeAttState(BaseDecodeAttState): +class MlaFa3DecodeAttState(BaseDecodeAttState): cu_seqlens_q: torch.Tensor = None cu_seqlens_k: torch.Tensor = None page_table: torch.Tensor = None @@ -169,7 +169,7 @@ def init_state(self): self.decode_max_q_seq_len = 1 return - def copy_for_decode_cuda_graph(self, new_state: "Fa3DecodeAttState"): + def copy_for_decode_cuda_graph(self, new_state: "MlaFa3DecodeAttState"): pass def decode_att( diff --git a/lightllm/common/basemodel/attention/flashinfer/__init__.py b/lightllm/common/basemodel/attention/flashinfer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/attention/flashinfer_backend.py b/lightllm/common/basemodel/attention/flashinfer/fp.py similarity index 100% rename from lightllm/common/basemodel/attention/flashinfer_backend.py rename to lightllm/common/basemodel/attention/flashinfer/fp.py diff --git a/lightllm/common/basemodel/attention/fp8_flashinfer_backend.py b/lightllm/common/basemodel/attention/flashinfer/fp8.py similarity index 100% rename from lightllm/common/basemodel/attention/fp8_flashinfer_backend.py rename to lightllm/common/basemodel/attention/flashinfer/fp8.py diff --git a/lightllm/common/basemodel/attention/mla_flashinfer_backend.py b/lightllm/common/basemodel/attention/flashinfer/mla.py similarity index 100% rename from lightllm/common/basemodel/attention/mla_flashinfer_backend.py rename to lightllm/common/basemodel/attention/flashinfer/mla.py diff --git a/lightllm/common/basemodel/attention/triton/__init__.py b/lightllm/common/basemodel/attention/triton/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/attention/triton_backend.py b/lightllm/common/basemodel/attention/triton/fp.py similarity index 100% rename from lightllm/common/basemodel/attention/triton_backend.py rename to lightllm/common/basemodel/attention/triton/fp.py diff --git a/lightllm/common/basemodel/attention/int4kv_triton_backend.py b/lightllm/common/basemodel/attention/triton/int4kv.py similarity index 100% rename from lightllm/common/basemodel/attention/int4kv_triton_backend.py rename to lightllm/common/basemodel/attention/triton/int4kv.py diff --git a/lightllm/common/basemodel/attention/int8kv_triton_backend.py b/lightllm/common/basemodel/attention/triton/int8kv.py similarity index 100% rename from lightllm/common/basemodel/attention/int8kv_triton_backend.py rename to lightllm/common/basemodel/attention/triton/int8kv.py diff --git a/lightllm/common/basemodel/attention/mla_triton_backend.py b/lightllm/common/basemodel/attention/triton/mla.py similarity index 100% rename from lightllm/common/basemodel/attention/mla_triton_backend.py rename to lightllm/common/basemodel/attention/triton/mla.py From 3fc80845132ed9960ddc58634730a9edf22f0e2e Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 04:52:27 +0000 Subject: [PATCH 065/114] fix all --- lightllm/common/basemodel/attention/fa3/fp.py | 2 +- .../common/basemodel/attention/fa3/fp8.py | 4 ++-- .../common/basemodel/attention/fa3/mla.py | 2 +- .../basemodel/attention/flashinfer/fp.py | 4 ++-- .../basemodel/attention/flashinfer/fp8.py | 4 ++-- .../basemodel/attention/flashinfer/mla.py | 4 ++-- .../common/basemodel/attention/triton/fp.py | 20 +++++++++---------- .../basemodel/attention/triton/int4kv.py | 8 ++++---- .../basemodel/attention/triton/int8kv.py | 10 +++++----- .../common/basemodel/attention/triton/mla.py | 6 +++--- 10 files changed, 32 insertions(+), 32 deletions(-) diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 794c0dff8a..4c0d2cca14 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -1,6 +1,6 @@ import dataclasses import torch -from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional, TYPE_CHECKING from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.sgl_utils import flash_attn_with_kvcache diff --git a/lightllm/common/basemodel/attention/fa3/fp8.py b/lightllm/common/basemodel/attention/fa3/fp8.py index b4aa55653f..53974730ff 100644 --- a/lightllm/common/basemodel/attention/fa3/fp8.py +++ b/lightllm/common/basemodel/attention/fa3/fp8.py @@ -1,6 +1,6 @@ import dataclasses import torch -from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional, TYPE_CHECKING from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.sgl_utils import flash_attn_with_kvcache @@ -10,7 +10,7 @@ from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops from typing import Union -from .fa3_backend import Fa3AttBackend, Fa3PrefillAttState, Fa3DecodeAttState +from .fp import Fa3AttBackend, Fa3PrefillAttState, Fa3DecodeAttState if HAS_VLLM: scaled_fp8_quant = vllm_ops.scaled_fp8_quant diff --git a/lightllm/common/basemodel/attention/fa3/mla.py b/lightllm/common/basemodel/attention/fa3/mla.py index 17430bcaed..187a8ae034 100644 --- a/lightllm/common/basemodel/attention/fa3/mla.py +++ b/lightllm/common/basemodel/attention/fa3/mla.py @@ -1,6 +1,6 @@ import dataclasses import torch -from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional, TYPE_CHECKING, Tuple from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.sgl_utils import flash_attn_with_kvcache diff --git a/lightllm/common/basemodel/attention/flashinfer/fp.py b/lightllm/common/basemodel/attention/flashinfer/fp.py index 024e1e714d..af3ba7d015 100644 --- a/lightllm/common/basemodel/attention/flashinfer/fp.py +++ b/lightllm/common/basemodel/attention/flashinfer/fp.py @@ -1,8 +1,8 @@ import dataclasses import torch -from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id -from ..triton_kernel.repack_kv_index import repack_kv_index +from ...triton_kernel.repack_kv_index import repack_kv_index class FlashInferAttBackend(BaseAttBackend): diff --git a/lightllm/common/basemodel/attention/flashinfer/fp8.py b/lightllm/common/basemodel/attention/flashinfer/fp8.py index c334cf6193..ee20f40e8e 100644 --- a/lightllm/common/basemodel/attention/flashinfer/fp8.py +++ b/lightllm/common/basemodel/attention/flashinfer/fp8.py @@ -1,7 +1,7 @@ import dataclasses import torch -from .base_att import AttControl -from .flashinfer_backend import FlashInferAttBackend, FlashInferPrefillAttState, FlashInferDecodeAttState +from ..base_att import AttControl +from .fp import FlashInferAttBackend, FlashInferPrefillAttState, FlashInferDecodeAttState class Fp8FlashInferAttBackend(FlashInferAttBackend): diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py index 670652fb3f..7ba50eeb85 100644 --- a/lightllm/common/basemodel/attention/flashinfer/mla.py +++ b/lightllm/common/basemodel/attention/flashinfer/mla.py @@ -1,8 +1,8 @@ import dataclasses import torch -from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id -from ..triton_kernel.repack_kv_index import repack_kv_index +from ...triton_kernel.repack_kv_index import repack_kv_index from typing import Tuple diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py index ec19f16997..0b65cf8bb6 100644 --- a/lightllm/common/basemodel/attention/triton/fp.py +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -1,6 +1,6 @@ import dataclasses import torch -from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional @@ -45,7 +45,7 @@ def _alibi_prefill_att( ): out = alloc_func(q.shape, q.dtype) - from ..triton_kernel.alibi_att.context_flashattention_nopad import context_attention_fwd + from ...triton_kernel.alibi_att.context_flashattention_nopad import context_attention_fwd context_attention_fwd( q, @@ -63,7 +63,7 @@ def _alibi_prefill_att( return out def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty): - from ..triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd + from ...triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd out = alloc_func(q.shape, q.dtype) context_attention_fwd( @@ -119,7 +119,7 @@ def _alibi_decode_att( att_control: AttControl, alloc_func=torch.empty, ): - from ..triton_kernel.alibi_att.token_flashattention_nopad import token_attention_fwd + from ...triton_kernel.alibi_att.token_flashattention_nopad import token_attention_fwd out = alloc_func(q.shape, q.dtype) token_attention_fwd( @@ -145,7 +145,7 @@ def _normal_decode_flash_decoding_att( v: torch.Tensor, alloc_func=torch.empty, ): - from ..triton_kernel.att.decode_att.mha.flash_decoding.flash_decoding import ( + from ...triton_kernel.att.decode_att.mha.flash_decoding.flash_decoding import ( token_decode_attention_flash_decoding, ) @@ -168,7 +168,7 @@ def _normal_decode_gqa_flash_decoding_att( v: torch.Tensor, alloc_func=torch.empty, ): - from ..triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding import ( + from ...triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding import ( gqa_token_decode_attention_flash_decoding, ) @@ -193,7 +193,7 @@ def _normal_decode_gqa_flash_decoding_att_vsm( alloc_func=torch.empty, ): # TODO USE , 在特定场景下比 _normal_decode_gqa_flash_decoding_att 省显存 - from ..triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding_vsm import ( + from ...triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding_vsm import ( gqa_token_decode_attention_flash_decoding_vsm, ) @@ -218,7 +218,7 @@ def _normal_decode_gqa_att( alloc_func=torch.empty, ): # TODO USE , 在特定场景下比 _normal_decode_gqa_flash_decoding_att 省显存 - from ..triton_kernel.att.decode_att.gqa.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd + from ...triton_kernel.att.decode_att.gqa.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd out = alloc_func(q.shape, q.dtype) @@ -248,7 +248,7 @@ def _normal_decode_stage3_att( calcu_shape1 = (batch_size, q_head_num, head_dim) att_m_tensor = alloc_func((q_head_num, total_token_num), torch.float32) - from ..triton_kernel.att.decode_att.mha.stage3_decode_att.token_attention_nopad_att1 import token_att_fwd + from ...triton_kernel.att.decode_att.mha.stage3_decode_att.token_attention_nopad_att1 import token_att_fwd token_att_fwd( q.view(calcu_shape1), @@ -262,7 +262,7 @@ def _normal_decode_stage3_att( ) o_tensor = alloc_func(q.shape, q.dtype) - from ..triton_kernel.att.decode_att.mha.stage3_decode_att.token_attention_softmax_and_reducev import ( + from ...triton_kernel.att.decode_att.mha.stage3_decode_att.token_attention_softmax_and_reducev import ( token_softmax_reducev_fwd, ) diff --git a/lightllm/common/basemodel/attention/triton/int4kv.py b/lightllm/common/basemodel/attention/triton/int4kv.py index 2b0a075480..14aae90e63 100644 --- a/lightllm/common/basemodel/attention/triton/int4kv.py +++ b/lightllm/common/basemodel/attention/triton/int4kv.py @@ -1,7 +1,7 @@ import dataclasses import torch from lightllm.utils.envs_utils import get_env_start_args -from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional, Tuple @@ -82,7 +82,7 @@ def _groupsize_quant_prefill_att( max_kv_seq_len = self.infer_state.max_kv_seq_len - from ..triton_kernel.kv_copy.ppl_int4kv_copy_kv import dequantize_int4kv + from ...triton_kernel.kv_copy.ppl_int4kv_copy_kv import dequantize_int4kv dequantize_int4kv( k=k, @@ -99,7 +99,7 @@ def _groupsize_quant_prefill_att( quant_group_size=self.backend.quant_group_size, ) - from ..triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd_contiguous_kv + from ...triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd_contiguous_kv context_attention_fwd_contiguous_kv( q=q, @@ -157,7 +157,7 @@ def ppl_int4kv_decode_att( v_scale: torch.Tensor, alloc_func=torch.empty, ) -> torch.Tensor: - from ..triton_kernel.att.decode_att.int4kv.ppl_int4kv_flash_decoding import ( + from ...triton_kernel.att.decode_att.int4kv.ppl_int4kv_flash_decoding import ( token_decode_attention_flash_decoding, ) diff --git a/lightllm/common/basemodel/attention/triton/int8kv.py b/lightllm/common/basemodel/attention/triton/int8kv.py index bf0ffdfdc9..e7bb34e2f9 100644 --- a/lightllm/common/basemodel/attention/triton/int8kv.py +++ b/lightllm/common/basemodel/attention/triton/int8kv.py @@ -1,7 +1,7 @@ import dataclasses import torch from lightllm.utils.envs_utils import get_env_start_args -from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional, Tuple from lightllm.utils.envs_utils import enable_diverse_mode_gqa_decode_fast_kernel @@ -83,7 +83,7 @@ def _groupsize_quant_prefill_att( max_kv_seq_len = self.infer_state.max_kv_seq_len - from ..triton_kernel.kv_copy.ppl_int8kv_copy_kv import dequantize_int8kv + from ...triton_kernel.kv_copy.ppl_int8kv_copy_kv import dequantize_int8kv dequantize_int8kv( k=k, @@ -100,7 +100,7 @@ def _groupsize_quant_prefill_att( quant_group_size=self.backend.quant_group_size, ) - from ..triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd_contiguous_kv + from ...triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd_contiguous_kv context_attention_fwd_contiguous_kv( q=q, @@ -161,7 +161,7 @@ def diverse_decode_att( alloc_func=torch.empty, ) -> torch.Tensor: - from ..triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse import ( + from ...triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse import ( token_decode_attention_flash_decoding, ) @@ -184,7 +184,7 @@ def ppl_mha_int8kv_decode_att( v_scale: torch.Tensor, alloc_func=torch.empty, ) -> torch.Tensor: - from ..triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding import ( + from ...triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding import ( token_decode_attention_flash_decoding, ) diff --git a/lightllm/common/basemodel/attention/triton/mla.py b/lightllm/common/basemodel/attention/triton/mla.py index c7aa64e5a6..e0aced276b 100644 --- a/lightllm/common/basemodel/attention/triton/mla.py +++ b/lightllm/common/basemodel/attention/triton/mla.py @@ -1,6 +1,6 @@ import dataclasses import torch -from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Tuple @@ -43,7 +43,7 @@ def _mla_prefill_att( att_control: AttControl, alloc_func=torch.empty, ): - from ..triton_kernel.mla_att.prefill_att import context_attention_fwd_with_v + from ...triton_kernel.mla_att.prefill_att import context_attention_fwd_with_v qk_rope_head_dim = 64 q_nope, q_rope = q[:, :, :-qk_rope_head_dim], q[:, :, -qk_rope_head_dim:] @@ -110,7 +110,7 @@ def _mla_decode_att( assert att_control.mla_decode softmax_scale = att_control.mla_prefill_dict["softmax_scale"] - from ..triton_kernel.mla_att.decode_att import gqa_token_decode_attention_flash_decoding + from ...triton_kernel.mla_att.decode_att import gqa_token_decode_attention_flash_decoding qk_rope_head_dim = 64 q_nope, q_rope = q From e78b34649f0aa88db02c9525eef8b89ec2c8b49b Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 05:11:44 +0000 Subject: [PATCH 066/114] fix --- .../basemodel/attention/create_utils.py | 71 ++++++++++++------- lightllm/server/api_cli.py | 14 ++-- lightllm/server/core/objs/start_args_type.py | 10 ++- 3 files changed, 60 insertions(+), 35 deletions(-) diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 1e7f0e9ea5..39e32ac635 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -1,15 +1,18 @@ from lightllm.utils.envs_utils import get_env_start_args from .base_att import BaseAttBackend -from .triton_backend import TritonAttBackend -from .int4kv_triton_backend import Int4kvTritonAttBackend -from .int8kv_triton_backend import Int8kvTritonAttBackend -from .fa3_backend import Fa3AttBackend -from .fp8_fa3_backend import Fp8Fa3AttBackend -from .flashinfer_backend import FlashInferAttBackend -from .fp8_flashinfer_backend import Fp8FlashInferAttBackend - -backend_dict = { - None: { +from .triton.fp import TritonAttBackend +from .triton.int4kv import Int4kvTritonAttBackend +from .triton.int8kv import Int8kvTritonAttBackend +from .triton.mla import MlaTritonAttBackend +from .fa3.fp import Fa3AttBackend +from .fa3.fp8 import Fp8Fa3AttBackend +from .fa3.mla import MlaFa3AttBackend +from .flashinfer.fp8 import Fp8FlashInferAttBackend +from .flashinfer.fp import FlashInferAttBackend +from .flashinfer.mla import MlaFlashInferAttBackend + +data_type_to_backend = { + "None": { "triton": TritonAttBackend, "fa3": Fa3AttBackend, "flashinfer": FlashInferAttBackend, @@ -26,36 +29,52 @@ }, } +mla_data_type_to_backend = { + "None": { + "triton": MlaTritonAttBackend, + "fa3": MlaFa3AttBackend, + "flashinfer": MlaFlashInferAttBackend, + }, +} + -def get_prefill_att_backend_class() -> BaseAttBackend: +def get_prefill_att_backend_class(index=0) -> BaseAttBackend: args = get_env_start_args() llm_dtype = args.llm_kv_type - if args.llm_prefill_att_backend is not None: - return backend_dict[llm_dtype][args.llm_prefill_att_backend] + backend_str = args.llm_prefill_att_backend[index] + if backend_str != "None": + return data_type_to_backend[llm_dtype][backend_str] else: # 根据环境自动选择最好的 raise NotImplementedError(f"error") -def get_decode_att_backend_class() -> BaseAttBackend: +def get_decode_att_backend_class(index=0) -> BaseAttBackend: args = get_env_start_args() llm_dtype = args.llm_kv_type - if args.llm_decode_att_backend is not None: - return backend_dict[llm_dtype][args.llm_decode_att_backend] + backend_str = args.llm_decode_att_backend[index] + if backend_str != "None": + return data_type_to_backend[llm_dtype][backend_str] else: # 根据环境自动选择最好的 raise NotImplementedError(f"error") -def get_mla_prefill_att_backend_class() -> BaseAttBackend: - # args = get_env_start_args() - # llm_dtype = args.llm_kv_type - # 根据环境自动选择最好的 - raise NotImplementedError(f"error") +def get_mla_prefill_att_backend_class(index=0) -> BaseAttBackend: + args = get_env_start_args() + llm_dtype = args.llm_kv_type + backend_str = args.llm_prefill_att_backend[index] + if backend_str != "None": + return mla_data_type_to_backend[llm_dtype][backend_str] + else: + raise NotImplementedError(f"error") -def get_mla_decode_att_backend_class() -> BaseAttBackend: - # args = get_env_start_args() - # llm_dtype = args.llm_kv_type - # 根据环境自动选择最好的 - raise NotImplementedError(f"error") +def get_mla_decode_att_backend_class(index=0) -> BaseAttBackend: + args = get_env_start_args() + llm_dtype = args.llm_kv_type + backend_str = args.llm_decode_att_backend[index] + if backend_str != "None": + return mla_data_type_to_backend[llm_dtype][backend_str] + else: + raise NotImplementedError(f"error") diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 8b16cb5196..10aef09233 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -339,22 +339,24 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--llm_prefill_att_backend", type=str, - choices=[None, "triton", "fa3", "flashinfer"], - default=None, + nargs="+", + choices=["None", "triton", "fa3", "flashinfer"], + default=["None"], help="""prefill attention kernel used in llm""", ) parser.add_argument( "--llm_decode_att_backend", type=str, - choices=[None, "triton", "fa3", "flashinfer"], - default=None, + nargs="+", + choices=["None", "triton", "fa3", "flashinfer"], + default=["None"], help="""decode attention kernel used in llm""", ) parser.add_argument( "--llm_kv_type", type=str, - choices=[None, "int8kv", "int4kv", "fp8kv"], - default=None, + choices=["None", "int8kv", "int4kv", "fp8kv"], + default="None", help="""kv type used in llm, None for dtype that llm used in config.json""", ) parser.add_argument( diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 8ce31dd657..184e4d7a90 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -118,9 +118,13 @@ class StartArgs: vit_quant_cfg: Optional[str] = field(default=None) enable_flashinfer_prefill: bool = field(default=False) enable_flashinfer_decode: bool = field(default=False) - llm_prefill_att_backend: str = field(default=None, metadata={"choices": [None, "triton", "fa3", "flashinfer"]}) - llm_decode_att_backend: str = field(default=None, metadata={"choices": [None, "triton", "fa3", "flashinfer"]}) - llm_kv_type: str = field(default=None, metadata={"choices": [None, "int8kv", "int4kv", "fp8kv"]}) + llm_prefill_att_backend: List[str] = field( + default=["None"], metadata={"choices": ["None", "triton", "fa3", "flashinfer"]} + ) + llm_decode_att_backend: List[str] = field( + default=["None"], metadata={"choices": ["None", "triton", "fa3", "flashinfer"]} + ) + llm_kv_type: str = field(default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv"]}) llm_kv_quant_group_size: int = field(default=8) sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]}) penalty_counter_mode: str = field( From f6d85080734c15b2a86dd4a2c83abb50b5467d5e Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 05:21:04 +0000 Subject: [PATCH 067/114] fix --- lightllm/common/basemodel/basemodel.py | 22 ++++++++++++++++++++-- lightllm/common/basemodel/infer_struct.py | 12 ++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index f1d41af88f..9919c390f2 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -123,9 +123,13 @@ def __init__(self, kvargs): self._wait_other_modules_ready() self._init_att_backend() + self._init_att_backend1() logger.info(f"use prefill att backend: {self.prefill_att_backend.__class__.__name__}") logger.info(f"use decode att backend: {self.decode_att_backend.__class__.__name__}") + if self.prefill_att_backend1 is not None: + logger.info(f"use prefill att backend1: {self.prefill_att_backend1.__class__.__name__}") + logger.info(f"use decode att backend1: {self.decode_att_backend1.__class__.__name__}") self._autotune_warmup() self._init_padded_req() @@ -247,8 +251,14 @@ def _init_some_value(self): return def _init_att_backend(self): - self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class()(model=self) - self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class()(model=self) + self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class(index=0)(model=self) + self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class(index=0)(model=self) + return + + def _init_att_backend1(self): + # self.prefill_att_backend1 是给后续有模型支持不同层用不同的att模块时,保留的扩展。 + self.prefill_att_backend1: BaseAttBackend = None + self.decode_att_backend1: BaseAttBackend = None return def _init_cudagraph(self): @@ -326,8 +336,16 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) if infer_state.is_prefill: infer_state.prefill_att_state = self.prefill_att_backend.create_att_prefill_state(infer_state=infer_state) + if self.prefill_att_backend1 is not None: + infer_state.prefill_att_state1 = self.prefill_att_backend1.create_att_prefill_state( + infer_state=infer_state + ) else: infer_state.decode_att_state = self.decode_att_backend.create_att_decode_state(infer_state=infer_state) + if self.decode_att_backend1 is not None: + infer_state.decode_att_state1 = self.decode_att_backend1.create_att_decode_state( + infer_state=infer_state + ) return infer_state diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 8951c63f1f..cd826414fd 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -24,6 +24,10 @@ def __init__(self): self.prefill_att_state: BasePrefillAttState = None self.decode_att_state: BaseDecodeAttState = None + # 保留的扩展, 支持线性att与标准att混合使用时使用 + self.prefill_att_state1: BasePrefillAttState = None + self.decode_att_state1: BaseDecodeAttState = None + self.input_ids: torch.Tensor = None self.batch_size: int = None self.total_token_num: int = None @@ -123,8 +127,12 @@ def init_some_extra_state(self, model): def init_att_state(self): if self.is_prefill: self.prefill_att_state.init_state() + if self.prefill_att_state1 is not None: + self.prefill_att_state1.init_state() else: self.decode_att_state.init_state() + if self.decode_att_state1 is not None: + self.decode_att_state1.init_state() def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"): for attr_name, attr_value in vars(new_infer_state).items(): @@ -132,6 +140,10 @@ def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"): attr_ = getattr(self, attr_name, None) if attr_ is not None and attr_.data_ptr() != attr_value.data_ptr(): attr_.copy_(attr_value, non_blocking=True) + + self.decode_att_state.copy_for_decode_cuda_graph() + if self.decode_att_state1 is not None: + self.decode_att_state1.copy_for_decode_cuda_graph() return def prefill_dp_balance(self, input_ids: torch.Tensor): From 4aa45e15e9ce8f7e29c49b4069d55cef12615357 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 05:22:43 +0000 Subject: [PATCH 068/114] fix --- lightllm/models/bloom/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/bloom/model.py b/lightllm/models/bloom/model.py index 451c16fc1a..80ea4e058a 100644 --- a/lightllm/models/bloom/model.py +++ b/lightllm/models/bloom/model.py @@ -5,7 +5,7 @@ from lightllm.models.bloom.layer_weights.pre_and_post_layer_weight import BloomPreAndPostLayerWeight from lightllm.models.bloom.layer_weights.transformer_layer_weight import BloomTransformerLayerWeight from lightllm.common.basemodel import InferStateInfo, TpPartBaseModel -from lightllm.common.basemodel.attention import TritonAttBackend +from lightllm.common.basemodel.attention.triton.fp import TritonAttBackend @ModelRegistry("bloom") From 7b24aceb6b8e11fb1fa893bd965c66b5e6f971e5 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 05:24:56 +0000 Subject: [PATCH 069/114] fix --- lightllm/server/core/objs/start_args_type.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 184e4d7a90..4ca39de9a9 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -119,10 +119,10 @@ class StartArgs: enable_flashinfer_prefill: bool = field(default=False) enable_flashinfer_decode: bool = field(default=False) llm_prefill_att_backend: List[str] = field( - default=["None"], metadata={"choices": ["None", "triton", "fa3", "flashinfer"]} + default=("None",), metadata={"choices": ["None", "triton", "fa3", "flashinfer"]} ) llm_decode_att_backend: List[str] = field( - default=["None"], metadata={"choices": ["None", "triton", "fa3", "flashinfer"]} + default=("None",), metadata={"choices": ["None", "triton", "fa3", "flashinfer"]} ) llm_kv_type: str = field(default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv"]}) llm_kv_quant_group_size: int = field(default=8) From e129535c17c2b9ba6de06af4d02c50c541c966df Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 05:25:45 +0000 Subject: [PATCH 070/114] fix --- lightllm/common/basemodel/attention/__init__.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/lightllm/common/basemodel/attention/__init__.py b/lightllm/common/basemodel/attention/__init__.py index b8384e85ab..80df545498 100644 --- a/lightllm/common/basemodel/attention/__init__.py +++ b/lightllm/common/basemodel/attention/__init__.py @@ -1,11 +1,14 @@ from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl -from .triton_backend import TritonAttBackend, TritonPrefillAttState, TritonDecodeAttState -from .int4kv_triton_backend import Int4kvTritonAttBackend -from .int8kv_triton_backend import Int8kvTritonAttBackend -from .fa3_backend import Fa3AttBackend -from .fp8_fa3_backend import Fp8Fa3AttBackend -from .flashinfer_backend import FlashInferAttBackend -from .fp8_flashinfer_backend import Fp8FlashInferAttBackend +from .triton.fp import TritonAttBackend +from .triton.int4kv import Int4kvTritonAttBackend +from .triton.int8kv import Int8kvTritonAttBackend +from .triton.mla import MlaTritonAttBackend +from .fa3.fp import Fa3AttBackend +from .fa3.fp8 import Fp8Fa3AttBackend +from .fa3.mla import MlaFa3AttBackend +from .flashinfer.fp8 import Fp8FlashInferAttBackend +from .flashinfer.fp import FlashInferAttBackend +from .flashinfer.mla import MlaFlashInferAttBackend from .create_utils import ( get_prefill_att_backend_class, From f4fc09eda20bfef9ab8f07f35d02fa8bd6099436 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 05:29:25 +0000 Subject: [PATCH 071/114] fix --- lightllm/common/kv_cache_mem_manager/mem_utils.py | 2 +- lightllm/models/bloom/model.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/mem_utils.py b/lightllm/common/kv_cache_mem_manager/mem_utils.py index 2960ac6d36..1ff58b89a0 100644 --- a/lightllm/common/kv_cache_mem_manager/mem_utils.py +++ b/lightllm/common/kv_cache_mem_manager/mem_utils.py @@ -34,7 +34,7 @@ def select_mem_manager_class(): memory_manager_class = PPLINT4KVMemoryManager elif get_env_start_args().llm_kv_type == "fp8kv": memory_manager_class = ExportCalibrationMemoryManager - elif get_env_start_args().llm_kv_type is None: + elif get_env_start_args().llm_kv_type == "None": memory_manager_class = MemoryManager logger.info(f"Model kv cache using mem_manager class: {memory_manager_class}") diff --git a/lightllm/models/bloom/model.py b/lightllm/models/bloom/model.py index 80ea4e058a..925620bf96 100644 --- a/lightllm/models/bloom/model.py +++ b/lightllm/models/bloom/model.py @@ -38,6 +38,6 @@ def _reset_num_key_value_heads(self): return def _init_att_backend(self): - self.prefill_att_backend = TritonAttBackend() - self.decode_att_backend = TritonAttBackend() + self.prefill_att_backend = TritonAttBackend(self) + self.decode_att_backend = TritonAttBackend(self) return From 51b9b93cf80208c465824f7bf076921f2eb7b5ea Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 05:34:34 +0000 Subject: [PATCH 072/114] fix --- lightllm/common/basemodel/basemodel.py | 4 --- lightllm/models/deepseek2/model.py | 45 +++----------------------- lightllm/models/llama/model.py | 3 -- lightllm/models/mistral/model.py | 3 -- lightllm/models/qwen2_vl/model.py | 3 -- lightllm/models/qwen3_vl/model.py | 3 -- lightllm/models/qwen3_vl_moe/model.py | 3 -- 7 files changed, 5 insertions(+), 59 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 9919c390f2..6513e25db9 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -118,7 +118,6 @@ def __init__(self, kvargs): self._init_infer_layer() self._init_some_value() self._init_custom() - self._init_inferstate_cls() # wait必须在init cudagraph 之前,避免错误捕获 self._wait_other_modules_ready() @@ -156,9 +155,6 @@ def _init_config(self): self.config["vocab_size"] = self.finetune_config.vocab_size return - def _init_inferstate_cls(self): - pass - @final def _verify_must(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index e4ce7c8269..95dd4fe519 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -6,49 +6,16 @@ from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo -from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights - from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num from lightllm.distributed.communication_op import dist_group_manager -from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id - +from lightllm.common.basemodel.attention import get_mla_decode_att_backend_class, get_mla_prefill_att_backend_class logger = init_logger(__name__) -class DeepSeek2FlashInferStateExtraInfo: - def __init__(self, model): - num_heads = model.config["num_attention_heads"] - self.tp_q_head_num = num_heads // get_dp_world_size() - self.qk_nope_head_dim = model.qk_nope_head_dim - self.qk_rope_head_dim = model.qk_rope_head_dim - self.kv_lora_rank = model.kv_lora_rank - self.q_data_type = model.data_type - self.kv_data_type = model.data_type - self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) - self.max_seq_length = model.max_seq_length - self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) - self.kv_indices_buffer = [ - torch.empty( - model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() - ), - torch.empty( - model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() - ), - ] - if model.config["rope_scaling"] is not None: - rope_scaling = model.config["rope_scaling"] - mscale_all_dim = rope_scaling.get("mscale_all_dim", 0) - scaling_factor = rope_scaling["factor"] - if mscale_all_dim: - mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) - self.softmax_scale = self.softmax_scale * mscale * mscale - - @ModelRegistry(["deepseek_v2", "deepseek_v3"]) class Deepseek2TpPartModel(LlamaTpPartModel): # weight class @@ -67,12 +34,10 @@ def __init__(self, kvargs): super().__init__(kvargs) return - def _init_inferstate_cls(self): - if get_env_start_args().enable_fa3: - self.infer_state_class = Deepseek2FlashAttentionStateInfo - elif self.enable_flashinfer: - self.infer_state_class = Deepseek2FlashInferStateInfo - self.flashinfer_extra_state = DeepSeek2FlashInferStateExtraInfo(self) + def _init_att_backend(self): + self.prefill_att_backend = get_mla_prefill_att_backend_class(index=0)(model=self) + self.decode_att_backend = get_mla_decode_att_backend_class(index=0)(model=self) + return def _init_some_value(self): super()._init_some_value() diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index 438e2f9575..6b616ef29b 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -70,9 +70,6 @@ def _init_mem_manager(self): ) return - def _init_inferstate_cls(self): - pass - def _init_custom(self): """ 模型特殊的一些初始化 diff --git a/lightllm/models/mistral/model.py b/lightllm/models/mistral/model.py index 3dc19ea799..d2bfeaa952 100644 --- a/lightllm/models/mistral/model.py +++ b/lightllm/models/mistral/model.py @@ -42,9 +42,6 @@ def _init_custom(self): self._init_to_get_rotary() return - def _init_inferstate_cls(self): - pass - def _init_mem_manager(self): # Dealing with head_dim_!=n_embed // num_attention_heads scenarios, such as mistral 13B head_dim = self.config["hidden_size"] // self.config["num_attention_heads"] diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py index 61dd06773f..dd4181fbfb 100644 --- a/lightllm/models/qwen2_vl/model.py +++ b/lightllm/models/qwen2_vl/model.py @@ -95,9 +95,6 @@ def __init__(self, kvargs): super().__init__(kvargs) return - def _init_inferstate_cls(self): - pass - def _init_config(self): with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: self.config = json.load(json_file) diff --git a/lightllm/models/qwen3_vl/model.py b/lightllm/models/qwen3_vl/model.py index 0d8a81f671..74aa33e3c0 100644 --- a/lightllm/models/qwen3_vl/model.py +++ b/lightllm/models/qwen3_vl/model.py @@ -37,9 +37,6 @@ def __init__(self, kvargs): super().__init__(kvargs) return - def _init_inferstate_cls(self): - pass - def _init_config(self): with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: all_config = json.load(json_file) diff --git a/lightllm/models/qwen3_vl_moe/model.py b/lightllm/models/qwen3_vl_moe/model.py index b11f22fdb7..cc1201de2c 100644 --- a/lightllm/models/qwen3_vl_moe/model.py +++ b/lightllm/models/qwen3_vl_moe/model.py @@ -25,9 +25,6 @@ def __init__(self, kvargs): super().__init__(kvargs) return - def _init_inferstate_cls(self): - pass - def _init_config(self): with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: all_config = json.load(json_file) From f76dfb17bad7d776061a6a10d093f2d19aae8df6 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 05:35:49 +0000 Subject: [PATCH 073/114] fix --- .../deepseek2/flashattention_infer_struct.py | 65 ----------- .../models/deepseek2/flashinfer_struct.py | 106 ------------------ lightllm/models/deepseek2/model.py | 2 - 3 files changed, 173 deletions(-) delete mode 100644 lightllm/models/deepseek2/flashattention_infer_struct.py delete mode 100644 lightllm/models/deepseek2/flashinfer_struct.py diff --git a/lightllm/models/deepseek2/flashattention_infer_struct.py b/lightllm/models/deepseek2/flashattention_infer_struct.py deleted file mode 100644 index 72ba8a43b1..0000000000 --- a/lightllm/models/deepseek2/flashattention_infer_struct.py +++ /dev/null @@ -1,65 +0,0 @@ -import os -import torch -import numpy as np -import torch.distributed as dist -from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.utils.dist_utils import get_current_device_id -from lightllm.utils.envs_utils import get_env_start_args -from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy - - -class Deepseek2FlashAttentionStateInfo(Deepseek2InferStateInfo): - _shared_page_table_buffer = None - - def __init__(self): - super().__init__() - - @classmethod - def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): - if cls._shared_page_table_buffer is None: - cls._shared_page_table_buffer = [ - torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()), - torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()), - ] - return cls._shared_page_table_buffer - - def init_some_extra_state(self, model): - super().init_some_extra_state(model) - args_mtp_step = get_env_start_args().mtp_step - if self.is_prefill: - self.cu_seqlens_q = self.b1_cu_q_seq_len - self.cu_seqlens_k = self.b1_cu_kv_seq_len - self.has_prefix_kv = self.max_cache_len > 0 - if self.has_prefix_kv: - self.cu_seqlens_prefix_k = torch.nn.functional.pad( - torch.cumsum(self.b_ready_cache_len, dim=0, dtype=torch.int32), (1, 0) - ) - self.prefix_k_max_len = self.max_cache_len - self.prefix_total_token_num = self.prefix_total_token_num - else: - # Meta information of flashattention for decoding - self.cu_seqlens_q = self.b1_cu_q_seq_len - self.cu_seqlens_k = self.b1_cu_kv_seq_len - max_seq_len_k = self.max_kv_seq_len - att_batch_size = self.batch_size // (args_mtp_step + 1) - if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: - page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer( - model.graph_max_batch_size, model.graph_max_len_in_batch - ) - self.page_table = page_buffer[self.microbatch_index][ - : att_batch_size * model.graph_max_len_in_batch - ].view(att_batch_size, model.graph_max_len_in_batch) - else: - self.page_table = torch.empty((att_batch_size, self.max_len_in_batch), dtype=torch.int32).to( - self.input_ids.device - ) - page_table_copy( - page_table=self.page_table[:, :max_seq_len_k], - req_to_token_indexs=model.req_manager.req_to_token_indexs, - b_req_idx=self.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], - ) - if args_mtp_step > 0: - self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() - else: - self.b_att_seq_len = self.b_seq_len - return diff --git a/lightllm/models/deepseek2/flashinfer_struct.py b/lightllm/models/deepseek2/flashinfer_struct.py deleted file mode 100644 index db6386f797..0000000000 --- a/lightllm/models/deepseek2/flashinfer_struct.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -import torch -import numpy as np -import torch.distributed as dist -from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.utils.envs_utils import get_env_start_args -from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index - - -class Deepseek2FlashInferStateInfo(Deepseek2InferStateInfo): - def __init__(self): - super().__init__() - self.prefill_wrapper = None - self.decode_wrapper = None - self.flashinfer_extra_state = None - - def init_some_extra_state(self, model): - super().init_some_extra_state(model) - self.flashinfer_extra_state = model.flashinfer_extra_state - - import flashinfer - - if not self.is_prefill: - if get_env_start_args().enable_flashinfer_decode: - self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(self.input_ids.device) - if self.batch_size <= model.graph_max_batch_size: - self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ - : self.batch_size * self.flashinfer_extra_state.max_seq_length - ] - else: - self.kv_indices = torch.empty( - self.batch_size * self.flashinfer_extra_state.max_seq_length, - dtype=torch.int32, - device=self.input_ids.device, - ) - repack_kv_index( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - self.b_seq_len, - self.b_start_loc, - self.max_len_in_batch, - self.kv_indices, - ) - if self.decode_wrapper is None: - self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( - self.flashinfer_extra_state.workspace_buffer, - use_cuda_graph=True, - qo_indptr=self.q_indptr, - kv_indices=self.kv_indices, - kv_indptr=self.kv_starts, - kv_len_arr=self.b_seq_len, - ) - self.decode_wrapper.plan( - self.q_indptr, - self.kv_starts, - self.kv_indices, - self.b_seq_len, - self.flashinfer_extra_state.tp_q_head_num, - self.flashinfer_extra_state.kv_lora_rank, - self.flashinfer_extra_state.qk_rope_head_dim, - 1, - False, # causal - self.flashinfer_extra_state.softmax_scale, - self.flashinfer_extra_state.q_data_type, - self.flashinfer_extra_state.kv_data_type, - ) - else: - if get_env_start_args().enable_flashinfer_prefill: - q_starts = self.b1_cu_q_seq_len.int() - kv_starts = self.b1_kv_start_loc.int() - if self.prefill_wrapper is None: - self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( - self.flashinfer_extra_state.workspace_buffer, "NHD" - ) - self.prefill_wrapper.plan( - qo_indptr=q_starts, - kv_indptr=kv_starts, - num_qo_heads=self.flashinfer_extra_state.tp_q_head_num, - num_kv_heads=self.flashinfer_extra_state.tp_q_head_num, - head_dim_qk=self.flashinfer_extra_state.qk_nope_head_dim - + self.flashinfer_extra_state.qk_rope_head_dim, - head_dim_vo=self.flashinfer_extra_state.qk_nope_head_dim, - q_data_type=self.flashinfer_extra_state.q_data_type, - causal=True, - sm_scale=self.flashinfer_extra_state.softmax_scale, - ) - return - - def copy_for_cuda_graph(self, new_infer_state): - super().copy_for_cuda_graph(new_infer_state) - if get_env_start_args().enable_flashinfer_decode and not self.is_prefill: - self.decode_wrapper.plan( - new_infer_state.q_indptr, - new_infer_state.kv_starts, - new_infer_state.kv_indices, - new_infer_state.b_seq_len, - new_infer_state.flashinfer_extra_state.tp_q_head_num, - new_infer_state.flashinfer_extra_state.kv_lora_rank, - new_infer_state.flashinfer_extra_state.qk_rope_head_dim, - 1, - False, # causal - new_infer_state.flashinfer_extra_state.softmax_scale, - new_infer_state.flashinfer_extra_state.q_data_type, - new_infer_state.flashinfer_extra_state.kv_data_type, - ) - return diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index 95dd4fe519..f9fe06a058 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -4,8 +4,6 @@ from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo -from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.utils.log_utils import init_logger From e2018d68166836c343ca013f42324932286edfce Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 05:38:42 +0000 Subject: [PATCH 074/114] fix --- lightllm/common/basemodel/attention/base_att.py | 4 ---- lightllm/common/basemodel/attention/fa3/fp.py | 3 --- lightllm/common/basemodel/attention/fa3/fp8.py | 3 --- lightllm/common/basemodel/attention/fa3/mla.py | 3 --- lightllm/common/basemodel/attention/flashinfer/fp.py | 3 --- lightllm/common/basemodel/attention/flashinfer/mla.py | 3 --- lightllm/common/basemodel/attention/triton/fp.py | 3 --- lightllm/common/basemodel/attention/triton/int4kv.py | 3 --- lightllm/common/basemodel/attention/triton/int8kv.py | 3 --- lightllm/common/basemodel/attention/triton/mla.py | 3 --- 10 files changed, 31 deletions(-) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index 167dd16011..c15a254567 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -77,10 +77,6 @@ class BasePrefillAttState(ABC): def init_state(self): pass - @abstractmethod - def copy_for_prefill_cuda_graph(self, new_state: "BasePrefillAttState"): - pass - @abstractmethod def prefill_att( self, diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 4c0d2cca14..7058cdb394 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -57,9 +57,6 @@ def init_state(self): ] ) - def copy_for_prefill_cuda_graph(self, new_state: "Fa3PrefillAttState"): - pass - def prefill_att( self, q: torch.Tensor, diff --git a/lightllm/common/basemodel/attention/fa3/fp8.py b/lightllm/common/basemodel/attention/fa3/fp8.py index 53974730ff..41dcde125a 100644 --- a/lightllm/common/basemodel/attention/fa3/fp8.py +++ b/lightllm/common/basemodel/attention/fa3/fp8.py @@ -67,9 +67,6 @@ def init_state(self): ) ) - def copy_for_prefill_cuda_graph(self, new_state: "Fp8Fa3PrefillAttState"): - pass - def prefill_att( self, q: torch.Tensor, diff --git a/lightllm/common/basemodel/attention/fa3/mla.py b/lightllm/common/basemodel/attention/fa3/mla.py index 187a8ae034..ed6182fe43 100644 --- a/lightllm/common/basemodel/attention/fa3/mla.py +++ b/lightllm/common/basemodel/attention/fa3/mla.py @@ -47,9 +47,6 @@ def init_state(self): self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() - def copy_for_prefill_cuda_graph(self, new_state: "MlaFa3PrefillAttState"): - pass - def prefill_att( self, q: torch.Tensor, diff --git a/lightllm/common/basemodel/attention/flashinfer/fp.py b/lightllm/common/basemodel/attention/flashinfer/fp.py index af3ba7d015..ed3adf4aca 100644 --- a/lightllm/common/basemodel/attention/flashinfer/fp.py +++ b/lightllm/common/basemodel/attention/flashinfer/fp.py @@ -84,9 +84,6 @@ def init_state(self): kv_data_type=self.backend.kv_data_type, ) - def copy_for_prefill_cuda_graph(self, new_state: "FlashInferPrefillAttState"): - pass - def prefill_att( self, q: torch.Tensor, diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py index 7ba50eeb85..75883dd124 100644 --- a/lightllm/common/basemodel/attention/flashinfer/mla.py +++ b/lightllm/common/basemodel/attention/flashinfer/mla.py @@ -74,9 +74,6 @@ def init_state(self): ) return - def copy_for_prefill_cuda_graph(self, new_state: "MlaFlashInferPrefillAttState"): - pass - def prefill_att( self, q: torch.Tensor, diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py index 0b65cf8bb6..b427a9d544 100644 --- a/lightllm/common/basemodel/attention/triton/fp.py +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -17,9 +17,6 @@ class TritonPrefillAttState(BasePrefillAttState): def init_state(self): pass - def copy_for_prefill_cuda_graph(self, new_state: "TritonPrefillAttState"): - pass - def prefill_att( self, q: torch.Tensor, diff --git a/lightllm/common/basemodel/attention/triton/int4kv.py b/lightllm/common/basemodel/attention/triton/int4kv.py index 14aae90e63..74b595802e 100644 --- a/lightllm/common/basemodel/attention/triton/int4kv.py +++ b/lightllm/common/basemodel/attention/triton/int4kv.py @@ -29,9 +29,6 @@ def init_state(self): - self.infer_state.b_seq_len ) - def copy_for_prefill_cuda_graph(self, new_state: "Int4kvTritonPrefillAttState"): - pass - def prefill_att( self, q: torch.Tensor, diff --git a/lightllm/common/basemodel/attention/triton/int8kv.py b/lightllm/common/basemodel/attention/triton/int8kv.py index e7bb34e2f9..492f51cccb 100644 --- a/lightllm/common/basemodel/attention/triton/int8kv.py +++ b/lightllm/common/basemodel/attention/triton/int8kv.py @@ -30,9 +30,6 @@ def init_state(self): - self.infer_state.b_seq_len ) - def copy_for_prefill_cuda_graph(self, new_state: "Int8kvTritonPrefillAttState"): - pass - def prefill_att( self, q: torch.Tensor, diff --git a/lightllm/common/basemodel/attention/triton/mla.py b/lightllm/common/basemodel/attention/triton/mla.py index e0aced276b..b628d7239e 100644 --- a/lightllm/common/basemodel/attention/triton/mla.py +++ b/lightllm/common/basemodel/attention/triton/mla.py @@ -17,9 +17,6 @@ class MlaTritonPrefillAttState(BasePrefillAttState): def init_state(self): pass - def copy_for_prefill_cuda_graph(self, new_state: "MlaTritonPrefillAttState"): - pass - def prefill_att( self, q: torch.Tensor, From df348ee16af5ce09be5487c7912b2b197ed82e2b Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 05:57:41 +0000 Subject: [PATCH 075/114] fix --- lightllm/common/basemodel/attention/flashinfer/fp.py | 2 +- lightllm/common/basemodel/attention/flashinfer/mla.py | 4 ++-- lightllm/common/basemodel/attention/triton/fp.py | 10 +++++----- lightllm/common/basemodel/attention/triton/int4kv.py | 2 +- lightllm/common/basemodel/attention/triton/int8kv.py | 2 +- lightllm/common/basemodel/attention/triton/mla.py | 2 +- lightllm/common/basemodel/infer_struct.py | 10 +++++++--- 7 files changed, 18 insertions(+), 14 deletions(-) diff --git a/lightllm/common/basemodel/attention/flashinfer/fp.py b/lightllm/common/basemodel/attention/flashinfer/fp.py index ed3adf4aca..21f30dd6f4 100644 --- a/lightllm/common/basemodel/attention/flashinfer/fp.py +++ b/lightllm/common/basemodel/attention/flashinfer/fp.py @@ -149,7 +149,7 @@ def init_state(self): self.infer_state.req_manager.req_to_token_indexs, self.infer_state.b_req_idx, self.infer_state.b_seq_len, - self.infer_state.b_start_loc, + self.infer_state.b_kv_start_loc, self.infer_state.max_kv_seq_len, self.kv_indices, ) diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py index 75883dd124..8786a86057 100644 --- a/lightllm/common/basemodel/attention/flashinfer/mla.py +++ b/lightllm/common/basemodel/attention/flashinfer/mla.py @@ -122,7 +122,7 @@ def init_state(self): self.kv_starts = self.infer_state.b1_cu_kv_seq_len - self.q_indptr = torch.arange(batch_size + 1, dtype=torch.int32).to(device) + self.q_indptr = torch.arange(batch_size + 1, dtype=torch.int32, device="cuda") if batch_size <= model.graph_max_batch_size and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch: self.kv_indices = self.backend.kv_indices_buffer[self.infer_state.microbatch_index][ : batch_size * self.backend.max_seq_length @@ -138,7 +138,7 @@ def init_state(self): self.infer_state.req_manager.req_to_token_indexs, self.infer_state.b_req_idx, self.infer_state.b_seq_len, - self.infer_state.b_start_loc, + self.infer_state.b_kv_start_loc, self.infer_state.max_kv_seq_len, self.kv_indices, ) diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py index b427a9d544..da9e5205d4 100644 --- a/lightllm/common/basemodel/attention/triton/fp.py +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -51,7 +51,7 @@ def _alibi_prefill_att( out, self.infer_state.b_req_idx, att_control.tp_alibi, - self.infer_state.b_start_loc, + self.infer_state.b_q_start_loc, self.infer_state.b_seq_len, self.infer_state.b_ready_cache_len, self.infer_state.max_len_in_batch, @@ -69,7 +69,7 @@ def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, v, out, self.infer_state.b_req_idx, - self.infer_state.b_start_loc, + self.infer_state.b_q_start_loc, self.infer_state.b_seq_len, self.infer_state.b_ready_cache_len, self.infer_state.max_len_in_batch, @@ -127,7 +127,7 @@ def _alibi_decode_att( att_control.tp_alibi, self.infer_state.req_manager.req_to_token_indexs, self.infer_state.b_req_idx, - self.infer_state.b_start_loc, + self.infer_state.b_kv_start_loc, self.infer_state.b_seq_len, self.infer_state.max_len_in_batch, self.infer_state.total_token_num, @@ -253,7 +253,7 @@ def _normal_decode_stage3_att( att_m_tensor, Req_to_tokens=self.infer_state.req_manager.req_to_token_indexs, B_req_idx=self.infer_state.b_req_idx, - B_Start_Loc=self.infer_state.b_start_loc, + B_Start_Loc=self.infer_state.b_kv_start_loc, B_Seqlen=self.infer_state.b_seq_len, max_len_in_batch=self.infer_state.max_len_in_batch, ) @@ -269,7 +269,7 @@ def _normal_decode_stage3_att( o_tensor.view(calcu_shape1), req_to_tokens=self.infer_state.req_manager.req_to_token_indexs, b_req_idx=self.infer_state.b_req_idx, - b_start_loc=self.infer_state.b_start_loc, + b_start_loc=self.infer_state.b_kv_start_loc, b_seq_len=self.infer_state.b_seq_len, ) return o_tensor diff --git a/lightllm/common/basemodel/attention/triton/int4kv.py b/lightllm/common/basemodel/attention/triton/int4kv.py index 74b595802e..14a194a438 100644 --- a/lightllm/common/basemodel/attention/triton/int4kv.py +++ b/lightllm/common/basemodel/attention/triton/int4kv.py @@ -103,7 +103,7 @@ def _groupsize_quant_prefill_att( k=k_dequant, v=v_dequant, o=o_tensor, - b_start_loc=self.infer_state.b_start_loc, + b_start_loc=self.infer_state.b_q_start_loc, b_kv_start_loc=self.b_kv_start_loc, b_seq_len=self.infer_state.b_seq_len, max_q_input_len=self.infer_state.max_q_seq_len, diff --git a/lightllm/common/basemodel/attention/triton/int8kv.py b/lightllm/common/basemodel/attention/triton/int8kv.py index 492f51cccb..1471fbd699 100644 --- a/lightllm/common/basemodel/attention/triton/int8kv.py +++ b/lightllm/common/basemodel/attention/triton/int8kv.py @@ -104,7 +104,7 @@ def _groupsize_quant_prefill_att( k=k_dequant, v=v_dequant, o=o_tensor, - b_start_loc=self.infer_state.b_start_loc, + b_start_loc=self.infer_state.b_q_start_loc, b_kv_start_loc=self.b_kv_start_loc, b_seq_len=self.infer_state.b_seq_len, max_q_input_len=self.infer_state.max_q_seq_len, diff --git a/lightllm/common/basemodel/attention/triton/mla.py b/lightllm/common/basemodel/attention/triton/mla.py index b628d7239e..6fe171120e 100644 --- a/lightllm/common/basemodel/attention/triton/mla.py +++ b/lightllm/common/basemodel/attention/triton/mla.py @@ -55,7 +55,7 @@ def _mla_prefill_att( k_rope, v, o_tensor, - self.infer_state.b_start_loc, + self.infer_state.b_q_start_loc, self.infer_state.b1_cu_kv_seq_len, self.infer_state.b_seq_len, self.infer_state.b_ready_cache_len, diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index cd826414fd..e046197077 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -32,7 +32,6 @@ def __init__(self): self.batch_size: int = None self.total_token_num: int = None self.b_req_idx: torch.Tensor = None - self.b_start_loc: torch.Tensor = None self.b_ready_cache_len: torch.Tensor = None # only for prefill prompt cache used. self.b_shared_seq_len: torch.Tensor = None # only for diverse mode used in decode phase. @@ -77,6 +76,11 @@ def __init__(self): self.max_q_seq_len: int = None self.max_kv_seq_len: int = None + # prefill 用 + self.b_q_start_loc: torch.Tensor = None + # decode 用 + self.b_kv_start_loc: torch.Tensor = None + # 一些特殊模型,特殊模式使用的输入变量,本身这些变量不适合放在 # inferstate的基类中,但是为了代码的简洁和方便,都放在基类中 # 进行管理。注意这些成员变量只会在特定的模型和模式下才会生效。 @@ -111,7 +115,7 @@ def init_some_extra_state(self, model): b_ready_cache_len=self.b_ready_cache_len, b_seq_len=self.b_seq_len, ) - self.b_start_loc = self.b1_cu_q_seq_len[0:-1] + self.b_q_start_loc = self.b1_cu_q_seq_len[0:-1] else: ( self.b_q_seq_len, @@ -122,7 +126,7 @@ def init_some_extra_state(self, model): ) = gen_decode_params(self.b_seq_len) # TODO: check the correctness self.max_kv_seq_len = self.max_len_in_batch - self.b_start_loc = self.b1_cu_kv_seq_len[0:-1] + self.b_kv_start_loc = self.b1_cu_kv_seq_len[0:-1] def init_att_state(self): if self.is_prefill: From 6a637e5f3b5cfe1e5f38c921c45aeefacc2b9f99 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 06:00:19 +0000 Subject: [PATCH 076/114] fix --- lightllm/models/qwen2_vl/infer_struct.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/lightllm/models/qwen2_vl/infer_struct.py b/lightllm/models/qwen2_vl/infer_struct.py index c7b5ba041c..747be932d9 100644 --- a/lightllm/models/qwen2_vl/infer_struct.py +++ b/lightllm/models/qwen2_vl/infer_struct.py @@ -32,10 +32,6 @@ def init_some_extra_state(self, model): self.position_ids = self.position_ids.contiguous() self.position_cos = model._cos_cached[self.position_ids] self.position_sin = model._sin_cached[self.position_ids] - if get_env_start_args().enable_fa3: - self.max_seq_len = self.max_kv_seq_len - self.q_max_seq_len = self.max_q_seq_len - self.init_flash_attention_state_func(model) return def get_mrope_position(self, multimodal_params: List[dict]) -> torch.Tensor: @@ -82,6 +78,6 @@ def get_mrope_position(self, multimodal_params: List[dict]) -> torch.Tensor: position_ids=position_ids, b_ready_cache_len=self.b_ready_cache_len, b_q_seq_len=self.b_q_seq_len, - b_start_loc=self.b_start_loc, + b_start_loc=self.b_q_start_loc, ) return position_ids From e0f14b65b9bee88cceef3ada47e73974edc44c52 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 06:21:45 +0000 Subject: [PATCH 077/114] fix --- .../common/basemodel/triton_kernel/destindex_copy_kv.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/destindex_copy_kv.py b/lightllm/common/basemodel/triton_kernel/destindex_copy_kv.py index be29f92667..060a92bf7c 100644 --- a/lightllm/common/basemodel/triton_kernel/destindex_copy_kv.py +++ b/lightllm/common/basemodel/triton_kernel/destindex_copy_kv.py @@ -16,12 +16,13 @@ def _fwd_kernel_destindex_copy_kv( stride_o_h, stride_o_d, head_num, + head_dim, BLOCK_DMODEL: tl.constexpr, BLOCK_HEAD: tl.constexpr, ): cur_index = tl.program_id(0) offs_h = tl.arange(0, BLOCK_HEAD) - offs_d = tl.arange(0, BLOCK_DMODEL) + offs_d = (tl.arange(0, BLOCK_DMODEL)) % head_dim dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) @@ -54,7 +55,8 @@ def destindex_copy_kv(K, DestLoc, Out): Out.stride(1), Out.stride(2), head_num, - BLOCK_DMODEL=head_dim, + head_dim, + BLOCK_DMODEL=triton.next_power_of_2(head_dim), BLOCK_HEAD=BLOCK_HEAD, num_warps=num_warps, num_stages=1, From 53bad8869c0c1f30f1666549b6341a9c1ce77fa0 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 06:24:09 +0000 Subject: [PATCH 078/114] fix --- .../layer_infer/transformer_layer_infer.py | 53 --- .../context_flashattention_nopad.py | 433 ------------------ .../phi3/triton_kernel/destindex_copy_kv.py | 192 -------- .../phi3/triton_kernel/flash_decoding.py | 37 -- .../triton_kernel/flash_decoding_stage1.py | 162 ------- .../triton_kernel/flash_decoding_stage2.py | 85 ---- 6 files changed, 962 deletions(-) delete mode 100644 lightllm/models/phi3/triton_kernel/context_flashattention_nopad.py delete mode 100644 lightllm/models/phi3/triton_kernel/destindex_copy_kv.py delete mode 100644 lightllm/models/phi3/triton_kernel/flash_decoding.py delete mode 100644 lightllm/models/phi3/triton_kernel/flash_decoding_stage1.py delete mode 100644 lightllm/models/phi3/triton_kernel/flash_decoding_stage2.py diff --git a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py index 806c59365b..0995b1414d 100755 --- a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py @@ -1,11 +1,5 @@ -import torch -from functools import partial from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.phi3.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.phi3.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd, -) -from lightllm.models.phi3.triton_kernel.destindex_copy_kv import destindex_copy_kv from lightllm.models.phi3.layer_weights.transformer_layer_weight import Phi3TransformerLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo @@ -17,12 +11,6 @@ def __init__(self, layer_num, network_config, mode=[]): super().__init__(layer_num, network_config, mode) return - def _bind_attention(self): - self._context_attention_kernel = partial(Phi3TransformerLayerInfer._context_attention_kernel, self) - self._copy_kv_to_mem_cache = partial(Phi3TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - self._token_attention_kernel = partial(Phi3TransformerLayerInfer._token_decode_attention_flashdecoding, self) - return - def _get_qkv(self, input_emb, infer_state: LlamaInferStateInfo, layer_weight: Phi3TransformerLayerWeight): q = layer_weight.q_proj.mm(input_emb.view(-1, self.embed_dim_)) cache_kv = layer_weight.kv_proj.mm( @@ -35,44 +23,3 @@ def _get_qkv(self, input_emb, infer_state: LlamaInferStateInfo, layer_weight: Ph infer_state.position_sin, ) return q, cache_kv - - def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager): - destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) - return - - def _context_attention_kernel( - self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - context_attention_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv[:, 0 : self.tp_k_head_num_, :], - kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, - ) - return o_tensor - - def _token_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - from lightllm.models.phi3.triton_kernel.flash_decoding import token_decode_attention_flash_decoding - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_v, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) diff --git a/lightllm/models/phi3/triton_kernel/context_flashattention_nopad.py b/lightllm/models/phi3/triton_kernel/context_flashattention_nopad.py deleted file mode 100644 index ee04c3367b..0000000000 --- a/lightllm/models/phi3/triton_kernel/context_flashattention_nopad.py +++ /dev/null @@ -1,433 +0,0 @@ -import torch - -import triton -import triton.language as tl -import math -import torch.nn.functional as F - -from lightllm.utils.device_utils import is_tesla - - -@triton.jit -def _fwd_kernel( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 - Out, - Req_to_tokens, - B_req_idx, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_req_to_tokens_b, - stride_req_to_tokens_s, - kv_group_num, - b_prompt_cache_len, - head_dim: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - - q = tl.load(Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim), other=0.0) - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - block_end_loc = tl.minimum((start_m + 1) * BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len) - - for start_n in range(0, block_mask * block_end_loc, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - kv_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n), - mask=(start_n + offs_n) < block_end_loc, - other=0, - ).to(tl.int64) - off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - k = tl.load( - K + off_k, mask=((start_n + offs_n[None, :]) < block_end_loc) & (offs_d[:, None] < head_dim), other=0.0 - ) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] + prompt_cache_len >= start_n + offs_n[None, :], qk, float("-100000000.0")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc_scale = tl.where(offs_m + prompt_cache_len >= start_n, acc_scale, 1.0) - acc = acc * acc_scale[:, None] - # update acc - off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - v = tl.load( - V + off_v, mask=((start_n + offs_n[:, None]) < block_end_loc) & (offs_d[None, :] < head_dim), other=0.0 - ) - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim)) - return - - -@torch.no_grad() -def context_attention_fwd( - q, k, v, o, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, max_input_len, req_to_token_indexs -): - BLOCK = 128 if not is_tesla() else 64 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - head_dim = Lq - BLOCK_DMODEL = triton.next_power_of_2(head_dim) - - sm_scale = 1.0 / (Lq ** 0.5) # 计算scale系数 - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - o, - req_to_token_indexs, - b_req_idx, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - req_to_token_indexs.stride(0), - req_to_token_indexs.stride(1), - kv_group_num=kv_group_num, - b_prompt_cache_len=b_prompt_cache_len, - head_dim=head_dim, - BLOCK_M=BLOCK, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - -@triton.jit -def _fwd_kernel_no_prompt_cache( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - head_dim, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - - q = tl.load(Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim), other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (offs_d[:, None] < head_dim), - other=0.0, - ) - # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (offs_d[None, :] < head_dim), - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim)) - return - - -@torch.no_grad() -def context_attention_fwd_no_prompt_cache(q, k, v, o, b_start_loc, b_seq_len, max_input_len): - BLOCK = 128 if not is_tesla() else 64 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - head_dim = Lq - BLOCK_DMODEL = triton.next_power_of_2(head_dim) - sm_scale = 1.0 / (Lq ** 0.5) # 计算scale系数 - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel_no_prompt_cache[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - kv_group_num=kv_group_num, - head_dim=head_dim, - BLOCK_M=BLOCK, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - -def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim, prompt_cache_len): - xq = xq.view(bs, seqlen, num_head, head_dim) - xk = xk.view(bs, seqlen + prompt_cache_len, num_head, head_dim) - xv = xv.view(bs, seqlen + prompt_cache_len, num_head, head_dim) - mask_cache = torch.ones((seqlen, prompt_cache_len)).cuda().unsqueeze(0).unsqueeze(0).cuda() - mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() - mask[mask == 0.0] = -100000000.0 - mask = torch.cat([mask_cache, mask], dim=-1) - mask = mask.repeat(bs, num_head, 1, 1) - keys = xk - values = xv - xq = xq.transpose(1, 2) - keys = keys.transpose(1, 2) - values = values.transpose(1, 2) - scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim) - scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq) - output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) - return output - - -def test(): - import torch - import numpy as np - - Z, H, N_CTX, D_HEAD = 10, 6, 500, 96 - dtype = torch.float16 - Z = 1 - q = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - k = torch.empty((Z * N_CTX + 7000, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - v = torch.empty((Z * N_CTX + 7000, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - o = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - req_to_token_indexs = torch.zeros((10, Z * N_CTX + 7000), dtype=torch.int32, device="cuda") - max_input_len = N_CTX - Z = 1 - b_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - b_req_idx = torch.ones((Z,), dtype=torch.int32, device="cuda") - b_prompt_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") - b_prompt_cache_len[0] = 0 - prompt_cache_len = 0 - - b_seq_len[0] = 500 - b_req_idx[0] = 0 - req_to_token_indexs[0][: prompt_cache_len + N_CTX] = torch.tensor( - np.arange(prompt_cache_len + N_CTX), dtype=torch.int32 - ).cuda() - - torch_out = [] - start = 0 - for i in range(Z): - end = start + b_seq_len[i] - torch_o = torch_att( - q[start:end], - k[start : end + prompt_cache_len], - v[start : end + prompt_cache_len], - 1, - b_seq_len[i], - H, - D_HEAD, - prompt_cache_len, - ) - start = end - torch_out.append(torch_o) - - torch_out = torch.cat(torch_out, dim=0) - - context_attention_fwd( - q, - k, - v, - o, - b_req_idx, - b_start_loc, - b_seq_len + prompt_cache_len, - b_prompt_cache_len, - max_input_len, - req_to_token_indexs, - ) - - # context_attention_fwd_no_prompt_cache( - # q, k, v, o, b_start_loc, b_seq_len, max_input_len - # ) - - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) - assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) diff --git a/lightllm/models/phi3/triton_kernel/destindex_copy_kv.py b/lightllm/models/phi3/triton_kernel/destindex_copy_kv.py deleted file mode 100644 index 4f31895ae0..0000000000 --- a/lightllm/models/phi3/triton_kernel/destindex_copy_kv.py +++ /dev/null @@ -1,192 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_destindex_copy_kv( - K, - Dest_loc, - Out, - stride_k_bs, - stride_k_h, - stride_k_d, - stride_o_bs, - stride_o_h, - stride_o_d, - head_num, - head_dim, - BLOCK_DMODEL: tl.constexpr, - BLOCK_HEAD: tl.constexpr, -): - cur_index = tl.program_id(0) - offs_h = tl.arange(0, BLOCK_HEAD) - offs_d = tl.arange(0, BLOCK_DMODEL) - - dest_index = tl.load(Dest_loc + cur_index) - - k_ptrs = K + cur_index * stride_k_bs + stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :] - o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :] - - k = tl.load(k_ptrs, mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim), other=0.0) - tl.store(o_ptrs, k, mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim)) - return - - -@torch.no_grad() -def destindex_copy_kv(K, DestLoc, Out): - seq_len = DestLoc.shape[0] - head_num = K.shape[1] - head_dim = K.shape[2] - assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2] - BLOCK_HEAD = triton.next_power_of_2(head_num) - BLOCK_DMODEL = triton.next_power_of_2(head_dim) - grid = (seq_len,) - num_warps = 1 - - _fwd_kernel_destindex_copy_kv[grid]( - K, - DestLoc, - Out, - K.stride(0), - K.stride(1), - K.stride(2), - Out.stride(0), - Out.stride(1), - Out.stride(2), - head_num, - head_dim, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_HEAD=BLOCK_HEAD, - num_warps=num_warps, - num_stages=1, - ) - return - - -@triton.jit -def _fwd_kernel_destindex_copy_quantize_kv( - K, - Dest_loc, - Out, - Out_scale, - stride_k_bs, - stride_k_h, - stride_k_d, - stride_o_bs, - stride_o_h, - stride_o_d, - stride_os_bs, - stride_os_h, - stride_os_d, - head_num, - head_dim, - BLOCK_DMODEL: tl.constexpr, - BLOCK_HEAD: tl.constexpr, -): - cur_index = tl.program_id(0) - offs_h = tl.arange(0, BLOCK_HEAD) - offs_d = tl.arange(0, BLOCK_DMODEL) - - dest_index = tl.load(Dest_loc + cur_index) - src_data = tl.load( - K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :], - mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim), - other=0.0, - ) - abs_data = tl.abs(src_data) - data_scale = (tl.max(abs_data, axis=1) / 127.0).to(Out_scale.dtype.element_ty)[:, None] - q_src_data = (src_data / data_scale).to(tl.int8) - o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :] - os_ptrs = Out_scale + dest_index * stride_os_bs + stride_os_h * offs_h[:, None] - tl.store(o_ptrs, q_src_data, mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim)) - tl.store(os_ptrs, data_scale, mask=(offs_h[:, None] < head_num)) - - -@torch.no_grad() -def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale): - seq_len = DestLoc.shape[0] - head_num = K.shape[1] - head_dim = K.shape[2] - assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2] - BLOCK_HEAD = triton.next_power_of_2(head_num) - BLOCK_DMODEL = triton.next_power_of_2(head_dim) - grid = (seq_len,) - num_warps = 1 - - _fwd_kernel_destindex_copy_quantize_kv[grid]( - K, - DestLoc, - Out, - Out_scale, - K.stride(0), - K.stride(1), - K.stride(2), - Out.stride(0), - Out.stride(1), - Out.stride(2), - Out_scale.stride(0), - Out_scale.stride(1), - Out_scale.stride(2), - head_num, - head_dim, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_HEAD=BLOCK_HEAD, - num_warps=num_warps, - num_stages=1, - ) - return - - -def test1(): - import time - - B, N_CTX, H, D = 32, 1024, 12, 96 - dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32, device="cuda") - - for _ in range(10): - destindex_copy_kv(src, dest_loc, dest) - torch.cuda.synchronize() - t1 = time.time() - for _ in range(1000): - destindex_copy_kv(src, dest_loc, dest) - torch.cuda.synchronize() - t2 = time.time() - - print("Time cost ", t2 - t1) - print("max ", torch.max(torch.abs(dest - src))) - print("mean ", torch.mean(torch.abs(dest - src))) - assert torch.allclose(src, dest, atol=1e-2, rtol=0) - - -def test2(): - import time - - B, N_CTX, H, D = 32, 1024, 12, 96 - src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32).cuda() - value_dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda().to(torch.int8) - scale_dest = torch.randn((B * N_CTX, H, 1), dtype=torch.float16).cuda() - - for _ in range(10): - destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) - torch.cuda.synchronize() - t1 = time.time() - for _ in range(1000): - destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) - torch.cuda.synchronize() - t2 = time.time() - - print("Time cost ", t2 - t1) - print("max ", torch.max(torch.abs(value_dest * scale_dest - src))) - print("mean ", torch.mean(torch.abs(value_dest * scale_dest - src))) - cos = torch.nn.CosineSimilarity(0) - print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32))) - - -if __name__ == "__main__": - test1() - test2() diff --git a/lightllm/models/phi3/triton_kernel/flash_decoding.py b/lightllm/models/phi3/triton_kernel/flash_decoding.py deleted file mode 100644 index e47e308864..0000000000 --- a/lightllm/models/phi3/triton_kernel/flash_decoding.py +++ /dev/null @@ -1,37 +0,0 @@ -import torch - - -def token_decode_attention_flash_decoding( - q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty -): - BLOCK_SEQ = 256 - batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch - calcu_shape1 = (batch_size, q_head_num, head_dim) - - from .flash_decoding_stage1 import flash_decode_stage1 - from .flash_decoding_stage2 import flash_decode_stage2 - - o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out - - mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" - ) - mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" - ) - - flash_decode_stage1( - q.view(calcu_shape1), - cache_k, - cache_v, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - mid_o, - mid_o_logexpsum, - BLOCK_SEQ, - ) - flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) - return o_tensor diff --git a/lightllm/models/phi3/triton_kernel/flash_decoding_stage1.py b/lightllm/models/phi3/triton_kernel/flash_decoding_stage1.py deleted file mode 100644 index f6d8b5abee..0000000000 --- a/lightllm/models/phi3/triton_kernel/flash_decoding_stage1.py +++ /dev/null @@ -1,162 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_flash_decode_stage1( - Q, - K, - V, - sm_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, - Mid_O, # [batch, head, seq_block_num, head_dim] - Mid_O_LogExpSum, # [batch, head, seq_block_num] - stride_req_to_tokens_b, - stride_req_to_tokens_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_mid_ob, - stride_mid_oh, - stride_mid_os, - stride_mid_od, - stride_mid_o_eb, - stride_mid_o_eh, - stride_mid_o_es, - gqa_group_size, - head_dim, - BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - seq_start_block = tl.program_id(2) - cur_kv_head = cur_head // gqa_group_size - - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_batch_start_index = seq_start_block * BLOCK_SEQ - cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) - - off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d - - block_n_size = ( - tl.where( - cur_batch_end_index - cur_batch_start_index <= 0, - 0, - cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1, - ) - // BLOCK_N - ) - - offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) - - q = tl.load(Q + off_q, mask=offs_d < head_dim, other=0.0) - - sum_exp = 0.0 - max_logic = -float("inf") - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, block_n_size, 1): - offs_n_new = start_n * BLOCK_N + offs_n - k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, - mask=offs_n_new < cur_batch_end_index, - other=0, - ) - off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] - k = tl.load( - K + off_k, mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < head_dim), other=0.0 - ) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - att_value = tl.where(offs_n_new < cur_batch_end_index, att_value, float("-inf")) - v = tl.load( - V + off_k, mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < head_dim), other=0.0 - ) - - cur_max_logic = tl.max(att_value, axis=0) - new_max_logic = tl.maximum(cur_max_logic, max_logic) - - exp_logic = tl.exp(att_value - new_max_logic) - logic_scale = tl.exp(max_logic - new_max_logic) - acc *= logic_scale - acc += tl.sum(exp_logic[:, None] * v, axis=0) - - sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0) - max_logic = new_max_logic - - need_store = tl.where(block_n_size == 0, 0, 1) - for _ in range(0, need_store, 1): - off_mid_o = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + seq_start_block * stride_mid_os + offs_d - off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block - tl.store(Mid_O + off_mid_o, acc / sum_exp, mask=offs_d < head_dim) - tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp)) - return - - -@torch.no_grad() -def flash_decode_stage1( - q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_batch, mid_out, mid_out_logsumexp, block_seq -): - BLOCK_SEQ = block_seq - BLOCK_N = 16 - assert BLOCK_SEQ % BLOCK_N == 0 - # shape constraints - Lq, Lk = q.shape[-1], k.shape[-1] - assert Lq == Lk - head_dim = Lq - BLOCK_DMODEL = triton.next_power_of_2(head_dim) - sm_scale = 1.0 / (Lk ** 0.5) - batch, head_num = B_req_idx.shape[0], q.shape[1] - grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) - gqa_group_size = q.shape[1] // k.shape[1] - - _fwd_kernel_flash_decode_stage1[grid]( - q, - k, - v, - sm_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, - mid_out, - mid_out_logsumexp, - Req_to_tokens.stride(0), - Req_to_tokens.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - mid_out.stride(0), - mid_out.stride(1), - mid_out.stride(2), - mid_out.stride(3), - mid_out_logsumexp.stride(0), - mid_out_logsumexp.stride(1), - mid_out_logsumexp.stride(2), - gqa_group_size, - head_dim, - BLOCK_SEQ=BLOCK_SEQ, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_N=BLOCK_N, - num_warps=1, - num_stages=2, - ) - return diff --git a/lightllm/models/phi3/triton_kernel/flash_decoding_stage2.py b/lightllm/models/phi3/triton_kernel/flash_decoding_stage2.py deleted file mode 100644 index a06ee54545..0000000000 --- a/lightllm/models/phi3/triton_kernel/flash_decoding_stage2.py +++ /dev/null @@ -1,85 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_flash_decode_stage2( - B_Seqlen, - Mid_O, # [batch, head, seq_block_num, head_dim] - Mid_O_LogExpSum, # [batch, head, seq_block_num] - Out, # [batch, head, head_dim] - stride_mid_ob, - stride_mid_oh, - stride_mid_os, - stride_mid_od, - stride_mid_o_eb, - stride_mid_o_eh, - stride_mid_o_es, - stride_obs, - stride_oh, - stride_od, - head_dim, - BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - - block_n_size = tl.where(cur_batch_seq_len <= 0, 0, cur_batch_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ - - sum_exp = 0.0 - max_logic = -float("inf") - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - - offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d - offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh - for block_seq_n in range(0, block_n_size, 1): - tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os, mask=offs_d < head_dim, other=0.0) - tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) - new_max_logic = tl.maximum(tlogic, max_logic) - - old_scale = tl.exp(max_logic - new_max_logic) - acc *= old_scale - exp_logic = tl.exp(tlogic - new_max_logic) - acc += exp_logic * tv - sum_exp = sum_exp * old_scale + exp_logic - max_logic = new_max_logic - - tl.store(Out + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp, mask=offs_d < head_dim) - return - - -@torch.no_grad() -def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, Out, block_seq): - Lk = mid_out.shape[-1] - head_dim = Lk - batch, head_num = mid_out.shape[0], mid_out.shape[1] - BLOCK_DMODEL = triton.next_power_of_2(head_dim) - grid = (batch, head_num) - - _fwd_kernel_flash_decode_stage2[grid]( - B_Seqlen, - mid_out, - mid_out_logexpsum, - Out, - mid_out.stride(0), - mid_out.stride(1), - mid_out.stride(2), - mid_out.stride(3), - mid_out_logexpsum.stride(0), - mid_out_logexpsum.stride(1), - mid_out_logexpsum.stride(2), - Out.stride(0), - Out.stride(1), - Out.stride(2), - head_dim, - BLOCK_SEQ=block_seq, - BLOCK_DMODEL=BLOCK_DMODEL, - num_warps=4, - num_stages=2, - ) - return From 6c3428b08689e681f2acaaf83a958719092e2169 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 06:32:07 +0000 Subject: [PATCH 079/114] fix --- .../transformer_layer_infer_cohere_template.py | 10 ---------- .../layer_infer/transformer_layer_infer.py | 18 +++++++++--------- .../layer_infer/transformer_layer_infer.py | 4 +++- 3 files changed, 12 insertions(+), 20 deletions(-) diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py index 27f71a17ec..379e891d1e 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py @@ -1,14 +1,9 @@ from functools import partial from typing import Tuple - import torch import torch.distributed as dist - from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_template import TransformerLayerInferTpl -from lightllm.utils.infer_utils import mark_cost_time - from ...infer_struct import InferStateInfo -from ..transformer_layer_infer import TransformerLayerInfer from lightllm.distributed.communication_op import all_reduce @@ -30,11 +25,6 @@ def _q_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Ten def _k_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: raise Exception("need to impl") - def _bind_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - self._att_norm = partial(TransformerLayerCohereInferTpl._q_norm, self) - self._q_norm = partial(TransformerLayerCohereInferTpl._k_norm, self) - self._k_norm = partial(TransformerLayerCohereInferTpl._att_norm, self) - def _rotary_emb_fwd(self, q, kv, position_cos, position_sin): raise Exception("need to impl") diff --git a/lightllm/models/cohere/layer_infer/transformer_layer_infer.py b/lightllm/models/cohere/layer_infer/transformer_layer_infer.py index 0cdd281a37..b3dcba937a 100644 --- a/lightllm/models/cohere/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/cohere/layer_infer/transformer_layer_infer.py @@ -30,6 +30,15 @@ def _bind_func(self): self._bind_norm() self._bind_attn() + def _bind_norm(self): + self._att_norm = partial(CohereTransformerLayerInfer._att_norm, self) + self._q_norm = partial(CohereTransformerLayerInfer._q_norm, self) + self._k_norm = partial(CohereTransformerLayerInfer._k_norm, self) + + def _bind_attn(self): + self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self) + self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_attention_kernel, self) + def _rotary_emb_fwd(self, q, kv, position_cos, position_sin): return rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), @@ -52,15 +61,6 @@ def _q_norm(self, input, infer_state, layer_weight: CohereTransformerLayerWeight def _k_norm(self, input, infer_state, layer_weight: CohereTransformerLayerWeight): return layernorm_forward(input, layer_weight.k_norm_weight_.weight, self.eps_) - def _bind_norm(self): - self._att_norm = partial(CohereTransformerLayerInfer._att_norm, self) - self._q_norm = partial(CohereTransformerLayerInfer._q_norm, self) - self._k_norm = partial(CohereTransformerLayerInfer._k_norm, self) - - def _bind_attn(self): - # no need to re-impl - LlamaTransformerLayerInfer._bind_attention(self) - def _get_o( self, input, infer_state: CohereInferStateInfo, layer_weight: CohereTransformerLayerWeight ) -> torch.Tensor: diff --git a/lightllm/models/starcoder/layer_infer/transformer_layer_infer.py b/lightllm/models/starcoder/layer_infer/transformer_layer_infer.py index 018816fcc6..561ffc316f 100644 --- a/lightllm/models/starcoder/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/starcoder/layer_infer/transformer_layer_infer.py @@ -1,5 +1,6 @@ from lightllm.models.bloom.layer_infer.transformer_layer_infer import BloomTransformerLayerInfer from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from functools import partial class StarcoderTransformerLayerInfer(BloomTransformerLayerInfer): @@ -13,5 +14,6 @@ def __init__(self, layer_num, network_config, mode=[]): return def _bind_func(self): - LlamaTransformerLayerInfer._bind_attention(self) + self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self) + self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_attention_kernel, self) return From 9e0d4cb053d17fd80a42f556508c0223d85de3b8 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 07:09:09 +0000 Subject: [PATCH 080/114] fix --- lightllm/models/cohere/model.py | 5 - lightllm/models/deepseek2/infer_struct.py | 15 -- .../layer_infer/transformer_layer_infer.py | 6 - lightllm/models/gemma3/model.py | 4 - .../layer_infer/transformer_layer_infer.py | 122 ++++------ lightllm/models/gpt_oss/model.py | 7 +- .../layer_weights/transformer_layer_weight.py | 4 - lightllm/models/internlm/model.py | 3 - .../pre_and_post_layer_weight.py | 2 - lightllm/models/mistral/model.py | 2 - .../models/mistral/triton_kernel/__init__.py | 0 .../context_flashattention_nopad.py | 228 ------------------ .../init_att_sliding_window_info.py | 45 ---- .../token_attention_nopad_att1.py | 128 ---------- .../token_attention_nopad_reduceV.py | 124 ---------- .../token_attention_softmax_and_reducev.py | 132 ---------- .../layer_infer/transformer_layer_infer.py | 1 - .../layer_infer/transformer_layer_infer.py | 3 - lightllm/models/qwen3/model.py | 2 - .../layer_infer/transformer_layer_infer.py | 8 - .../layer_infer/transformer_layer_infer.py | 1 - lightllm/models/qwen_vl/model.py | 2 - lightllm/models/stablelm/model.py | 3 - 23 files changed, 53 insertions(+), 794 deletions(-) delete mode 100644 lightllm/models/mistral/triton_kernel/__init__.py delete mode 100644 lightllm/models/mistral/triton_kernel/context_flashattention_nopad.py delete mode 100644 lightllm/models/mistral/triton_kernel/init_att_sliding_window_info.py delete mode 100644 lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py delete mode 100644 lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py delete mode 100644 lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py diff --git a/lightllm/models/cohere/model.py b/lightllm/models/cohere/model.py index 5b317c1331..05ccaac3e3 100644 --- a/lightllm/models/cohere/model.py +++ b/lightllm/models/cohere/model.py @@ -1,10 +1,5 @@ import os import torch -from lightllm.common.basemodel.basemodel import TpPartBaseModel -from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_cohere_template import ( - TransformerLayerCohereInferTpl, -) -from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.models.registry import ModelRegistry from lightllm.models.cohere.infer_struct import CohereInferStateInfo from lightllm.models.cohere.layer_infer.post_layer_infer import CoherePostLayerInfer diff --git a/lightllm/models/deepseek2/infer_struct.py b/lightllm/models/deepseek2/infer_struct.py index 0c2ef30489..4dd79305c3 100644 --- a/lightllm/models/deepseek2/infer_struct.py +++ b/lightllm/models/deepseek2/infer_struct.py @@ -1,21 +1,6 @@ -import os -import torch -import numpy as np -import torch.distributed as dist from lightllm.models.llama.infer_struct import LlamaInferStateInfo class Deepseek2InferStateInfo(LlamaInferStateInfo): def __init__(self): super().__init__() - self.kv_starts = None - - def init_some_extra_state(self, model): - super().init_some_extra_state(model) - if not self.is_prefill: - self.kv_starts = self.b1_cu_kv_seq_len - - if self.is_prefill: - self.b1_kv_start_loc = self.b1_cu_kv_seq_len - self.max_value_in_b_seq_len = self.max_kv_seq_len - return diff --git a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py index d4bd8c3fa6..6f87710917 100644 --- a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py @@ -1,12 +1,6 @@ import torch -import torch.functional as F import torch.distributed as dist import torch.nn as nn -import numpy as np -from typing import Tuple -from functools import partial -import triton - from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.distributed import all_reduce from lightllm.models.gemma3.layer_weights.transformer_layer_weight import Gemma3TransformerLayerWeight diff --git a/lightllm/models/gemma3/model.py b/lightllm/models/gemma3/model.py index dc4f03b7e1..9931c31713 100644 --- a/lightllm/models/gemma3/model.py +++ b/lightllm/models/gemma3/model.py @@ -1,7 +1,5 @@ import os -import re import json -import numpy as np import torch from lightllm.models.registry import ModelRegistry from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer @@ -14,8 +12,6 @@ from lightllm.models.gemma3.layer_weights.pre_and_post_layer_weight import Gemma3PreAndPostLayerWeight from lightllm.models.gemma3.layer_weights.transformer_layer_weight import Gemma3TransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel -from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer -from lightllm.models.llava.layer_weights.pre_and_post_layer_weight import LlavaPreAndPostLayerWeight from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem from lightllm.server.core.objs import SamplingParams from lightllm.common.build_utils import repair_config diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index dbcf379468..d02e0ba8a8 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -1,13 +1,7 @@ import torch -from torch import nn -from torch.nn import functional as F -import numpy as np -from functools import partial -from typing import Optional - from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer, LlamaInferStateInfo +from lightllm.common.basemodel.attention.base_att import AttControl from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -23,22 +17,17 @@ def __init__(self, layer_num, network_config, mode=[]): self.sliding_window = network_config["sliding_window"] self.head_dim_ = network_config["head_dim"] - def _bind_attention(self): - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - self._context_attention_kernel = self._context_sliding_attention_flashattention - self._token_attention_kernel = self._token_sliding_attention_flashattention - def _bind_norm(self): self._att_norm = self._att_norm self._ffn_norm = self._ffn_norm return - def _att_norm(self, input, infer_state, layer_weight) -> torch.Tensor: + def _att_norm(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) -> torch.Tensor: out = self.alloc_tensor(input.shape, input.dtype) out = self._gpt_oss_rmsnorm(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_) return out - def _ffn_norm(self, input, infer_state, layer_weight) -> torch.Tensor: + def _ffn_norm(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) -> torch.Tensor: out = self.alloc_tensor(input.shape, input.dtype) out = self._gpt_oss_rmsnorm(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_) return out @@ -65,78 +54,61 @@ def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) - ) return hidden_states.view(num_tokens, hidden_dim) - def _context_sliding_attention_flashattention( - self, q, kv, infer_state, layer_weight: GptOssTransformerLayerWeight, out=None + def _context_attention_kernel( + self, + q: torch.Tensor, + kv, + infer_state: LlamaInferStateInfo, + layer_weight: GptOssTransformerLayerWeight, + out=None, ): if self.network_config_["layer_types"][self.layer_num_] == "sliding_attention": window_size = (self.sliding_window - 1, self.sliding_window - 1) + use_sliding_window = True else: window_size = (-1, -1) + use_sliding_window = False - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( - -1, 1, self.tp_k_head_num_, self.head_dim_ + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + o_tensor = infer_state.prefill_att_state.prefill_att( + q=_q, + k=_k, + v=_v, + att_control=AttControl( + use_sliding_window=use_sliding_window, + sliding_window=window_size, + use_att_sink=True, + sink_weight=layer_weight.attn_sinks.weight, + ), + alloc_func=self.alloc_tensor, ) - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) - q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) - k_descale, v_descale = None, None # disable quantization - Lq = q.shape[-1] - sm_scale = 1.0 / (Lq ** 0.5) - o = flash_attn_with_kvcache( - q=q, - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=infer_state.q_max_seq_len, - softmax_scale=sm_scale, - causal=True, - window_size=window_size, - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, - sinks=layer_weight.attn_sinks.weight, - ) - return o + o_tensor = o_tensor.view(q.shape) + return o_tensor - def _token_sliding_attention_flashattention( - self, q, infer_state, layer_weight: GptOssTransformerLayerWeight, out=None + def _token_attention_kernel( + self, q: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: GptOssTransformerLayerWeight, out=None ): if self.network_config_["layer_types"][self.layer_num_] == "sliding_attention": window_size = (self.sliding_window - 1, self.sliding_window - 1) + use_sliding_window = True else: window_size = (-1, -1) + use_sliding_window = False - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( - -1, 1, self.tp_k_head_num_, self.head_dim_ - ) - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) - q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) - k_descale, v_descale = None, None # disable quantization - Lq = q.shape[-1] - sm_scale = 1.0 / (Lq ** 0.5) - o = flash_attn_with_kvcache( - q=q, - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=1, - softmax_scale=sm_scale, - causal=True, - window_size=window_size, - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, - sinks=layer_weight.attn_sinks.weight, + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + o_tensor = infer_state.decode_att_state.decode_att( + q=_q, + k=_k, + v=_v, + att_control=AttControl( + use_sliding_window=use_sliding_window, + sliding_window=window_size, + use_att_sink=True, + sink_weight=layer_weight.attn_sinks.weight, + ), + alloc_func=self.alloc_tensor, ) - return o + o_tensor = o_tensor.view(q.shape) + return o_tensor diff --git a/lightllm/models/gpt_oss/model.py b/lightllm/models/gpt_oss/model.py index 34a017b316..dc5f2abdfe 100644 --- a/lightllm/models/gpt_oss/model.py +++ b/lightllm/models/gpt_oss/model.py @@ -19,4 +19,9 @@ class GptOssTpPartModel(LlamaTpPartModel): def __init__(self, kvargs): super().__init__(kvargs) - assert get_env_start_args().enable_fa3, "For now GPT-OSS type model only support flashattention-3" + assert ( + get_env_start_args().llm_prefill_att_backend[0] == "fa3" + ), "For now GPT-OSS type model only support flashattention-3" + assert ( + get_env_start_args().llm_decode_att_backend[0] == "fa3" + ), "For now GPT-OSS type model only support flashattention-3" diff --git a/lightllm/models/internlm/layer_weights/transformer_layer_weight.py b/lightllm/models/internlm/layer_weights/transformer_layer_weight.py index a2fc91dc46..858c192f4c 100755 --- a/lightllm/models/internlm/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/internlm/layer_weights/transformer_layer_weight.py @@ -1,7 +1,3 @@ -import torch -import math -import numpy as np - from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight diff --git a/lightllm/models/internlm/model.py b/lightllm/models/internlm/model.py index 78ac7117e1..50adbb3f9f 100644 --- a/lightllm/models/internlm/model.py +++ b/lightllm/models/internlm/model.py @@ -1,6 +1,3 @@ -import os -import json -import torch from lightllm.models.registry import ModelRegistry from lightllm.models.internlm.layer_weights.transformer_layer_weight import InternlmTransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel diff --git a/lightllm/models/llava/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llava/layer_weights/pre_and_post_layer_weight.py index b4c070a1e6..e0e2e11845 100644 --- a/lightllm/models/llava/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llava/layer_weights/pre_and_post_layer_weight.py @@ -1,5 +1,3 @@ -import torch -import numpy as np from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight diff --git a/lightllm/models/mistral/model.py b/lightllm/models/mistral/model.py index d2bfeaa952..f09525c59f 100644 --- a/lightllm/models/mistral/model.py +++ b/lightllm/models/mistral/model.py @@ -1,5 +1,3 @@ -import os -import json import torch from lightllm.models.registry import ModelRegistry from lightllm.common.basemodel import TpPartBaseModel diff --git a/lightllm/models/mistral/triton_kernel/__init__.py b/lightllm/models/mistral/triton_kernel/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/models/mistral/triton_kernel/context_flashattention_nopad.py b/lightllm/models/mistral/triton_kernel/context_flashattention_nopad.py deleted file mode 100644 index abcaf02b51..0000000000 --- a/lightllm/models/mistral/triton_kernel/context_flashattention_nopad.py +++ /dev/null @@ -1,228 +0,0 @@ -import torch - -import triton -import triton.language as tl -import math -import torch.nn.functional as F - - -@triton.jit -def _fwd_kernel( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - sliding_window, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, - other=0.0, - ) - # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - # [SYM] mask outside of windows - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - qk = tl.where((start_n + offs_n[None, :]) > (offs_m[:, None] - sliding_window), qk, float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - return - - -@torch.no_grad() -def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, sliding_window): - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - - sm_scale = 1.0 / (Lq ** 0.5) # 计算scale系数 - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - kv_group_num=kv_group_num, - sliding_window=sliding_window, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - -def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): - xq = xq.view(bs, seqlen, num_head, head_dim) - xk = xk.view(bs, seqlen, num_head, head_dim) - xv = xv.view(bs, seqlen, num_head, head_dim) - mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() - mask[mask == 0.0] = -100000000.0 - mask = mask.repeat(bs, num_head, 1, 1) - keys = xk - values = xv - xq = xq.transpose(1, 2) - keys = keys.transpose(1, 2) - values = values.transpose(1, 2) - scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim) - scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq) - output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) - return output - - -def test(): - import torch - - Z, H, N_CTX, D_HEAD = 4, 6, 1024, 128 - dtype = torch.float16 - Z = 3 - q = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - k = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - v = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - o = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - - max_input_len = N_CTX - Z = 4 - b_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - - b_seq_len[0] = 512 - b_seq_len[1] = 1024 - b_seq_len[2] = 512 - b_seq_len[3] = 1024 - - for i in range(1, Z): - b_start_loc[i] = b_start_loc[i - 1] + b_seq_len[i - 1] - - torch_out = [] - start = 0 - for i in range(Z): - end = start + b_seq_len[i] - torch_o = torch_att(q[start:end], k[start:end], v[start:end], 1, b_seq_len[i], H, D_HEAD) - start = end - torch_out.append(torch_o) - torch_out = torch.cat(torch_out, dim=0) - context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, 10) - print(o.shape, torch_out.shape) - - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) - assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) diff --git a/lightllm/models/mistral/triton_kernel/init_att_sliding_window_info.py b/lightllm/models/mistral/triton_kernel/init_att_sliding_window_info.py deleted file mode 100644 index a60fe970b3..0000000000 --- a/lightllm/models/mistral/triton_kernel/init_att_sliding_window_info.py +++ /dev/null @@ -1,45 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_init_att_window_info( - b_seq_len, - b_att_seq_len, - batch_size, - sliding_window, - BLOCK_SIZE: tl.constexpr, -): - cur_index = tl.program_id(0) - cur_start = cur_index * BLOCK_SIZE - offsets = cur_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < batch_size - - cur_seq_len = tl.load(b_seq_len + offsets, mask=mask) - b_att_seq_len_data = tl.minimum(cur_seq_len, sliding_window) - - tl.store(b_att_seq_len + offsets, b_att_seq_len_data, mask=mask) - return - - -@torch.no_grad() -def init_att_window_info_fwd(batch_size, b_seq_len, b_att_seq_len, sliding_window): - # shape constraints - assert batch_size == b_seq_len.shape[0] == b_att_seq_len.shape[0] - - BLOCK_SIZE = 32 - num_warps = 1 - grid = (triton.cdiv(batch_size, BLOCK_SIZE),) - - _fwd_kernel_init_att_window_info[grid]( - b_seq_len, - b_att_seq_len, - batch_size=batch_size, - sliding_window=sliding_window, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, - num_stages=1, - ) - return diff --git a/lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py b/lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py deleted file mode 100644 index 9a8261132a..0000000000 --- a/lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py +++ /dev/null @@ -1,128 +0,0 @@ -import torch - -import triton -import triton.language as tl -import math - - -@triton.jit -def _fwd_kernel_token_att1( - Q, - K, - sm_scale, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - B_Att_Start_Loc, - B_Att_Seqlen, - Att_Out, - stride_req_to_tokens_b, - stride_req_to_tokens_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - att_stride_h, - att_stride_bs, - kv_group_num, - sliding_window, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_n = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - offs_d = tl.arange(0, BLOCK_DMODEL) # [D] - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # use window index - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch) - - # use new start index of k value - cur_batch_start_index = tl.maximum(cur_batch_seq_len - sliding_window, 0) - cur_batch_end_index = cur_batch_seq_len - - off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd # [D] - - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) # [32] - - # use new value to decide block mask - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < cur_att_seq_len, 1, 0) # a number - - for start_mark in range(0, block_mask, 1): - q = tl.load(Q + off_q + start_mark) # [SYM] why here add start_mark - offs_n_new = cur_batch_start_index + offs_n # the latest window of token - k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new, - mask=offs_n_new < cur_batch_end_index, - other=0, - ) - off_k = ( - k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd - ) # [32, D], find token index - k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) # [1, D] * [32, D] = [32, D] -> [32] - att_value = att_value.to(tl.float32) - att_value *= sm_scale - off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs - tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) - return - - -@torch.no_grad() -def token_att_fwd( - q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Att_Start_Loc, B_Att_Seqlen, sliding_window -): - BLOCK = 32 - # shape constraints - Lq, Lk = q.shape[-1], k.shape[-1] - assert Lq == Lk - assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / (Lk ** 0.5) - - batch, head_num = B_req_idx.shape[0], q.shape[1] - - grid = (batch, head_num, triton.cdiv(sliding_window, BLOCK)) - kv_group_num = q.shape[1] // k.shape[1] - - if kv_group_num == 1: - num_warps = 4 - else: - num_warps = 2 - - _fwd_kernel_token_att1[grid]( - q, - k, - sm_scale, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - B_Att_Start_Loc, - B_Att_Seqlen, - att_out, - Req_to_tokens.stride(0), - Req_to_tokens.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - att_out.stride(0), - att_out.stride(1), - kv_group_num=kv_group_num, - sliding_window=sliding_window, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return diff --git a/lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py b/lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py deleted file mode 100644 index acf4923f82..0000000000 --- a/lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py +++ /dev/null @@ -1,124 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_token_att2( - Prob, - V, - Out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - B_Att_Start_Loc, - B_Att_Seqlen, - stride_req_to_tokens_b, - stride_req_to_tokens_s, - stride_ph, - stride_pbs, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - sliding_window, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - cur_kv_head = cur_head // kv_group_num - - offs_n = tl.arange(0, BLOCK_N) # [64] - offs_d = tl.arange(0, BLOCK_DMODEL) # [D] - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_index = tl.maximum(cur_batch_seq_len - sliding_window, 0) # new index - # cur_batch_end_index = cur_batch_seq_len - cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # new index - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch) # att length - - v_loc_off = ( - cur_batch_req_idx * stride_req_to_tokens_b + (cur_batch_start_index + offs_n) * stride_req_to_tokens_s - ) # the latest window of value [64] - p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs # [64] - v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd # [1, D] - - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) # [D] - for start_n in range(0, cur_att_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) # check - p_value = tl.load(Prob + p_offs + start_n, mask=(start_n + offs_n) < cur_att_seq_len, other=0.0) # [64] - v_loc = tl.load( - Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s, - mask=(start_n + offs_n + cur_batch_start_index) < cur_batch_seq_len, - other=0.0, - ) # [64] - v_value = tl.load( - V + v_offs + v_loc[:, None] * stride_vbs, - mask=(start_n + offs_n[:, None] + cur_batch_start_index) < cur_batch_seq_len, - other=0.0, - ) # [1, D] + [64, 1] = [64, D] - acc += tl.sum(p_value[:, None] * v_value, 0) # [64, 1] * [64, D] = [64, D] -> [D] - - acc = acc.to(Out.dtype.element_ty) - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - return - - -@torch.no_grad() -def token_att_fwd2( - prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Att_Start_Loc, B_Att_Seqlen, sliding_window -): - BLOCK = 128 - # BLOCK = 64 # for triton 2.0.0dev - batch, head = B_req_idx.shape[0], prob.shape[0] - grid = (batch, head) - num_warps = 4 - dim = v.shape[-1] - - kv_group_num = prob.shape[0] // v.shape[1] - - _fwd_kernel_token_att2[grid]( - prob, - v, - out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - B_Att_Start_Loc, - B_Att_Seqlen, - Req_to_tokens.stride(0), - Req_to_tokens.stride(1), - prob.stride(0), - prob.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - kv_group_num=kv_group_num, - siliding_window=sliding_window, - BLOCK_DMODEL=dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - -def torch_att(V, P, bs, seqlen, num_head, head_dim): - V = V.view(bs, seqlen, num_head, head_dim).transpose(1, 2) - P = P.reshape(num_head, bs, 1, seqlen).transpose(0, 1) - out = torch.matmul(P, V) - - return out diff --git a/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py b/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py deleted file mode 100644 index bf9928f987..0000000000 --- a/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py +++ /dev/null @@ -1,132 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel( - Logics, - V, - Out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - B_Att_Start_Loc, - B_Att_Seqlen, - stride_logic_h, - stride_logic_bs, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_req_to_token_b, - stride_req_to_token_s, - other_kv_index, # 避免读取到nan的数据 - kv_group_num, - sliding_window, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_loc = tl.load(B_Att_Start_Loc + cur_batch) # new index - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch) # new index - cur_cache_start_loc = tl.maximum(cur_batch_seq_len - sliding_window, 0) # new index - - offs_n = tl.arange(0, BLOCK_N) # [64] - offs_d = tl.arange(0, BLOCK_DMODEL) # [D] - - off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd # [1, D] - v_ptrs = V + off_v - - e_max = float("-inf") - e_sum = 0.0 - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) # [D] - - for start_n in range(0, cur_att_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) # check - v_index = tl.load( - Req_to_tokens - + cur_batch_req_idx * stride_req_to_token_b - + (cur_cache_start_loc + start_n + offs_n) * stride_req_to_token_s, - mask=(start_n + offs_n) < cur_att_seq_len, - other=other_kv_index, - ) # [64] - - qk = tl.load( - Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs, - mask=(start_n + offs_n) < cur_att_seq_len, - other=float("-inf"), - ) # [64] - - n_e_max = tl.maximum(tl.max(qk, 0), e_max) - old_scale = tl.exp(e_max - n_e_max) - p = tl.exp(qk - n_e_max) - e_sum = e_sum * old_scale + tl.sum(p, 0) - v = tl.load(v_ptrs + v_index[:, None] * stride_vbs) # [1, D] + [64, 1] = [64, D] - acc = acc * old_scale + tl.sum(p[:, None] * v, 0) # [64, 1] * [64, D] = [64, D] -> [D] - e_max = n_e_max - - acc = acc / e_sum - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - return - - -@torch.no_grad() -def token_softmax_reducev_fwd( - logics, - v, - o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, - b_att_start_loc, - b_att_seq_len, - sliding_window, -): - BLOCK = 64 - batch, head = b_seq_len.shape[0], logics.shape[0] - grid = (batch, head) - kv_group_num = logics.shape[0] // v.shape[1] - - num_warps = 1 - _fwd_kernel[grid]( - logics, - v, - o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, - b_att_start_loc, - b_att_seq_len, - logics.stride(0), - logics.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - req_to_tokens.stride(0), - req_to_tokens.stride(1), - 0, - kv_group_num, - sliding_window, - BLOCK_DMODEL=v.shape[-1], - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=3, - ) - return diff --git a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py index a60375688b..c5842b8289 100644 --- a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py @@ -3,7 +3,6 @@ import torch.nn.functional as F from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.mistral.layer_infer.transformer_layer_infer import MistralTransformerLayerInfer from lightllm.models.mixtral.layer_infer._custom_ops import fused_topk from lightllm.models.mixtral.layer_weights.transformer_layer_weight import MixtralTransformerLayerWeight diff --git a/lightllm/models/qwen/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen/layer_infer/transformer_layer_infer.py index 7a4b2ca816..2576d7affd 100755 --- a/lightllm/models/qwen/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen/layer_infer/transformer_layer_infer.py @@ -1,7 +1,4 @@ import torch -import torch.functional as F -import torch.distributed as dist -import numpy as np from typing import Tuple from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd diff --git a/lightllm/models/qwen3/model.py b/lightllm/models/qwen3/model.py index 21e71e0e02..e48b36e0f7 100644 --- a/lightllm/models/qwen3/model.py +++ b/lightllm/models/qwen3/model.py @@ -1,5 +1,3 @@ -import torch -from typing import final from lightllm.models.registry import ModelRegistry from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight diff --git a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py index 17ce4b7693..9ce475e974 100644 --- a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py @@ -1,20 +1,12 @@ import torch -import torch.functional as F import torch.distributed as dist -import numpy as np -from functools import partial from typing import Tuple from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused -from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd from lightllm.distributed import all_reduce -from lightllm.utils.dist_utils import get_global_world_size from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward diff --git a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py index b155f8b907..facad2e56b 100644 --- a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py @@ -7,7 +7,6 @@ from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward from lightllm.distributed import all_reduce -from lightllm.utils.dist_utils import get_global_world_size from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features diff --git a/lightllm/models/qwen_vl/model.py b/lightllm/models/qwen_vl/model.py index edebccf17f..d942d68497 100644 --- a/lightllm/models/qwen_vl/model.py +++ b/lightllm/models/qwen_vl/model.py @@ -1,5 +1,3 @@ -import json -import numpy as np import unicodedata from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer from lightllm.server.core.objs import SamplingParams diff --git a/lightllm/models/stablelm/model.py b/lightllm/models/stablelm/model.py index 2ed710fd4c..a3d295358f 100644 --- a/lightllm/models/stablelm/model.py +++ b/lightllm/models/stablelm/model.py @@ -1,6 +1,3 @@ -import os -import json -import torch from lightllm.models.registry import ModelRegistry from lightllm.models.stablelm.layer_infer.transformer_layer_infer import StablelmTransformerLayerInfer from lightllm.models.bloom.layer_infer.post_layer_infer import BloomPostLayerInfer From ff1091246525f865e4398e6d0c9109285a02e541 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 07:23:23 +0000 Subject: [PATCH 081/114] fix --- docs/CN/source/tutorial/api_server_args_zh.rst | 12 ++---------- docs/EN/source/tutorial/api_server_args_zh.rst | 12 ------------ lightllm/models/deepseek2/model.py | 3 --- lightllm/models/llama/model.py | 3 --- lightllm/server/api_cli.py | 18 ------------------ lightllm/server/api_start.py | 15 --------------- lightllm/server/core/objs/start_args_type.py | 5 ----- lightllm/utils/envs_utils.py | 8 +++++++- .../triton_kernel/test_gen_decode_params.py | 12 ------------ .../server/core/objs/test_shm_req_manager.py | 2 -- 10 files changed, 9 insertions(+), 81 deletions(-) diff --git a/docs/CN/source/tutorial/api_server_args_zh.rst b/docs/CN/source/tutorial/api_server_args_zh.rst index ce7a79ab97..d478f28918 100755 --- a/docs/CN/source/tutorial/api_server_args_zh.rst +++ b/docs/CN/source/tutorial/api_server_args_zh.rst @@ -327,17 +327,9 @@ attention类型选择参数 推理后端将为解码使用微批次重叠模式 -.. option:: --enable_flashinfer_prefill +.. option:: --llm_kv_type - 推理后端将为预填充使用 flashinfer 的注意力 kernel - -.. option:: --enable_flashinfer_decode - - 推理后端将为解码使用 flashinfer 的注意力 kernel - -.. option:: --enable_fa3 - - 推理后端将为预填充和解码使用 fa3 注意力 kernel + 推理后端使用什么类型的数据存储kv cache, 可选值为 "None", "int8kv", "int4kv", "fp8kv" .. option:: --disable_cudagraph diff --git a/docs/EN/source/tutorial/api_server_args_zh.rst b/docs/EN/source/tutorial/api_server_args_zh.rst index 1644bbab5f..aae20ecbf5 100755 --- a/docs/EN/source/tutorial/api_server_args_zh.rst +++ b/docs/EN/source/tutorial/api_server_args_zh.rst @@ -325,18 +325,6 @@ Performance Optimization Parameters .. option:: --enable_decode_microbatch_overlap The inference backend will use microbatch overlap mode for decoding - -.. option:: --enable_flashinfer_prefill - - The inference backend will use flashinfer's attention kernel for prefill - -.. option:: --enable_flashinfer_decode - - The inference backend will use flashinfer's attention kernel for decoding - -.. option:: --enable_fa3 - - The inference backend will use fa3 attention kernel for prefill and decoding .. option:: --disable_cudagraph diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index f9fe06a058..f0739a8a81 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -26,9 +26,6 @@ class Deepseek2TpPartModel(LlamaTpPartModel): infer_state_class = Deepseek2InferStateInfo def __init__(self, kvargs): - self.enable_flashinfer = ( - get_env_start_args().enable_flashinfer_prefill or get_env_start_args().enable_flashinfer_decode - ) super().__init__(kvargs) return diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index 6b616ef29b..c104ebccc9 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -33,9 +33,6 @@ class LlamaTpPartModel(TpPartBaseModel): infer_state_class = LlamaInferStateInfo def __init__(self, kvargs): - self.enable_flashinfer = ( - get_env_start_args().enable_flashinfer_prefill or get_env_start_args().enable_flashinfer_decode - ) super().__init__(kvargs) return diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 10aef09233..1eff756186 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -367,24 +367,6 @@ def make_argument_parser() -> argparse.ArgumentParser: this params will be effective. """, ) - - parser.add_argument( - "--enable_flashinfer_prefill", - action="store_true", - help="""inference backend will use the attention kernel of flashinfer for prefill, - only deepseekv3 model supported now.""", - ) - parser.add_argument( - "--enable_flashinfer_decode", - action="store_true", - help="""inference backend will use the attention kernel of flashinfer for decode, - only deepseekv3 model supported now.""", - ) - parser.add_argument( - "--enable_fa3", - action="store_true", - help="""inference backend will use the fa3 attention kernel for prefill and decode""", - ) parser.add_argument( "--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources" ) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 4ead3cbbf7..3ae3789f47 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -122,21 +122,6 @@ def normal_or_p_d_start(args): if args.return_all_prompt_logprobs: assert args.disable_dynamic_prompt_cache is True, "need add --disable_dynamic_prompt_cache" assert args.disable_chunked_prefill is True, "need add --disable_chunked_prefill" - if "offline_calibration_fp8kv" in args.mode: - assert args.enable_fa3 is True or ( - args.enable_flashinfer_prefill is True and args.enable_flashinfer_decode is True - ), ( - "offline_calibration_fp8kv mode need enable fa3 or flashinfer, add --enable_fa3 or " - "--enable_flashinfer_prefill and --enable_flashinfer_decode" - ) - if "export_fp8kv_calibration" in args.mode: - assert args.enable_fa3 is True or ( - args.enable_flashinfer_prefill is True and args.enable_flashinfer_decode is True - ), ( - "export_fp8kv_calibration mode need enable fa3 or flashinfer, add --enable_fa3 or " - "--enable_flashinfer_prefill and --enable_flashinfer_decode" - ) - assert args.disable_cudagraph is True, "export_fp8kv_calibration mode need disable cudagraph" # 部分模式还不能支持与高级动态调度算法协同,to do. if args.diverse_mode: diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 4ca39de9a9..25e6998fdc 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -116,8 +116,6 @@ class StartArgs: quant_cfg: Optional[str] = field(default=None) vit_quant_type: Optional[str] = field(default=None) vit_quant_cfg: Optional[str] = field(default=None) - enable_flashinfer_prefill: bool = field(default=False) - enable_flashinfer_decode: bool = field(default=False) llm_prefill_att_backend: List[str] = field( default=("None",), metadata={"choices": ["None", "triton", "fa3", "flashinfer"]} ) @@ -161,6 +159,3 @@ class StartArgs: # multi_modal enable_multimodal: bool = field(default=False) enable_multimodal_audio: bool = field(default=False) - - # kernel setting - enable_fa3: bool = field(default=False) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index b5822a342c..0b54ef5dce 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -26,10 +26,16 @@ def get_unique_server_name(): def set_cuda_arch(args): if not torch.cuda.is_available(): return - if args.enable_flashinfer_prefill or args.enable_flashinfer_decode: + + from lightllm.server.core.objs.start_args_type import StartArgs + + args: StartArgs = args + + if "flashinfer" in args.llm_prefill_att_backend or "flashinfer" in args.llm_decode_att_backend: capability = torch.cuda.get_device_capability() arch = f"{capability[0]}.{capability[1]}" os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}" + return def set_env_start_args(args): diff --git a/unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py b/unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py index 5c3ca89c65..41bc217b94 100644 --- a/unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py +++ b/unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py @@ -1,21 +1,9 @@ import torch import pytest -import easydict from lightllm.common.basemodel.triton_kernel.gen_decode_params import gen_decode_params -from lightllm.utils.envs_utils import set_env_start_args def test_gen_decode_params_basic(): - set_env_start_args( - easydict.EasyDict( - { - "mtp_step": 0, - "enable_flashinfer_prefill": False, - "enable_flashinfer_decode": False, - } - ) - ) - b_seq_len = torch.ones((9,), dtype=torch.int64, device="cuda") * 8192 ( b_q_seq_len, diff --git a/unit_tests/server/core/objs/test_shm_req_manager.py b/unit_tests/server/core/objs/test_shm_req_manager.py index 1d1ae2ef1a..dea40a4859 100644 --- a/unit_tests/server/core/objs/test_shm_req_manager.py +++ b/unit_tests/server/core/objs/test_shm_req_manager.py @@ -14,8 +14,6 @@ def setup_env(): running_max_req_size=10, disable_chunked_prefill=True, token_healing_mode=False, - enable_flashinfer_prefill=False, - enable_flashinfer_decode=False, ) ) # clear the lru_cache if used From 35e8a8e5f857165b4bacaf11b1cc8a47812c8071 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 07:42:09 +0000 Subject: [PATCH 082/114] fix all --- .../source/tutorial/deepseek_deployment.rst | 36 ++++++++++++------- .../tutorial/multi_level_cache_deployment.rst | 8 +++-- docs/CN/source/tutorial/reasoning_parser.rst | 3 +- .../source/tutorial/deepseek_deployment.rst | 33 ++++++++++------- .../tutorial/multi_level_cache_deployment.rst | 8 +++-- docs/EN/source/tutorial/reasoning_parser.rst | 3 +- .../offline_fp8_quant_mem_manager.py | 21 +++++++---- test/acc/test_deepseekr1.sh | 2 +- test/acc/test_deepseekr1_mtp.sh | 2 +- test/acc/test_deepseekr1_mtp_ep.sh | 2 +- test/acc/test_qwen2.sh | 2 +- test/acc/test_qwen3.sh | 2 +- test/start_scripts/README.md | 1 - test/start_scripts/draft.sh | 2 +- test/start_scripts/multi_node_ep_node0.sh | 2 +- test/start_scripts/multi_node_ep_node1.sh | 2 +- test/start_scripts/multi_node_tp_node0.sh | 2 +- test/start_scripts/multi_node_tp_node1.sh | 2 +- .../multi_pd_master/pd_decode.sh | 2 +- .../multi_pd_master/pd_prefill.sh | 2 +- test/start_scripts/single_node_ep.sh | 2 +- test/start_scripts/single_node_tp.sh | 2 +- .../single_pd_master/pd_decode.sh | 2 +- .../single_pd_master/pd_nixl_decode.sh | 2 +- .../single_pd_master/pd_nixl_prefill.sh | 2 +- .../single_pd_master/pd_prefill.sh | 2 +- 26 files changed, 91 insertions(+), 58 deletions(-) diff --git a/docs/CN/source/tutorial/deepseek_deployment.rst b/docs/CN/source/tutorial/deepseek_deployment.rst index 071d9405ab..5d57b137c6 100644 --- a/docs/CN/source/tutorial/deepseek_deployment.rst +++ b/docs/CN/source/tutorial/deepseek_deployment.rst @@ -33,12 +33,14 @@ LightLLM 支持以下几种部署模式: LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 8 \ - --enable_fa3 + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 **参数说明:** - `LOADWORKER=18`: 模型加载线程数,提高加载速度 - `--tp 8`: 张量并行度,使用8个GPU -- `--enable_fa3`: 启用 Flash Attention 3.0 +- `--llm_prefill_att_backend fa3`: 启用 Flash Attention 3.0 +- `--llm_decode_att_backend fa3`: 启用 Flash Attention 3.0 - `--port 8088`: 服务端口 1.2 单机 DP + EP 模式 (Data Parallel + Expert Parallel) @@ -55,13 +57,15 @@ LightLLM 支持以下几种部署模式: --model_dir /path/DeepSeek-R1 \ --tp 8 \ --dp 8 \ - --enable_fa3 + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 **参数说明:** - `MOE_MODE=EP`: 设置专家并行模式 - `--tp 8`: 张量并行度 - `--dp 8`: 数据并行度,通常设置为与 tp 相同的值 -- `--enable_fa3`: 启用 Flash Attention 3.0 +- `--llm_prefill_att_backend fa3`: 启用 Flash Attention 3.0 +- `--llm_decode_att_backend fa3`: 启用 Flash Attention 3.0 **可选优化参数:** - `--enable_prefill_microbatch_overlap`: 启用预填充微批次重叠 @@ -85,7 +89,8 @@ LightLLM 支持以下几种部署模式: LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ @@ -101,7 +106,8 @@ LightLLM 支持以下几种部署模式: LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ @@ -129,7 +135,8 @@ LightLLM 支持以下几种部署模式: --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ @@ -146,7 +153,8 @@ LightLLM 支持以下几种部署模式: --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ @@ -195,7 +203,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --host $host \ --port 8019 \ --nccl_port 2732 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 @@ -219,7 +228,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --host $host \ --port 8121 \ --nccl_port 12322 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 @@ -287,7 +297,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --tp 8 \ --dp 8 \ --nccl_port 2732 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --disable_cudagraph \ --config_server_host $config_server_host \ --config_server_port 60088 @@ -306,7 +317,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --nccl_port 12322 \ --tp 8 \ --dp 8 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --config_server_host $config_server_host \ --config_server_port 60088 # 如果需要启用微批次重叠,可以取消注释以下行 diff --git a/docs/CN/source/tutorial/multi_level_cache_deployment.rst b/docs/CN/source/tutorial/multi_level_cache_deployment.rst index 0446b07804..223b92dca3 100644 --- a/docs/CN/source/tutorial/multi_level_cache_deployment.rst +++ b/docs/CN/source/tutorial/multi_level_cache_deployment.rst @@ -66,7 +66,8 @@ LightLLM 的多级缓存系统采用分层设计: --model_dir /path/to/Qwen3-235B-A22B \ --tp 8 \ --graph_max_batch_size 500 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --mem_fraction 0.88 \ --enable_cpu_cache \ --cpu_cache_storage_size 400 \ @@ -81,7 +82,7 @@ LightLLM 的多级缓存系统采用分层设计: - ``--model_dir``: 模型文件路径,支持本地路径或 HuggingFace 模型名称 - ``--tp 8``: 张量并行度,使用 8 个 GPU 进行模型推理 - ``--graph_max_batch_size 500``: CUDA Graph 最大批次大小,影响吞吐量和显存占用 -- ``--enable_fa3``: 启用 Flash Attention 3.0,提升注意力计算速度,也可以换成flashinfer后端性能更佳 +- ``--llm_prefill_att_backend fa3``: 启用 Flash Attention 3.0,提升注意力计算速度,也可以换成flashinfer后端性能更佳 - ``--mem_fraction 0.88``: GPU 显存使用比例,建议设置为 0.88及以下 CPU 缓存参数 @@ -130,7 +131,8 @@ CPU 缓存参数 --model_dir /path/to/Qwen3-235B-A22B \ --tp 8 \ --graph_max_batch_size 500 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --mem_fraction 0.88 \ --enable_cpu_cache \ --cpu_cache_storage_size 400 \ diff --git a/docs/CN/source/tutorial/reasoning_parser.rst b/docs/CN/source/tutorial/reasoning_parser.rst index 547eb05d16..a9a0d09fe4 100644 --- a/docs/CN/source/tutorial/reasoning_parser.rst +++ b/docs/CN/source/tutorial/reasoning_parser.rst @@ -32,7 +32,8 @@ DeepSeek-R1 --model_dir /path/to/DeepSeek-R1 \ --reasoning_parser deepseek-r1 \ --tp 8 \ - --enable_fa3 + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 DeepSeek-V3 ~~~~~~~~~~~ diff --git a/docs/EN/source/tutorial/deepseek_deployment.rst b/docs/EN/source/tutorial/deepseek_deployment.rst index 6098411be0..accdbc462b 100755 --- a/docs/EN/source/tutorial/deepseek_deployment.rst +++ b/docs/EN/source/tutorial/deepseek_deployment.rst @@ -33,12 +33,13 @@ Suitable for deploying DeepSeek-R1 model on a single H200 node. LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 8 \ - --enable_fa3 + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 **Parameter Description:** - `LOADWORKER=18`: Model loading thread count, improves loading speed - `--tp 8`: Tensor parallelism, using 8 GPUs -- `--enable_fa3`: Enable Flash Attention 3.0 +- `--llm_prefill_att_backend fa3`: Enable Flash Attention 3.0 - `--port 8088`: Service port 1.2 Single node DP + EP Mode (Data Parallel + Expert Parallel) @@ -55,13 +56,13 @@ Suitable for expert parallelism deployment of MoE models like DeepSeek-V2/V3. --model_dir /path/DeepSeek-R1 \ --tp 8 \ --dp 8 \ - --enable_fa3 + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 **Parameter Description:** - `MOE_MODE=EP`: Set expert parallelism mode - `--tp 8`: Tensor parallelism - `--dp 8`: Data parallelism, usually set to the same value as tp -- `--enable_fa3`: Enable Flash Attention 3.0 **Optional Optimization Parameters:** - `--enable_prefill_microbatch_overlap`: Enable prefill microbatch overlap @@ -85,7 +86,8 @@ Suitable for deployment across multiple H200/H100 nodes. LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ @@ -101,7 +103,8 @@ Suitable for deployment across multiple H200/H100 nodes. LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ @@ -129,7 +132,8 @@ Suitable for deploying MoE models across multiple nodes. --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ @@ -146,7 +150,8 @@ Suitable for deploying MoE models across multiple nodes. --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ @@ -195,7 +200,8 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for --host $host \ --port 8019 \ --nccl_port 2732 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip @@ -216,7 +222,8 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for --host $host \ --port 8121 \ --nccl_port 12322 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 @@ -284,7 +291,8 @@ Supports multiple PD Master nodes, providing better load balancing and high avai --tp 8 \ --dp 8 \ --nccl_port 2732 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --disable_cudagraph \ --config_server_host $config_server_host \ --config_server_port 60088 @@ -303,7 +311,8 @@ Supports multiple PD Master nodes, providing better load balancing and high avai --nccl_port 12322 \ --tp 8 \ --dp 8 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --config_server_host $config_server_host \ --config_server_port 60088 # if you want to enable microbatch overlap, you can uncomment the following lines diff --git a/docs/EN/source/tutorial/multi_level_cache_deployment.rst b/docs/EN/source/tutorial/multi_level_cache_deployment.rst index bb8d943b87..6c99c351f0 100644 --- a/docs/EN/source/tutorial/multi_level_cache_deployment.rst +++ b/docs/EN/source/tutorial/multi_level_cache_deployment.rst @@ -66,7 +66,8 @@ Suitable for most scenarios, significantly increasing cache capacity while maint --model_dir /path/to/Qwen3-235B-A22B \ --tp 8 \ --graph_max_batch_size 500 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --mem_fraction 0.88 \ --enable_cpu_cache \ --cpu_cache_storage_size 400 \ @@ -81,7 +82,7 @@ Basic Parameters - ``--model_dir``: Model file path, supports local path or HuggingFace model name - ``--tp 8``: Tensor parallelism degree, using 8 GPUs for model inference - ``--graph_max_batch_size 500``: CUDA Graph maximum batch size, affects throughput and memory usage -- ``--enable_fa3``: Enable Flash Attention 3.0 to improve attention computation speed. You can also switch to flashinfer backend for better performance +- ``--llm_prefill_att_backend fa3``: Enable Flash Attention 3.0 to improve attention computation speed. You can also switch to flashinfer backend for better performance - ``--mem_fraction 0.88``: GPU memory usage ratio, recommended to set to 0.88 or below CPU Cache Parameters @@ -130,7 +131,8 @@ Suitable for ultra-long text or extremely high-concurrency scenarios, providing --model_dir /path/to/Qwen3-235B-A22B \ --tp 8 \ --graph_max_batch_size 500 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --mem_fraction 0.88 \ --enable_cpu_cache \ --cpu_cache_storage_size 400 \ diff --git a/docs/EN/source/tutorial/reasoning_parser.rst b/docs/EN/source/tutorial/reasoning_parser.rst index e76e093d63..56e61e6cd6 100644 --- a/docs/EN/source/tutorial/reasoning_parser.rst +++ b/docs/EN/source/tutorial/reasoning_parser.rst @@ -32,7 +32,8 @@ DeepSeek-R1 --model_dir /path/to/DeepSeek-R1 \ --reasoning_parser deepseek-r1 \ --tp 8 \ - --enable_fa3 + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 DeepSeek-V3 ~~~~~~~~~~~ diff --git a/lightllm/common/kv_cache_mem_manager/offline_fp8_quant_mem_manager.py b/lightllm/common/kv_cache_mem_manager/offline_fp8_quant_mem_manager.py index 5cc0b12d03..56a79a3b57 100755 --- a/lightllm/common/kv_cache_mem_manager/offline_fp8_quant_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/offline_fp8_quant_mem_manager.py @@ -31,8 +31,10 @@ def __init__( self.scales_list = None self.abs_max = None + enable_fa3 = "fa3" in get_env_start_args().llm_prefill_att_backend + if is_export_mode: - scales_shape = [layer_num, 2 * head_num] if get_env_start_args().enable_fa3 else [layer_num, 2] + scales_shape = [layer_num, 2 * head_num] if enable_fa3 else [layer_num, 2] self.abs_max = torch.zeros(scales_shape, dtype=torch.float32, device="cuda") elif get_env_start_args().kv_quant_calibration_config_path is not None: logger.info( @@ -43,7 +45,7 @@ def __init__( self.scales_list = cfg["scales"] self.scales = torch.tensor(self.scales_list, dtype=torch.float32, device="cuda").view(cfg["scales_shape"]) - if not get_env_start_args().enable_fa3: + if not enable_fa3: self.scales = torch.repeat_interleave(self.scales, head_num, dim=-1) elif cfg["num_head"] > self.total_head_num: factor = cfg["num_head"] // self.total_head_num @@ -51,7 +53,7 @@ def __init__( elif cfg["num_head"] < self.total_head_num: factor = self.total_head_num // cfg["num_head"] self.scales = torch.repeat_interleave(self.scales, factor, dim=-1).contiguous() - if get_env_start_args().enable_fa3 and dist.is_initialized() and dist.get_world_size() > 1: + if enable_fa3 and dist.is_initialized() and dist.get_world_size() > 1: half_head = self.total_head_num // 2 start_head = dist.get_rank() * head_num end_head = start_head + head_num @@ -65,6 +67,8 @@ def __init__( logger.warning("scales is None, no kv_quant_calibration_config_path be set, will use 1.0 as scales") def _load_and_check_config(self): + enable_fa3 = "fa3" in get_env_start_args().llm_prefill_att_backend + if os.path.exists(get_env_start_args().kv_quant_calibration_config_path): with open(get_env_start_args().kv_quant_calibration_config_path, "r") as f: cfg = json.load(f) @@ -86,7 +90,7 @@ def _load_and_check_config(self): raise ValueError( f"num_head {cfg['num_head']} in config " f"not match current model head num {self.total_head_num}" ) - if get_env_start_args().enable_fa3: + if enable_fa3: if cfg["quant_type"] != "per_head": raise ValueError(f"quant type {cfg['num_head']} in config not match fa3 backend") else: @@ -100,6 +104,7 @@ def _load_and_check_config(self): ) def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int): + enable_fa3 = "fa3" in get_env_start_args().llm_prefill_att_backend inference_counts = get_kv_quant_calibration_inference_count() warmup_counts = get_kv_quant_calibration_warmup_count() if not get_model_init_status() or self.count >= warmup_counts + inference_counts: @@ -109,7 +114,7 @@ def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int): logger.info("kv cache calibration mode will collect kv cache data for quantization calibration") if self.abs_max is not None and self.count >= warmup_counts: - if get_env_start_args().enable_fa3: + if enable_fa3: kv_max = kv_buffer.abs().amax(dim=(0, 2)).to(torch.float32) else: k_max = kv_buffer[:, : self.head_num, :].abs().amax(dim=()).to(torch.float32) @@ -119,7 +124,7 @@ def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int): if self.count == warmup_counts + inference_counts - 1 and layer_index == self.layer_num - 1: final_abs_max = self.abs_max if dist.is_initialized() and dist.get_world_size() > 1: - if get_env_start_args().enable_fa3: + if enable_fa3: k_max, v_max = torch.chunk(self.abs_max, 2, dim=-1) k_max = k_max.contiguous() v_max = v_max.contiguous() @@ -144,11 +149,13 @@ def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int): self.count += 1 def _export_calibration_data(self): + enable_fa3 = "fa3" in get_env_start_args().llm_prefill_att_backend + model_arch = get_model_architectures(get_env_start_args().model_dir) cfg = { "version": "1.0", "architectures": model_arch, - "quant_type": "per_head" if get_env_start_args().enable_fa3 else "per_tensor", + "quant_type": "per_head" if enable_fa3 else "per_tensor", "qmin": self.qmin, "qmax": self.qmax, "num_layers": self.layer_num, diff --git a/test/acc/test_deepseekr1.sh b/test/acc/test_deepseekr1.sh index e167303a35..5fcfc0c08b 100644 --- a/test/acc/test_deepseekr1.sh +++ b/test/acc/test_deepseekr1.sh @@ -1,4 +1,4 @@ -LOADWORKER=18 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --port 8089 --enable_fa3 +LOADWORKER=18 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --port 8089 --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 diff --git a/test/acc/test_deepseekr1_mtp.sh b/test/acc/test_deepseekr1_mtp.sh index 046314a728..7eaffd4993 100644 --- a/test/acc/test_deepseekr1_mtp.sh +++ b/test/acc/test_deepseekr1_mtp.sh @@ -1,3 +1,3 @@ -LOADWORKER=18 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --port 8089 --mem_fraction 0.75 --enable_fa3 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 +LOADWORKER=18 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --port 8089 --mem_fraction 0.75 --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/acc/test_deepseekr1_mtp_ep.sh b/test/acc/test_deepseekr1_mtp_ep.sh index 2ea5f74387..0467f76e6a 100644 --- a/test/acc/test_deepseekr1_mtp_ep.sh +++ b/test/acc/test_deepseekr1_mtp_ep.sh @@ -1,3 +1,3 @@ -LOADWORKER=18 MOE_MODE=EP NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --max_total_token_num 60000 --graph_max_batch_size 16 --enable_fa3 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 +LOADWORKER=18 MOE_MODE=EP NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --max_total_token_num 60000 --graph_max_batch_size 16 --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 32 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/acc/test_qwen2.sh b/test/acc/test_qwen2.sh index 265d679e8a..bb5603b5be 100644 --- a/test/acc/test_qwen2.sh +++ b/test/acc/test_qwen2.sh @@ -1,5 +1,5 @@ # first -LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen2.5-Math-7B-Instruct/snapshots/ef9926d75ab1d54532f6a30dd5e760355eb9aa4d --tp 2 --port 8089 --enable_fa3 +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen2.5-Math-7B-Instruct/snapshots/ef9926d75ab1d54532f6a30dd5e760355eb9aa4d --tp 2 --port 8089 --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 # second HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"Qwen/Qwen2.5-Math-7B-Instruct", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/acc/test_qwen3.sh b/test/acc/test_qwen3.sh index c0da5ec96e..36a3c96804 100644 --- a/test/acc/test_qwen3.sh +++ b/test/acc/test_qwen3.sh @@ -1,5 +1,5 @@ # first -LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --tp 2 --port 8089 --enable_fa3 +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --tp 2 --port 8089 --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 # second HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/qwen3-8b", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/start_scripts/README.md b/test/start_scripts/README.md index f5dae19b92..e00af27139 100644 --- a/test/start_scripts/README.md +++ b/test/start_scripts/README.md @@ -108,7 +108,6 @@ sh multi_pd_master/pd_decode.sh - `--model_dir`: Model file path - `--tp`: Tensor parallelism degree - `--dp`: Data parallelism degree -- `--enable_fa3`: Enable Flash Attention 3.0 - `--nnodes`: Total number of nodes - `--node_rank`: Current node rank - `--nccl_host`: NCCL communication host address diff --git a/test/start_scripts/draft.sh b/test/start_scripts/draft.sh index 866f5f2fa5..235f4427a0 100644 --- a/test/start_scripts/draft.sh +++ b/test/start_scripts/draft.sh @@ -16,7 +16,7 @@ HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions \ LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /mtc/DeepSeek-R1 \ --tp 8 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --batch_max_tokens 4096 --chunked_prefill_size 2048 \ --max_total_token_num 20000 \ --enable_cpu_cache --cpu_cache_storage_size 66 --cpu_cache_token_page_size 128 diff --git a/test/start_scripts/multi_node_ep_node0.sh b/test/start_scripts/multi_node_ep_node0.sh index 3a139968a6..68f80b39d5 100644 --- a/test/start_scripts/multi_node_ep_node0.sh +++ b/test/start_scripts/multi_node_ep_node0.sh @@ -6,7 +6,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ diff --git a/test/start_scripts/multi_node_ep_node1.sh b/test/start_scripts/multi_node_ep_node1.sh index b24a598688..10aee85285 100644 --- a/test/start_scripts/multi_node_ep_node1.sh +++ b/test/start_scripts/multi_node_ep_node1.sh @@ -6,7 +6,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ diff --git a/test/start_scripts/multi_node_tp_node0.sh b/test/start_scripts/multi_node_tp_node0.sh index b86bdeb358..d750da93ca 100644 --- a/test/start_scripts/multi_node_tp_node0.sh +++ b/test/start_scripts/multi_node_tp_node0.sh @@ -5,7 +5,7 @@ export nccl_host=$1 LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ diff --git a/test/start_scripts/multi_node_tp_node1.sh b/test/start_scripts/multi_node_tp_node1.sh index 378977ab2e..cb495496e8 100644 --- a/test/start_scripts/multi_node_tp_node1.sh +++ b/test/start_scripts/multi_node_tp_node1.sh @@ -5,7 +5,7 @@ export nccl_host=$1 LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ diff --git a/test/start_scripts/multi_pd_master/pd_decode.sh b/test/start_scripts/multi_pd_master/pd_decode.sh index 4cefef6fb2..2b7bb80d76 100644 --- a/test/start_scripts/multi_pd_master/pd_decode.sh +++ b/test/start_scripts/multi_pd_master/pd_decode.sh @@ -13,7 +13,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --nccl_port 12322 \ --tp 8 \ --dp 8 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 --config_server_host $config_server_host \ --config_server_port 60088 # if you want to enable microbatch overlap, you can uncomment the following lines diff --git a/test/start_scripts/multi_pd_master/pd_prefill.sh b/test/start_scripts/multi_pd_master/pd_prefill.sh index b845da435d..eaa343ef62 100644 --- a/test/start_scripts/multi_pd_master/pd_prefill.sh +++ b/test/start_scripts/multi_pd_master/pd_prefill.sh @@ -13,7 +13,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --tp 8 \ --dp 8 \ --nccl_port 2732 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --disable_cudagraph \ --config_server_host $config_server_host \ --config_server_port 60088 diff --git a/test/start_scripts/single_node_ep.sh b/test/start_scripts/single_node_ep.sh index cad172d515..7406d94628 100644 --- a/test/start_scripts/single_node_ep.sh +++ b/test/start_scripts/single_node_ep.sh @@ -3,7 +3,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 8 \ --dp 8 \ ---enable_fa3 +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap \ #--enable_decode_microbatch_overlap \ diff --git a/test/start_scripts/single_node_tp.sh b/test/start_scripts/single_node_tp.sh index 1fb461bb11..ee10b6c101 100644 --- a/test/start_scripts/single_node_tp.sh +++ b/test/start_scripts/single_node_tp.sh @@ -2,7 +2,7 @@ LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 8 \ ---enable_fa3 +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap \ #--enable_decode_microbatch_overlap \ diff --git a/test/start_scripts/single_pd_master/pd_decode.sh b/test/start_scripts/single_pd_master/pd_decode.sh index ae16b96ad4..36804dd11e 100644 --- a/test/start_scripts/single_pd_master/pd_decode.sh +++ b/test/start_scripts/single_pd_master/pd_decode.sh @@ -13,7 +13,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --host $host \ --port 8121 \ --nccl_port 12322 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 # if you want to enable microbatch overlap, you can uncomment the following lines diff --git a/test/start_scripts/single_pd_master/pd_nixl_decode.sh b/test/start_scripts/single_pd_master/pd_nixl_decode.sh index 1b43c11cc4..5fb34a973e 100644 --- a/test/start_scripts/single_pd_master/pd_nixl_decode.sh +++ b/test/start_scripts/single_pd_master/pd_nixl_decode.sh @@ -18,7 +18,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --host $host \ --port 8121 \ --nccl_port 12322 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 # if you want to enable microbatch overlap, you can uncomment the following lines diff --git a/test/start_scripts/single_pd_master/pd_nixl_prefill.sh b/test/start_scripts/single_pd_master/pd_nixl_prefill.sh index 303de29758..5a37df0b1d 100644 --- a/test/start_scripts/single_pd_master/pd_nixl_prefill.sh +++ b/test/start_scripts/single_pd_master/pd_nixl_prefill.sh @@ -19,7 +19,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --host $host \ --port 8019 \ --nccl_port 2732 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 diff --git a/test/start_scripts/single_pd_master/pd_prefill.sh b/test/start_scripts/single_pd_master/pd_prefill.sh index f6e2e4b685..b94a1f8ccd 100644 --- a/test/start_scripts/single_pd_master/pd_prefill.sh +++ b/test/start_scripts/single_pd_master/pd_prefill.sh @@ -13,7 +13,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --host $host \ --port 8019 \ --nccl_port 2732 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 From 69a100b1089a54f297fc92efad6fb673ecb42611 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 07:47:34 +0000 Subject: [PATCH 083/114] remove mode --- .../CN/source/tutorial/api_server_args_zh.rst | 16 ------------- .../EN/source/tutorial/api_server_args_zh.rst | 17 -------------- lightllm/common/basemodel/basemodel.py | 12 ++++------ lightllm/server/api_cli.py | 23 ------------------- lightllm/server/core/objs/start_args_type.py | 1 - test/start_scripts/draft.sh | 2 +- .../single_node_tp_cpu_cache_enable.sh | 2 +- 7 files changed, 6 insertions(+), 67 deletions(-) diff --git a/docs/CN/source/tutorial/api_server_args_zh.rst b/docs/CN/source/tutorial/api_server_args_zh.rst index d478f28918..5976fcb322 100755 --- a/docs/CN/source/tutorial/api_server_args_zh.rst +++ b/docs/CN/source/tutorial/api_server_args_zh.rst @@ -183,22 +183,6 @@ PD 分离模式参数 设置为 True 时,--nccl_host 必须等于 config_server_host,--nccl_port 对于 config_server 必须是唯一的, 不要为不同的推理节点使用相同的 nccl_port,这将是严重错误 -attention类型选择参数 ---------------------- - -.. option:: --mode - - 模型推理模式,可以指定多个值: - - * ``triton_int8kv``: 使用 int8 存储 kv cache,可增加 token 容量,使用 triton kernel - * ``ppl_int8kv``: 使用 int8 存储 kv cache,使用 ppl 快速 kernel - * ``ppl_fp16``: 使用 ppl 快速 fp16 解码注意力 kernel - * ``triton_flashdecoding``: 用于长上下文的 flashdecoding 模式,当前支持 llama llama2 qwen - * ``triton_gqa_attention``: 使用 GQA 的模型的快速 kernel - * ``triton_gqa_flashdecoding``: 使用 GQA 的模型的快速 flashdecoding kernel - * ``triton_fp8kv``: 使用 float8 存储 kv cache,目前仅用于 deepseek2 - - 需要阅读源代码以确认所有模型支持的具体模式 调度参数 -------- diff --git a/docs/EN/source/tutorial/api_server_args_zh.rst b/docs/EN/source/tutorial/api_server_args_zh.rst index aae20ecbf5..0767ae7e3b 100755 --- a/docs/EN/source/tutorial/api_server_args_zh.rst +++ b/docs/EN/source/tutorial/api_server_args_zh.rst @@ -183,23 +183,6 @@ Different Parallel Mode Setting Parameters When set to True, --nccl_host must equal config_server_host, --nccl_port must be unique for config_server, do not use the same nccl_port for different inference nodes, this will be a serious error -Attention Type Selection Parameters ------------------------------------- - -.. option:: --mode - - Model inference mode, can specify multiple values: - - * ``triton_int8kv``: Use int8 to store kv cache, can increase token capacity, uses triton kernel - * ``ppl_int8kv``: Use int8 to store kv cache, uses ppl fast kernel - * ``ppl_fp16``: Use ppl fast fp16 decode attention kernel - * ``triton_flashdecoding``: Flashdecoding mode for long context, currently supports llama llama2 qwen - * ``triton_gqa_attention``: Fast kernel for models using GQA - * ``triton_gqa_flashdecoding``: Fast flashdecoding kernel for models using GQA - * ``triton_fp8kv``: Use float8 to store kv cache, currently only used for deepseek2 - - Need to read source code to confirm specific modes supported by all models - Scheduling Parameters --------------------- diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 6513e25db9..9c30b1ff44 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -60,7 +60,6 @@ def __init__(self, kvargs): self.max_total_token_num = kvargs["max_total_token_num"] self.batch_max_tokens = kvargs.get("batch_max_tokens", None) self.load_way = kvargs.get("load_way", "HF") - self.mode = kvargs.get("mode", []) self.weight_dict = kvargs.get("weight_dict", None) self.finetune_config = kvargs.get("finetune_config", None) self.max_req_num = kvargs.get("max_req_num", 1000) @@ -170,15 +169,12 @@ def _init_quant(self): logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") def _init_weights(self, start_layer_index=0): - self.pre_post_weight = self.pre_and_post_weight_class( - self.data_type, network_config=self.config, mode=self.mode - ) + self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) self.trans_layers_weight = [ self.transformer_weight_class( i, self.data_type, network_config=self.config, - mode=self.mode, quant_cfg=self.quant_cfg, ) for i in range(start_layer_index, start_layer_index + self.config["n_layer"]) @@ -228,10 +224,10 @@ def _init_req_manager(self): return def _init_infer_layer(self, start_layer_index=0): - self.pre_infer = self.pre_layer_infer_class(network_config=self.config, mode=self.mode) - self.post_infer = self.post_layer_infer_class(network_config=self.config, mode=self.mode) + self.pre_infer = self.pre_layer_infer_class(network_config=self.config) + self.post_infer = self.post_layer_infer_class(network_config=self.config) self.layers_infer = [ - self.transformer_layer_infer_class(i, network_config=self.config, mode=self.mode) + self.transformer_layer_infer_class(i, network_config=self.config) for i in range(start_layer_index, start_layer_index + self.config["n_layer"]) ] return diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 1eff756186..422a656a87 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -219,29 +219,6 @@ def make_argument_parser() -> argparse.ArgumentParser: the --nccl_host must equal to the config_server_host, and the --nccl_port must be unique for a config_server, dont use same nccl_port for different inference node, it will be critical error""", ) - - parser.add_argument( - "--mode", - type=str, - default=[], - nargs="+", - help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_int8kv_flashdecoding | ppl_int8kv_flashdecoding_diverse - | ppl_fp16 | triton_flashdecoding - | triton_gqa_attention | triton_gqa_flashdecoding | triton_fp8kv | offline_calibration_fp8kv - | export_fp8kv_calibration - triton_flashdecoding mode is for long context, current support llama llama2 qwen; - triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA; - triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel; - triton_fp8kv mode use float8 to store kv cache, currently only for deepseek2; - offline_calibration_fp8kv mode use float8 to store kv cache, need fa3 or flashinfer backend, - currently only for llama and qwen model; - export_fp8kv_calibration record and export kv cache quant calibration results to a json file. - It can be used for llama and qwen model. - Calibration need to disable cudagraph and use fa3 or flashinfer backend. - ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel; - ppl_fp16 mode use ppl fast fp16 decode attention kernel; - you need to read source code to make sure the supported detail mode for all models""", - ) parser.add_argument( "--trust_remote_code", action="store_true", diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 25e6998fdc..7b440718f5 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -63,7 +63,6 @@ class StartArgs: nccl_host: str = field(default="127.0.0.1") nccl_port: int = field(default=28765) use_config_server_to_init_nccl: bool = field(default=False) - mode: List[str] = field(default_factory=list) trust_remote_code: bool = field(default=False) disable_log_stats: bool = field(default=False) log_stats_interval: int = field(default=10) diff --git a/test/start_scripts/draft.sh b/test/start_scripts/draft.sh index 235f4427a0..04f573cd6d 100644 --- a/test/start_scripts/draft.sh +++ b/test/start_scripts/draft.sh @@ -3,7 +3,7 @@ LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /mtc/models/qwen3-8b --tp 2 --dp 1 --enable_cpu_cache --cpu_cache_storage_size 66 --cpu_cache_token_page_size 128 \ --batch_max_tokens 4096 --chunked_prefill_size 2048 \ --max_total_token_num 20000 \ ---mode "ppl_int8kv_flashdecoding" | tee log.txt +--llm_kv_type int8kv | tee log.txt # 精度评测命令 diff --git a/test/start_scripts/single_node_tp_cpu_cache_enable.sh b/test/start_scripts/single_node_tp_cpu_cache_enable.sh index 3caabb59bd..47da83dbe9 100644 --- a/test/start_scripts/single_node_tp_cpu_cache_enable.sh +++ b/test/start_scripts/single_node_tp_cpu_cache_enable.sh @@ -3,7 +3,7 @@ LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /mtc/models/qwen3-8b --tp 2 --dp 1 --enable_cpu_cache --cpu_cache_storage_size 66 --cpu_cache_token_page_size 128 \ --batch_max_tokens 4096 --chunked_prefill_size 2048 \ --max_total_token_num 20000 \ ---mode "ppl_int8kv_flashdecoding" | tee log.txt +--llm_kv_type int8kv | tee log.txt # 精度评测命令 From aaa3531d3fe6949fe9df8d65a106aa9ed0ad61a9 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 07:51:52 +0000 Subject: [PATCH 084/114] fix --- lightllm/server/router/manager.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 89c46d9ed9..ac5c1abee3 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -58,7 +58,6 @@ def __init__(self, args: StartArgs): # 判断是否是保守调度,保守调度不会发生暂停 req 的情况,但是有些场景可能影响吞吐 self.is_safe_schedule = args.router_token_ratio == 0.0 self.load_way = args.load_way - self.mode = args.mode self.max_total_token_num = args.max_total_token_num self.shm_req_manager = ShmReqManager() # 用共享内存进行共享,router 模块读取进行精确的调度估计 @@ -155,7 +154,6 @@ async def wait_to_model_ready(self): "weight_dir": self.model_weightdir, "load_way": self.load_way, "max_total_token_num": self.max_total_token_num, - "mode": self.mode, "max_req_num": self.args.running_max_req_size + 8, "max_seq_length": self.args.max_req_total_len + 8, # 留一点余量 "nccl_host": self.args.nccl_host, From d7faae09e9a4f7f868319c8599390aae60505c7d Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 07:53:34 +0000 Subject: [PATCH 085/114] fix --- .../server/router/model_infer/mode_backend/base_backend.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 92653bc0cd..805c9b8e50 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -88,7 +88,6 @@ def init_model(self, kvargs): # dp_size_in_node 计算兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容 self.dp_size_in_node = max(1, self.dp_size // self.nnodes) self.load_way = kvargs["load_way"] - self.mode = kvargs["mode"] self.disable_chunked_prefill = self.args.disable_chunked_prefill self.chunked_prefill_size = self.args.chunked_prefill_size self.return_all_prompt_logprobs = self.args.return_all_prompt_logprobs @@ -148,7 +147,6 @@ def init_model(self, kvargs): "weight_dir": self.weight_dir, "max_total_token_num": max_total_token_num, "load_way": self.load_way, - "mode": self.mode, "max_req_num": kvargs.get("max_req_num", 1000), "max_seq_length": kvargs.get("max_seq_length", 1024 * 5), "is_token_healing": kvargs.get("is_token_healing", False), @@ -302,7 +300,6 @@ def init_mtp_draft_model(self, main_kvargs: dict): "weight_dir": self.args.mtp_draft_model_dir[i], "max_total_token_num": self.model.mem_manager.size, "load_way": main_kvargs["load_way"], - "mode": main_kvargs["mode"], "max_req_num": main_kvargs.get("max_req_num", 1000), "max_seq_length": main_kvargs.get("max_seq_length", 1024 * 5), "is_token_healing": False, From 9b64be61633a028dc326b540b6d5b0d9411996b8 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 08:00:01 +0000 Subject: [PATCH 086/114] fix --- .../basemodel/layer_infer/post_layer_infer.py | 3 +- .../basemodel/layer_infer/pre_layer_infer.py | 3 +- .../template/post_layer_infer_template.py | 4 +- .../template/pre_layer_infer_template.py | 4 +- ...transformer_layer_infer_cohere_template.py | 126 ----------- .../transformer_layer_infer_template.py | 4 +- .../layer_infer/transformer_layer_infer.py | 3 +- .../pre_and_post_layer_weight.py | 3 +- .../layer_weights/transformer_layer_weight.py | 3 +- .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 4 +- lightllm/models/cohere/__init__.py | 0 lightllm/models/cohere/infer_struct.py | 8 - .../models/cohere/layer_infer/__init__.py | 0 .../cohere/layer_infer/post_layer_infer.py | 71 ------- .../layer_infer/transformer_layer_infer.py | 84 -------- .../models/cohere/layer_weights/__init__.py | 0 .../pre_and_post_layer_weight.py | 25 --- .../layer_weights/transformer_layer_weight.py | 25 --- lightllm/models/cohere/model.py | 64 ------ .../models/cohere/triton_kernels/__init__.py | 0 .../models/cohere/triton_kernels/layernorm.py | 131 ------------ .../cohere/triton_kernels/rotary_emb.py | 199 ------------------ 23 files changed, 15 insertions(+), 753 deletions(-) delete mode 100755 lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py delete mode 100644 lightllm/models/cohere/__init__.py delete mode 100644 lightllm/models/cohere/infer_struct.py delete mode 100644 lightllm/models/cohere/layer_infer/__init__.py delete mode 100644 lightllm/models/cohere/layer_infer/post_layer_infer.py delete mode 100644 lightllm/models/cohere/layer_infer/transformer_layer_infer.py delete mode 100644 lightllm/models/cohere/layer_weights/__init__.py delete mode 100644 lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py delete mode 100644 lightllm/models/cohere/layer_weights/transformer_layer_weight.py delete mode 100644 lightllm/models/cohere/model.py delete mode 100644 lightllm/models/cohere/triton_kernels/__init__.py delete mode 100644 lightllm/models/cohere/triton_kernels/layernorm.py delete mode 100644 lightllm/models/cohere/triton_kernels/rotary_emb.py diff --git a/lightllm/common/basemodel/layer_infer/post_layer_infer.py b/lightllm/common/basemodel/layer_infer/post_layer_infer.py index d254eb510a..c7bae26ead 100644 --- a/lightllm/common/basemodel/layer_infer/post_layer_infer.py +++ b/lightllm/common/basemodel/layer_infer/post_layer_infer.py @@ -4,8 +4,7 @@ class PostLayerInfer(BaseLayerInfer): """ """ - def __init__(self, network_config, mode): + def __init__(self, network_config): super().__init__() self.network_config_ = network_config - self.mode = mode return diff --git a/lightllm/common/basemodel/layer_infer/pre_layer_infer.py b/lightllm/common/basemodel/layer_infer/pre_layer_infer.py index 3626346f20..e83fe89490 100644 --- a/lightllm/common/basemodel/layer_infer/pre_layer_infer.py +++ b/lightllm/common/basemodel/layer_infer/pre_layer_infer.py @@ -4,8 +4,7 @@ class PreLayerInfer(BaseLayerInfer): """ """ - def __init__(self, network_config, mode): + def __init__(self, network_config): super().__init__() self.network_config_ = network_config - self.mode = mode return diff --git a/lightllm/common/basemodel/layer_infer/template/post_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/post_layer_infer_template.py index fa7e96a694..1b7813fca3 100644 --- a/lightllm/common/basemodel/layer_infer/template/post_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/post_layer_infer_template.py @@ -6,8 +6,8 @@ class PostLayerInferTpl(PostLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.eps_ = 1e-5 self.vocab_size_ = network_config["vocab_size"] self.embed_dim_ = network_config["n_embed"] diff --git a/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py index e7a0840794..04f8cda16b 100644 --- a/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py @@ -5,8 +5,8 @@ class PreLayerInferTpl(PreLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.eps_ = 1e-5 return diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py deleted file mode 100755 index 379e891d1e..0000000000 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py +++ /dev/null @@ -1,126 +0,0 @@ -from functools import partial -from typing import Tuple -import torch -import torch.distributed as dist -from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_template import TransformerLayerInferTpl -from ...infer_struct import InferStateInfo -from lightllm.distributed.communication_op import all_reduce - - -class TransformerLayerCohereInferTpl(TransformerLayerInferTpl): - """ """ - - def __init__(self, layer_num, network_config, mode): - super().__init__(layer_num, network_config, mode) - - self.use_qk_norm_ = self.network_config_.get("use_qk_norm", False) - return - - def _att_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - raise Exception("need to impl") - - def _q_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - raise Exception("need to impl") - - def _k_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - raise Exception("need to impl") - - def _rotary_emb_fwd(self, q, kv, position_cos, position_sin): - raise Exception("need to impl") - - def _bind_rotary_emb_fwd(self): - raise Exception("need to impl") - - def _get_qkv( - self, input, infer_state: InferStateInfo, layer_weight - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_)) - cache_kv = layer_weight.kv_proj.mm(input.view(-1, self.embed_dim_)).view( - -1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_ - ) - - if self.use_qk_norm_: - q = q.view(-1, self.tp_q_head_num_, self.head_dim_) - k = cache_kv[:, 0 : self.tp_k_head_num_, :] - q = self._q_norm(q, infer_state, layer_weight) - cache_kv[:, 0 : self.tp_k_head_num_, :] = self._k_norm(k, infer_state, layer_weight) - self._rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, 0 : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, - ) - return q, cache_kv - - def _context_attention_kernel(self, q, kv, infer_state: InferStateInfo, layer_weight, out=None) -> torch.Tensor: - raise Exception("need to impl") - - def _token_attention_kernel(self, q, infer_state: InferStateInfo, layer_weight, out=None) -> torch.Tensor: - raise Exception("need to impl") - - def _get_o(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - raise Exception("need to impl") - - def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - raise Exception("need to impl") - - def _context_attention(self, input_embding, infer_state: InferStateInfo, layer_weight): - q, cache_kv = self._get_qkv(input_embding, infer_state, layer_weight) - self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) - q = None - o = self._get_o(o, infer_state, layer_weight) - if self.tp_world_size_ > 1: - all_reduce(o, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - infer_state._attn_out = o - return - - def _context_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight): - ffn_out = self._ffn(input_embdings, infer_state, layer_weight) - if self.tp_world_size_ > 1: - all_reduce(ffn_out, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - infer_state._ffn_out = ffn_out - return - - def _token_attention(self, input_embding, infer_state: InferStateInfo, layer_weight): - q, cache_kv = self._get_qkv(input_embding, infer_state, layer_weight) - self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._token_attention_kernel(q, infer_state, layer_weight) - q = None - o = self._get_o(o, infer_state, layer_weight) - if self.tp_world_size_ > 1: - all_reduce(o, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - infer_state._attn_out = o - return - - def _token_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight): - ffn_out = self._ffn(input_embdings, infer_state, layer_weight) - if self.tp_world_size_ > 1: - all_reduce(ffn_out, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - infer_state._ffn_out = ffn_out - return - - def _cohere_residual(self, input_embdings, infer_state: InferStateInfo): - # emb_addr = input_embdings.data_ptr() - # attn_out_addr = infer_state._attn_out.data_ptr() - # ffn_addr = infer_state._ffn_out.data_ptr() - # assert emb_addr != attn_out_addr - # assert emb_addr != ffn_addr - # assert attn_out_addr != ffn_addr - input_embdings.add_( - infer_state._attn_out.view(-1, self.embed_dim_) + infer_state._ffn_out.view(-1, self.embed_dim_) - ) - - def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - self._context_attention(input1, infer_state, layer_weight=layer_weight) - self._context_ffn(input1, infer_state, layer_weight) - self._cohere_residual(input_embdings, infer_state) - return input_embdings - - def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - self._token_attention(input1, infer_state, layer_weight=layer_weight) - self._token_ffn(input1, infer_state, layer_weight) - self._cohere_residual(input_embdings, infer_state) - return input_embdings diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index fdfc9a193e..9153349c5d 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -11,8 +11,8 @@ class TransformerLayerInferTpl(TransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) # need to set by subclass self.eps_ = 1e-5 self.tp_q_head_num_ = -1 diff --git a/lightllm/common/basemodel/layer_infer/transformer_layer_infer.py b/lightllm/common/basemodel/layer_infer/transformer_layer_infer.py index 7350531bbb..53daffcddf 100644 --- a/lightllm/common/basemodel/layer_infer/transformer_layer_infer.py +++ b/lightllm/common/basemodel/layer_infer/transformer_layer_infer.py @@ -4,9 +4,8 @@ class TransformerLayerInfer(BaseLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode): + def __init__(self, layer_num, network_config): super().__init__() self.layer_num_ = layer_num self.network_config_ = network_config - self.mode = mode return diff --git a/lightllm/common/basemodel/layer_weights/pre_and_post_layer_weight.py b/lightllm/common/basemodel/layer_weights/pre_and_post_layer_weight.py index 19eb67017d..8c81fd5bcc 100644 --- a/lightllm/common/basemodel/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/pre_and_post_layer_weight.py @@ -3,11 +3,10 @@ class PreAndPostLayerWeight(BaseLayerWeight): - def __init__(self, data_type, network_config, mode): + def __init__(self, data_type, network_config): super().__init__() self.data_type_ = data_type self.network_config_ = network_config - self.mode = mode self.init_static_params() return diff --git a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py index 97bc762370..4bc58c76f6 100644 --- a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py @@ -9,12 +9,11 @@ class TransformerLayerWeight(BaseLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode, quant_cfg): + def __init__(self, layer_num, data_type, network_config, quant_cfg): super().__init__() self.layer_num_ = layer_num self.data_type_ = data_type self.network_config_ = network_config - self.mode = mode self.quant_cfg = quant_cfg self._parse_config() self._init_weight_names() diff --git a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py index afc8c93081..83f7674531 100644 --- a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py @@ -5,8 +5,8 @@ class BloomPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.pre_norm_weight_ = NoTpNormWeight( weight_name="word_embeddings_layernorm.weight", data_type=self.data_type_, diff --git a/lightllm/models/bloom/layer_weights/transformer_layer_weight.py b/lightllm/models/bloom/layer_weights/transformer_layer_weight.py index 7b27ce6f2c..599893655d 100644 --- a/lightllm/models/bloom/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/bloom/layer_weights/transformer_layer_weight.py @@ -48,8 +48,8 @@ def get_slopes_power_of_2(n): class BloomTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode, quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _parse_config(self): diff --git a/lightllm/models/cohere/__init__.py b/lightllm/models/cohere/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/models/cohere/infer_struct.py b/lightllm/models/cohere/infer_struct.py deleted file mode 100644 index d9571af92b..0000000000 --- a/lightllm/models/cohere/infer_struct.py +++ /dev/null @@ -1,8 +0,0 @@ -from lightllm.models.llama.infer_struct import LlamaInferStateInfo - - -class CohereInferStateInfo(LlamaInferStateInfo): - def __init__(self): - super().__init__() - self._attn_out = None - self._ffn_out = None diff --git a/lightllm/models/cohere/layer_infer/__init__.py b/lightllm/models/cohere/layer_infer/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/models/cohere/layer_infer/post_layer_infer.py b/lightllm/models/cohere/layer_infer/post_layer_infer.py deleted file mode 100644 index 67987a8d3b..0000000000 --- a/lightllm/models/cohere/layer_infer/post_layer_infer.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch -import numpy as np -from lightllm.models.cohere.infer_struct import CohereInferStateInfo -from lightllm.models.cohere.layer_weights.pre_and_post_layer_weight import CoherePreAndPostLayerWeight -from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward -from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer -from lightllm.common.build_utils import repair_config -from lightllm.distributed.communication_op import all_gather - - -class CoherePostLayerInfer(LlamaPostLayerInfer): - def __init__(self, network_config, mode): - repair_config(config=network_config, same_names=["layer_norm_eps", "rms_norm_eps"]) - super().__init__(network_config, mode) - self.eps_ = network_config["layer_norm_eps"] - self.logits_scale = network_config["logit_scale"] - return - - def _norm( - self, input: torch.Tensor, infer_state: CohereInferStateInfo, layer_weight: CoherePreAndPostLayerWeight - ) -> torch.Tensor: - return layernorm_forward( - input.unsqueeze(1), layer_weight.final_norm_weight_.weight.unsqueeze(0), eps=self.eps_ - ).squeeze(1) - - def token_forward( - self, input_embdings: torch.Tensor, infer_state: CohereInferStateInfo, layer_weight: CoherePreAndPostLayerWeight - ): - last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) - input_embdings_dtype = input_embdings.dtype - input_embdings = None - last_input = self._norm(last_input, infer_state, layer_weight) - last_input = last_input.permute(1, 0).view(-1, token_num) - logic_batch = layer_weight.lm_head_weight_.lm_head(input=last_input, alloc_func=self.alloc_tensor) - last_input = None - vocab_size = layer_weight.lm_head_weight_.vocab_size - if self.tp_world_size_ == 1: - gather_data = logic_batch - else: - gather_data = self.alloc_tensor((vocab_size, token_num), dtype=input_embdings_dtype) - split_indexes = np.linspace(0, vocab_size, self.tp_world_size_ + 1, dtype=np.int64) - all_gather( - [gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.tp_world_size_)], - logic_batch, - group=infer_state.dist_group, - async_op=False, - ) - gather_data = gather_data * self.logits_scale - logic_batch = None - ans_logics = self.alloc_tensor( - (token_num, vocab_size), - dtype=torch.float32, - ) - ans_logics[:, :] = gather_data.permute(1, 0) - gather_data = None - return ans_logics - - def tpsp_token_forward( - self, input_embdings: torch.Tensor, infer_state: CohereInferStateInfo, layer_weight: CoherePreAndPostLayerWeight - ): - raise NotImplementedError("not impl") - - def overlap_tpsp_token_forward( - self, - input_embdings: torch.Tensor, - input_embdings1: torch.Tensor, - infer_state: CohereInferStateInfo, - infer_state1: CohereInferStateInfo, - layer_weight: CoherePreAndPostLayerWeight, - ): - raise NotImplementedError("not impl") diff --git a/lightllm/models/cohere/layer_infer/transformer_layer_infer.py b/lightllm/models/cohere/layer_infer/transformer_layer_infer.py deleted file mode 100644 index b3dcba937a..0000000000 --- a/lightllm/models/cohere/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,84 +0,0 @@ -import torch -from functools import partial - -from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_cohere_template import ( - TransformerLayerCohereInferTpl, -) -from lightllm.models.cohere.infer_struct import CohereInferStateInfo -from lightllm.models.cohere.layer_weights.transformer_layer_weight import CohereTransformerLayerWeight -from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward, torch_layernorm -from lightllm.models.cohere.triton_kernels.rotary_emb import rotary_emb_fwd -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd - - -class CohereTransformerLayerInfer(TransformerLayerCohereInferTpl): - def __init__(self, layer_num, network_config, mode): - super().__init__(layer_num, network_config, mode) - self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ - self.tp_k_head_num_ = network_config["num_key_value_heads"] // self.tp_world_size_ - self.tp_v_head_num_ = network_config["num_key_value_heads"] // self.tp_world_size_ - self.tp_o_head_num_ = self.tp_q_head_num_ - self.head_dim_ = network_config["hidden_size"] // network_config["num_attention_heads"] - self.embed_dim_ = network_config["hidden_size"] - self.eps_ = self.network_config_["layer_norm_eps"] - self.use_qk_norm_ = network_config.get("use_qk_norm", False) - self._bind_func() - - def _bind_func(self): - self._bind_rotary_emb_fwd() - self._bind_norm() - self._bind_attn() - - def _bind_norm(self): - self._att_norm = partial(CohereTransformerLayerInfer._att_norm, self) - self._q_norm = partial(CohereTransformerLayerInfer._q_norm, self) - self._k_norm = partial(CohereTransformerLayerInfer._k_norm, self) - - def _bind_attn(self): - self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self) - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_attention_kernel, self) - - def _rotary_emb_fwd(self, q, kv, position_cos, position_sin): - return rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv, - position_cos, - position_sin, - ) - - def _bind_rotary_emb_fwd(self): - self._rotary_emb_fwd = partial(CohereTransformerLayerInfer._rotary_emb_fwd, self) - - def _att_norm(self, input, infer_state, layer_weight: CohereTransformerLayerWeight): - return layernorm_forward( - input.unsqueeze(1), layer_weight.att_norm_weight_.weight.unsqueeze(0), self.eps_ - ).squeeze(1) - - def _q_norm(self, input, infer_state, layer_weight: CohereTransformerLayerWeight): - return layernorm_forward(input, layer_weight.q_norm_weight_.weight, self.eps_) - - def _k_norm(self, input, infer_state, layer_weight: CohereTransformerLayerWeight): - return layernorm_forward(input, layer_weight.k_norm_weight_.weight, self.eps_) - - def _get_o( - self, input, infer_state: CohereInferStateInfo, layer_weight: CohereTransformerLayerWeight - ) -> torch.Tensor: - input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) - # o_tensor = layer_weight.mm_op.apply(input, layer_weight.o_weight_) - o_tensor = layer_weight.o_proj.mm(input) - return o_tensor - - def _ffn( - self, input, infer_state: CohereInferStateInfo, layer_weight: CohereTransformerLayerWeight - ) -> torch.Tensor: - input = input.view(-1, self.embed_dim_) - up_gate_out = layer_weight.gate_up_proj.mm(input) - ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) - silu_and_mul_fwd(up_gate_out, ffn1_out) - input = None - up_gate_out = None - # ffn2_out = layer_weight.mm_op.apply(ffn1_out, layer_weight.down_proj) - ffn2_out = layer_weight.down_proj.mm(ffn1_out) - ffn1_out = None - return ffn2_out diff --git a/lightllm/models/cohere/layer_weights/__init__.py b/lightllm/models/cohere/layer_weights/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py deleted file mode 100644 index f2e5f85472..0000000000 --- a/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py +++ /dev/null @@ -1,25 +0,0 @@ -from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight - - -class CoherePreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) - tie_weight = self.network_config_.get("tie_word_embeddings", True) - - self.wte_weight_ = EmbeddingWeight( - weight_name="model.embed_tokens.weight", - data_type=self.data_type_, - ) - if tie_weight: - self.lm_head_weight_ = self.wte_weight_ - else: - self.lm_head_weight_ = LMHeadWeight( - weight_name="model.lm_head.weight", - data_type=self.data_type_, - ) - self.final_norm_weight_ = NoTpNormWeight( - weight_name="model.norm.weight", - data_type=self.data_type_, - bias_name=None, - ) diff --git a/lightllm/models/cohere/layer_weights/transformer_layer_weight.py b/lightllm/models/cohere/layer_weights/transformer_layer_weight.py deleted file mode 100644 index 9c446b49e9..0000000000 --- a/lightllm/models/cohere/layer_weights/transformer_layer_weight.py +++ /dev/null @@ -1,25 +0,0 @@ -from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import NoTpNormWeight, TpHeadNormWeight - - -class CohereTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) - return - - def _parse_config(self): - super()._parse_config() - self.use_qk_norm = self.network_config_.get("use_qk_norm", False) - - def _init_norm(self): - self.att_norm_weight_ = NoTpNormWeight(self._att_norm_weight_name, self.data_type_) - - if self.use_qk_norm: - self.q_norm_weight_ = TpHeadNormWeight( - f"model.layers.{self.layer_num_}.self_attn.q_norm.weight", self.data_type_ - ) - self.k_norm_weight_ = TpHeadNormWeight( - f"model.layers.{self.layer_num_}.self_attn.k_norm.weight", self.data_type_ - ) - - return diff --git a/lightllm/models/cohere/model.py b/lightllm/models/cohere/model.py deleted file mode 100644 index 05ccaac3e3..0000000000 --- a/lightllm/models/cohere/model.py +++ /dev/null @@ -1,64 +0,0 @@ -import os -import torch -from lightllm.models.registry import ModelRegistry -from lightllm.models.cohere.infer_struct import CohereInferStateInfo -from lightllm.models.cohere.layer_infer.post_layer_infer import CoherePostLayerInfer -from lightllm.models.cohere.layer_infer.transformer_layer_infer import CohereTransformerLayerInfer -from lightllm.models.cohere.layer_weights.pre_and_post_layer_weight import CoherePreAndPostLayerWeight -from lightllm.models.cohere.layer_weights.transformer_layer_weight import CohereTransformerLayerWeight -from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer -from lightllm.models.llama.model import LlamaTpPartModel -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -@ModelRegistry("cohere") -class CohereTpPartModel(LlamaTpPartModel): - pre_and_post_weight_class = CoherePreAndPostLayerWeight - transformer_weight_class = CohereTransformerLayerWeight - - pre_layer_infer_class = LlamaPreLayerInfer - transformer_layer_infer_class = CohereTransformerLayerInfer - post_layer_infer_class = CoherePostLayerInfer - - infer_state_class = CohereInferStateInfo - - def _init_to_get_rotary(self, default_base=10000): - partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) - if self.config.get("rope_scaling", {}) is None: - rope_scaling_factor = 1.0 - else: - rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) - - base = self.config.get("rope_theta", float(default_base)) - - if "max_sequence_length" in self.config: - max_seq_len = self.config["max_sequence_length"] - else: - max_position_embeddings = self.config.get( - "max_position_embeddings", 2048 if base <= 10000.0 + 1e-5 else 16384 - ) - max_seq_len = max_position_embeddings * rope_scaling_factor - - # NTK - try: - ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1)) - assert ntk_alpha >= 1 - if ntk_alpha > 1: - logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}") - max_seq_len *= ntk_alpha - base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula - except: - pass - - inv_freq = 1.0 / ( - base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) - ) - t = torch.arange(max_seq_len + 1024 * 128, device="cpu", dtype=torch.float32) / rope_scaling_factor - freqs = torch.outer(t, inv_freq) - freqs = torch.repeat_interleave(freqs, 2, dim=-1) - - self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() - self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() - return diff --git a/lightllm/models/cohere/triton_kernels/__init__.py b/lightllm/models/cohere/triton_kernels/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/models/cohere/triton_kernels/layernorm.py b/lightllm/models/cohere/triton_kernels/layernorm.py deleted file mode 100644 index c1d5ff4cd6..0000000000 --- a/lightllm/models/cohere/triton_kernels/layernorm.py +++ /dev/null @@ -1,131 +0,0 @@ -import torch -import triton -import triton.language as tl - -# LayerNorm adapted from triton tutorial, used for Cohere q, k norm -# X [N, head_num, head_dim] -# W [head_num, head_dim] -@triton.jit -def _layer_norm_fwd_kernel( - X, # pointer to the input - W, # pointer to the weights - Y, - stride_x_N, - stride_x_hn, - stride_x_hd, - stride_y_N, - stride_y_hn, - stride_y_hd, - stride_w_hn, - stride_w_hd, - N, # number of columns in X - eps, # epsilon to avoid division by zero - BLOCK_SIZE: tl.constexpr, -): - Seq = tl.program_id(0) - H = tl.program_id(1) - - X += Seq * stride_x_N + H * stride_x_hn - Y += Seq * stride_y_N + H * stride_y_hn - W += H * stride_w_hn - - _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - _mean += a - mean = tl.sum(_mean, axis=0) / N - - _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - x = tl.where(cols < N, x - mean, 0.0) - _var += x * x - var = tl.sum(_var, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) - x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) - x_hat = (x - mean) * rstd - y = x_hat * w - - tl.store(Y + cols, y.to(X.dtype.element_ty), mask=mask) - - -def layernorm_forward( - X, # pointer to the input - W, # pointer to the weights - eps, # epsilon to avoid division by zero -): - assert len(X.shape) == 3 - assert len(W.shape) == 2 - assert X.shape[-1] == W.shape[-1] - assert X.shape[-2] == W.shape[-2] - - y = torch.empty_like(X) - - stride_x_N = X.stride(0) - stride_x_hn = X.stride(1) - stride_x_hd = X.stride(2) - - stride_y_N = y.stride(0) - stride_y_hn = y.stride(1) - stride_y_hd = y.stride(2) - - stride_w_hn = W.stride(0) - stride_w_hd = W.stride(1) - - N = X.shape[-1] - BLOCK_SIZE = 128 - - grid = (X.shape[0], X.shape[1]) - _layer_norm_fwd_kernel[grid]( - X, - W, - y, - stride_x_N, - stride_x_hn, - stride_x_hd, - stride_y_N, - stride_y_hn, - stride_y_hd, - stride_w_hn, - stride_w_hd, - N, - eps, - BLOCK_SIZE, - ) - - return y - - -def torch_layernorm(x, weight, eps): - inp_dtype = x.dtype - x = x.to(torch.float32) - mean = x.mean(-1, keepdim=True) - variance = (x - mean).pow(2).mean(-1, keepdim=True) - x = (x - mean) * torch.rsqrt(variance + eps) - x = weight.to(torch.float32) * x - return x.to(inp_dtype) - - -def test_layernorm(eps=1e-5): - # create data - dtype = torch.float16 - x_shape = (5, 1, 128) - w_shape = (x_shape[-2], x_shape[-1]) - weight = torch.rand(w_shape, dtype=dtype, device="cuda") - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - # forward pass - y_ref = torch_layernorm(x, weight, eps).to(dtype) - y_out = layernorm_forward(x, weight, eps) - - # compare - print("type:", y_out.dtype, y_ref.dtype) - print("max delta:", torch.max(torch.abs(y_out - y_ref))) - assert torch.allclose(y_out, y_ref, atol=1e-2, rtol=0) - return diff --git a/lightllm/models/cohere/triton_kernels/rotary_emb.py b/lightllm/models/cohere/triton_kernels/rotary_emb.py deleted file mode 100644 index ac338e71ef..0000000000 --- a/lightllm/models/cohere/triton_kernels/rotary_emb.py +++ /dev/null @@ -1,199 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _rotary_kernel( - Q, - K, - Cos, - Sin, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_cosbs, - stride_cosd, - stride_sinbs, - stride_sind, - max_total_len, - HEAD_Q, - HEAD_K, # N_CTX 代表要计算的上下文长度 - BLOCK_HEAD: tl.constexpr, - BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, -): - cur_head_index = tl.program_id(0) - cur_seq_index = tl.program_id(1) - - cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) - cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) - - dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 - dim_range1 = tl.arange(0, BLOCK_DMODEL // 2) * 2 + 1 - - off_q0 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range0[None, None, :] * stride_qd - ) - off_q1 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range1[None, None, :] * stride_qd - ) - - off_dimcos_sin0 = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd - off_dimcos_sin1 = cur_seq_range[:, None, None] * stride_cosbs + dim_range1[None, None, :] * stride_cosd - - q0 = tl.load( - Q + off_q0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), - other=0.0, - ) - q1 = tl.load( - Q + off_q1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), - other=0.0, - ) - - cos0 = tl.load(Cos + off_dimcos_sin0, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin0 = tl.load(Sin + off_dimcos_sin0, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - cos1 = tl.load(Cos + off_dimcos_sin1, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin1 = tl.load(Sin + off_dimcos_sin1, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - out0 = q0 * cos0 - q1 * sin0 - out1 = q0 * sin1 + q1 * cos1 - - tl.store( - Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) - ) - tl.store( - Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) - ) - - off_k0 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range0[None, None, :] * stride_kd - ) - off_k1 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range1[None, None, :] * stride_kd - ) - - off_dimcos_sin0 = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd - off_dimcos_sin1 = cur_seq_range[:, None, None] * stride_cosbs + dim_range1[None, None, :] * stride_cosd - - k0 = tl.load( - K + off_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - other=0.0, - ) - k1 = tl.load( - K + off_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - other=0.0, - ) - - cos0 = tl.load(Cos + off_dimcos_sin0, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin0 = tl.load(Sin + off_dimcos_sin0, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - cos1 = tl.load(Cos + off_dimcos_sin1, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin1 = tl.load(Sin + off_dimcos_sin1, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - out_k0 = k0 * cos0 - k1 * sin0 - out_k1 = k0 * sin1 + k1 * cos1 - - tl.store( - K + off_k0, - out_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - ) - tl.store( - K + off_k1, - out_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - ) - return - - -def torch_cohere_rotary_emb(x, cos, sin): - dtype = x.dtype - seq_len, h, dim = x.shape - x = x.float() - x1 = x[:, :, ::2] - x2 = x[:, :, 1::2] - rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) - cos = cos.view((seq_len, 1, dim)) - sin = sin.view((seq_len, 1, dim)) - o = (x * cos) + (rot_x * sin) - return o.to(dtype=dtype) - - -@torch.no_grad() -def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.0): - total_len = q.shape[0] - head_num_q, head_num_k = q.shape[1], k.shape[1] - head_dim = int(q.shape[2] * partial_rotary_factor) - assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" - - BLOCK_SEQ = 16 - BLOCK_HEAD = 4 - if head_dim >= 128: - num_warps = 8 - else: - num_warps = 4 - - grid = (triton.cdiv(head_num_q, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) - _rotary_kernel[grid]( - q, - k, - cos, - sin, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - cos.stride(0), - cos.stride(1), - sin.stride(0), - sin.stride(1), - total_len, - head_num_q, - head_num_k, - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_SEQ=BLOCK_SEQ, - BLOCK_DMODEL=head_dim, - num_warps=num_warps, - num_stages=1, - ) - return - - -def test_rotary_emb(SEQ_LEN, H, D, dtype, eps=1e-5, device="cuda"): - # create data - x_shape = (SEQ_LEN, H, D) - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - y = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - cos_shape = (SEQ_LEN, D) - cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") - sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") - # forward pass - y_tri = torch_cohere_rotary_emb(x, cos, sin) - rotary_emb_fwd(x, y, cos, sin) - y_ref = x - - # compare - print("type:", y_tri.dtype, y_ref.dtype) - print("max delta:", torch.max(torch.abs(y_tri - y_ref))) - assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) From 93f87cf84d5426d8c86128029d26cedd3b3442e8 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 08:23:39 +0000 Subject: [PATCH 087/114] fix mode --- lightllm/models/__init__.py | 1 - .../models/bloom/layer_infer/post_layer_infer.py | 4 ++-- lightllm/models/bloom/layer_infer/pre_layer_infer.py | 4 ++-- .../bloom/layer_infer/transformer_layer_infer.py | 4 ++-- .../deepseek2/layer_infer/transformer_layer_infer.py | 4 ++-- .../layer_weights/transformer_layer_weight.py | 4 ++-- .../deepseek_mtp/layer_infer/pre_layer_infer.py | 4 ++-- .../layer_weights/pre_and_post_layer_weight.py | 4 ++-- .../models/gemma3/layer_infer/post_layer_infer.py | 4 ++-- .../models/gemma3/layer_infer/pre_layer_infer.py | 4 ++-- .../gemma3/layer_infer/transformer_layer_infer.py | 4 ++-- .../layer_weights/pre_and_post_layer_weight.py | 4 ++-- .../gemma3/layer_weights/transformer_layer_weight.py | 3 +-- .../models/gemma_2b/layer_infer/pre_layer_infer.py | 4 ++-- .../gemma_2b/layer_infer/transformer_layer_infer.py | 4 ++-- .../layer_weights/pre_and_post_layer_weight.py | 4 ++-- .../layer_weights/transformer_layer_weight.py | 4 ++-- .../gpt_oss/layer_infer/transformer_layer_infer.py | 4 ++-- .../layer_weights/transformer_layer_weight.py | 3 +-- .../layer_weights/transformer_layer_weight.py | 4 ++-- .../layer_weights/pre_and_post_layer_weight.py | 4 ++-- .../layer_weights/transformer_layer_weight.py | 4 ++-- .../layer_weights/pre_and_post_layer_weight.py | 4 ++-- .../layer_weights/pre_and_post_layer_weight.py | 12 ++++++------ .../models/llama/layer_infer/post_layer_infer.py | 4 ++-- lightllm/models/llama/layer_infer/pre_layer_infer.py | 4 ++-- .../llama/layer_infer/transformer_layer_infer.py | 4 ++-- .../llama/layer_weights/pre_and_post_layer_weight.py | 4 ++-- .../llama/layer_weights/transformer_layer_weight.py | 3 +-- .../llava/layer_weights/pre_and_post_layer_weight.py | 4 ++-- .../layer_weights/pre_and_post_layer_weight.py | 4 ++-- .../layer_weights/transformer_layer_weight.py | 4 ++-- .../mistral/layer_infer/transformer_layer_infer.py | 4 ++-- .../mistral_mtp/layer_infer/post_layer_infer.py | 4 ++-- .../mistral_mtp/layer_infer/pre_layer_infer.py | 4 ++-- .../layer_infer/transformer_layer_infer.py | 4 ++-- .../layer_weights/pre_and_post_layer_weight.py | 4 ++-- .../layer_weights/transformer_layer_weight.py | 4 ++-- .../mixtral/layer_infer/transformer_layer_infer.py | 4 ++-- .../layer_weights/transformer_layer_weight.py | 3 +-- .../phi3/layer_infer/transformer_layer_infer.py | 4 ++-- .../phi3/layer_weights/transformer_layer_weight.py | 4 ++-- .../qwen/layer_infer/transformer_layer_infer.py | 4 ++-- .../qwen/layer_weights/pre_and_post_layer_weight.py | 4 ++-- .../qwen2/layer_weights/pre_and_post_layer_weight.py | 4 ++-- .../layer_weights/pre_and_post_layer_weight.py | 4 ++-- .../qwen2_vl/layer_infer/transformer_layer_infer.py | 4 ++-- .../qwen3/layer_infer/transformer_layer_infer.py | 4 ++-- .../qwen3_moe/layer_infer/transformer_layer_infer.py | 4 ++-- .../layer_infer/transformer_layer_infer.py | 4 ++-- .../layer_weights/pre_and_post_layer_weight.py | 4 ++-- .../models/qwen3_vl/layer_infer/pre_layer_infer.py | 4 ++-- .../qwen3_vl/layer_infer/transformer_layer_infer.py | 4 ++-- .../layer_weights/pre_and_post_layer_weight.py | 4 ++-- .../layer_infer/transformer_layer_infer.py | 4 ++-- .../layer_weights/pre_and_post_layer_weight.py | 4 ++-- .../models/qwen_vl/layer_infer/pre_layer_infer.py | 4 ++-- .../stablelm/layer_infer/transformer_layer_infer.py | 4 ++-- .../layer_weights/pre_and_post_layer_weight.py | 4 ++-- .../models/starcoder/layer_infer/pre_layer_infer.py | 4 ++-- .../starcoder/layer_infer/transformer_layer_infer.py | 4 ++-- .../layer_weights/pre_and_post_layer_weight.py | 4 ++-- .../layer_infer/transformer_layer_infer.py | 4 ++-- .../layer_weights/pre_and_post_layer_weight.py | 4 ++-- .../layer_weights/pre_and_post_layer_weight.py | 8 ++++---- lightllm/models/vit/layer_infer/post_layer_infer.py | 3 +-- lightllm/models/vit/layer_infer/pre_layer_infer.py | 3 +-- .../vit/layer_infer/transformer_layer_infer.py | 3 +-- .../vit/layer_weights/pre_and_post_layer_weight.py | 4 ++-- .../vit/layer_weights/transformer_layer_weight.py | 4 ++-- lightllm/models/vit/model.py | 12 ++++-------- 71 files changed, 141 insertions(+), 153 deletions(-) diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index afc3fc660a..539b32decb 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -1,4 +1,3 @@ -from lightllm.models.cohere.model import CohereTpPartModel from lightllm.models.mixtral.model import MixtralTpPartModel from lightllm.models.bloom.model import BloomTpPartModel from lightllm.models.llama.model import LlamaTpPartModel diff --git a/lightllm/models/bloom/layer_infer/post_layer_infer.py b/lightllm/models/bloom/layer_infer/post_layer_infer.py index 7938869f5a..f4fff116cd 100644 --- a/lightllm/models/bloom/layer_infer/post_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/post_layer_infer.py @@ -10,9 +10,9 @@ class BloomPostLayerInfer(LlamaPostLayerInfer): """ """ - def __init__(self, network_config, mode): + def __init__(self, network_config): repair_config(config=network_config, same_names=["layer_norm_epsilon", "rms_norm_eps"]) - super().__init__(network_config, mode) + super().__init__(network_config) return def _norm(self, input, infer_state, layer_weight: BloomPreAndPostLayerWeight) -> torch.Tensor: diff --git a/lightllm/models/bloom/layer_infer/pre_layer_infer.py b/lightllm/models/bloom/layer_infer/pre_layer_infer.py index baf1d3084d..dfe396ab52 100644 --- a/lightllm/models/bloom/layer_infer/pre_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/pre_layer_infer.py @@ -9,8 +9,8 @@ class BloomPreLayerInfer(PreLayerInferTpl): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.eps_ = network_config["layer_norm_epsilon"] return diff --git a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py index 0316c3652b..808788f71a 100755 --- a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py @@ -9,8 +9,8 @@ class BloomTransformerLayerInfer(TransformerLayerInferTpl): """ """ - def __init__(self, layer_num, network_config, mode): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.eps_ = network_config["layer_norm_epsilon"] self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ self.tp_k_head_num_ = self.tp_q_head_num_ diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index de750bb581..133336dfeb 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -19,7 +19,7 @@ class Deepseek2TransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): + def __init__(self, layer_num, network_config): self.tp_k_head_num_ = 1 self.tp_v_head_num_ = 1 self.qk_nope_head_dim = network_config["qk_nope_head_dim"] @@ -51,7 +51,7 @@ def __init__(self, layer_num, network_config, mode=[]): mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) self.softmax_scale = self.softmax_scale * mscale * mscale self.enable_cc_method = not os.getenv("DISABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"] - super().__init__(layer_num, network_config, mode) + super().__init__(layer_num, network_config) self.num_heads = network_config["num_attention_heads"] self.num_kv_heads = network_config["num_key_value_heads"] return diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 611878f9e8..c5a2d33527 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -17,9 +17,9 @@ class Deepseek2TransformerLayerWeight(TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): self.enable_cc_method = not os.getenv("DISABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"] - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _parse_config(self): diff --git a/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py index 26bfc865e4..adb749c40e 100644 --- a/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py @@ -8,8 +8,8 @@ class Deepseek3MTPPreLayerInfer(LlamaPreLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.eps_ = network_config["rms_norm_eps"] self.hidden_size = network_config["hidden_size"] return diff --git a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py index 4a5bf2e961..1f0815c3db 100644 --- a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py @@ -8,8 +8,8 @@ class Deepseek3MTPPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.eh_proj_weight_ = ROWMMWeight( weight_names="model.layers.0.eh_proj.weight", diff --git a/lightllm/models/gemma3/layer_infer/post_layer_infer.py b/lightllm/models/gemma3/layer_infer/post_layer_infer.py index 22dc595059..57b21844ec 100644 --- a/lightllm/models/gemma3/layer_infer/post_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/post_layer_infer.py @@ -4,7 +4,7 @@ class Gemma3PostLayerInfer(LlamaPostLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.eps_ = 1e-6 return diff --git a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py index dc8a46ad91..3543786f69 100644 --- a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py @@ -5,8 +5,8 @@ class Gemma3PreLayerInfer(LlamaMultimodalPreLayerInfer): - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.embed_scale = torch.tensor(network_config["hidden_size"] ** 0.5, dtype=torch.float32) self.boi_token_index: int = 255_999 self.eoi_token_index: int = 256_000 diff --git a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py index 6f87710917..1f386625bf 100644 --- a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py @@ -12,8 +12,8 @@ class Gemma3TransformerLayerInfer(LlamaTransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.tp_k_head_num_ = network_config["num_key_value_heads"] self.tp_v_head_num_ = network_config["num_key_value_heads"] self.eps_ = 1e-6 diff --git a/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py index 17e65268cc..858937d8c1 100644 --- a/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py @@ -3,8 +3,8 @@ class Gemma3PreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="language_model.model.embed_tokens.weight", diff --git a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py index 1e7ceeb42a..e7808c412c 100644 --- a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py @@ -9,10 +9,9 @@ def __init__( layer_num, data_type, network_config, - mode=[], quant_cfg=None, ): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): diff --git a/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py b/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py index ce9737820e..468d471d2c 100644 --- a/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py @@ -11,8 +11,8 @@ class Gemma_2bPreLayerInfer(PreLayerInferTpl): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) tp_vob_ids = np.linspace(0, network_config["vocab_size"], self.tp_world_size_ + 1, dtype=np.int64) self.vob_start_id_, self.vob_end_id_ = int(tp_vob_ids[self.tp_rank_]), int(tp_vob_ids[self.tp_rank_ + 1]) self.normfactor = network_config["hidden_size"] ** 0.5 diff --git a/lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py index 35ddaef343..2ed325659d 100644 --- a/lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py @@ -16,8 +16,8 @@ class Gemma_2bTransformerLayerInfer(LlamaTransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.tp_k_head_num_ = network_config["num_key_value_heads"] # [SYM] always == 1 self.tp_v_head_num_ = network_config["num_key_value_heads"] return diff --git a/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py index d5d0438fa3..6e052caa63 100644 --- a/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py @@ -3,8 +3,8 @@ class Gemma_2bPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="model.embed_tokens.weight", diff --git a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py index 1916bd095c..9102ce6775 100644 --- a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py @@ -6,8 +6,8 @@ class Gemma_2bTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_qkv(self): diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index d02e0ba8a8..d80eefd16e 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -8,8 +8,8 @@ class GptOssTransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.hidden_size = self.network_config_["hidden_size"] self.alpha = 1.702 self.limit = 7.0 diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py index f6a841b1aa..c5c14b08e6 100644 --- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py @@ -17,10 +17,9 @@ def __init__( layer_num, data_type, network_config, - mode=[], quant_cfg=None, ): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_moe(self): diff --git a/lightllm/models/internlm/layer_weights/transformer_layer_weight.py b/lightllm/models/internlm/layer_weights/transformer_layer_weight.py index 858c192f4c..6ef81122d3 100755 --- a/lightllm/models/internlm/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/internlm/layer_weights/transformer_layer_weight.py @@ -2,8 +2,8 @@ class InternlmTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): diff --git a/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py index b40330aa3d..3ed7004c12 100644 --- a/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py @@ -3,8 +3,8 @@ class Internlm2PreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight(weight_name="model.tok_embeddings.weight", data_type=self.data_type_) self.lm_head_weight_ = LMHeadWeight(weight_name="output.weight", data_type=self.data_type_) diff --git a/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py b/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py index a675558632..a05e977f16 100755 --- a/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py @@ -2,8 +2,8 @@ class Internlm2TransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def load_hf_weights(self, weights): diff --git a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py index b20b9c4955..59caf40d6b 100644 --- a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py @@ -4,8 +4,8 @@ class Internlm2RewardPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="model.tok_embeddings.weight", data_type=self.data_type_, diff --git a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py index 7d76d202ae..21a4c2e6b5 100644 --- a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py @@ -13,8 +13,8 @@ def rename_weight_keys(weights): class InternVLPhi3PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): @@ -24,8 +24,8 @@ def load_hf_weights(self, weights): class InternVLInternlm2PreAndPostLayerWeight(Internlm2PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): @@ -35,8 +35,8 @@ def load_hf_weights(self, weights): class InternVLLlamaPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index 7c7b0ea39b..8bc10d623c 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -13,8 +13,8 @@ class LlamaPostLayerInfer(PostLayerInferTpl): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.eps_ = network_config["rms_norm_eps"] return diff --git a/lightllm/models/llama/layer_infer/pre_layer_infer.py b/lightllm/models/llama/layer_infer/pre_layer_infer.py index ddb99e2627..f4f150b173 100644 --- a/lightllm/models/llama/layer_infer/pre_layer_infer.py +++ b/lightllm/models/llama/layer_infer/pre_layer_infer.py @@ -10,8 +10,8 @@ class LlamaPreLayerInfer(PreLayerInferTpl): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 048101c998..2a9a543196 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -16,8 +16,8 @@ class LlamaTransformerLayerInfer(TransformerLayerInferTpl): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.eps_ = network_config["rms_norm_eps"] self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ self.tp_k_head_num_ = max(network_config["num_key_value_heads"] // self.tp_world_size_, 1) diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index ea59d24dfc..7e9ff41673 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -3,8 +3,8 @@ class LlamaPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="model.embed_tokens.weight", diff --git a/lightllm/models/llama/layer_weights/transformer_layer_weight.py b/lightllm/models/llama/layer_weights/transformer_layer_weight.py index 6b92272ee7..197116d99c 100644 --- a/lightllm/models/llama/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/llama/layer_weights/transformer_layer_weight.py @@ -11,10 +11,9 @@ def __init__( layer_num, data_type, network_config, - mode=[], quant_cfg=None, ): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight(self): diff --git a/lightllm/models/llava/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llava/layer_weights/pre_and_post_layer_weight.py index e0e2e11845..3afcfb0a71 100644 --- a/lightllm/models/llava/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llava/layer_weights/pre_and_post_layer_weight.py @@ -12,8 +12,8 @@ def rename_weight_keys(weights): class LlavaPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): diff --git a/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py index 0952468d0f..45023bdf8f 100644 --- a/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py @@ -3,8 +3,8 @@ class MiniCPMPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) hidden_size = self.network_config_["hidden_size"] dim_model_base = self.network_config_.get("dim_model_base", hidden_size) self.lm_head_scale = hidden_size / dim_model_base diff --git a/lightllm/models/minicpm/layer_weights/transformer_layer_weight.py b/lightllm/models/minicpm/layer_weights/transformer_layer_weight.py index 2bc5078382..c37b524fde 100755 --- a/lightllm/models/minicpm/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/minicpm/layer_weights/transformer_layer_weight.py @@ -3,8 +3,8 @@ class MiniCPMTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _parse_config(self): diff --git a/lightllm/models/mistral/layer_infer/transformer_layer_infer.py b/lightllm/models/mistral/layer_infer/transformer_layer_infer.py index 59eef6daa7..d115c30ec1 100755 --- a/lightllm/models/mistral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mistral/layer_infer/transformer_layer_infer.py @@ -4,7 +4,7 @@ class MistralTransformerLayerInfer(LlamaTransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.head_dim_ = network_config.get("head_dim", self.head_dim_) return diff --git a/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py index 5eac249bad..f890fbf663 100644 --- a/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py +++ b/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py @@ -4,6 +4,6 @@ class MistralMTPPostLayerInfer(LlamaPostLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) return diff --git a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py index 25bea1aa60..dbe9b61c85 100644 --- a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py @@ -7,8 +7,8 @@ class MistralMTPPreLayerInfer(LlamaPreLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) return def _mtp_context_forward( diff --git a/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py index 5724f32af9..6d72ae2c38 100644 --- a/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py @@ -10,8 +10,8 @@ class MistralMTPTransformerLayerInfer(MistralTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) return def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): diff --git a/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py index 2fbc89cfd0..c9032f6fee 100644 --- a/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py @@ -8,8 +8,8 @@ class MistralMTPPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.eh_proj_weight_ = ROWMMWeight( weight_names="mtp.eh_proj.weight", diff --git a/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py index 6607dbb704..08f280b06c 100644 --- a/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py @@ -3,8 +3,8 @@ class MistralMTPTransformerLayerWeight(TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): diff --git a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py index c5842b8289..44e66cff2d 100644 --- a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py @@ -8,8 +8,8 @@ class MixtralTransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.num_local_experts = network_config["num_local_experts"] self.num_experts_per_tok = network_config["num_experts_per_tok"] self.renormalize = True diff --git a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py index f425ad08ba..39e28d4655 100644 --- a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py @@ -8,12 +8,11 @@ class MixtralTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): super().__init__( layer_num, data_type, network_config, - mode, quant_cfg=quant_cfg, ) return diff --git a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py index 0995b1414d..fd3d05e426 100755 --- a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py @@ -7,8 +7,8 @@ class Phi3TransformerLayerInfer(LlamaTransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) return def _get_qkv(self, input_emb, infer_state: LlamaInferStateInfo, layer_weight: Phi3TransformerLayerWeight): diff --git a/lightllm/models/phi3/layer_weights/transformer_layer_weight.py b/lightllm/models/phi3/layer_weights/transformer_layer_weight.py index 91b2730917..db4906c19a 100755 --- a/lightllm/models/phi3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/phi3/layer_weights/transformer_layer_weight.py @@ -6,8 +6,8 @@ class Phi3TransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def load_hf_weights(self, weights): diff --git a/lightllm/models/qwen/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen/layer_infer/transformer_layer_infer.py index 2576d7affd..333870eb9d 100755 --- a/lightllm/models/qwen/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen/layer_infer/transformer_layer_infer.py @@ -9,8 +9,8 @@ class QwenTransformerLayerInfer(LlamaTransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) return def _get_qkv(self, input_emb, infer_state: QwenInferStateInfo, layer_weight: QwenTransformerLayerWeight): diff --git a/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py index 00f68eee69..bf9282a979 100644 --- a/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py @@ -5,8 +5,8 @@ class QwenPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="transformer.wte.weight", data_type=self.data_type_, diff --git a/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py index a8a57c02ed..6449430d9e 100644 --- a/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py @@ -2,6 +2,6 @@ class Qwen2PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return diff --git a/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py index 7cf6366223..3c974691d5 100644 --- a/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py @@ -5,8 +5,8 @@ class Qwen2RewardPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) del self.lm_head_weight_ self.score_up_weight_ = ROWMMWeight( weight_names="score.0.weight", diff --git a/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py index 19e17c36e8..298a77044c 100755 --- a/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py @@ -5,8 +5,8 @@ class Qwen2VLTransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) mrope_section = network_config["rope_scaling"]["mrope_section"] self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device="cuda") diff --git a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py index 20f135e761..5f0c91287d 100644 --- a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py @@ -13,8 +13,8 @@ class Qwen3TransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.head_dim_ = network_config["head_dim"] return diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 10a734e5c3..c85c423c29 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -19,7 +19,7 @@ class Qwen3MOETransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): + def __init__(self, layer_num, network_config): self.n_routed_experts = network_config["num_experts"] self.is_moe = ( network_config["num_experts"] > 0 @@ -28,7 +28,7 @@ def __init__(self, layer_num, network_config, mode=[]): ) self.num_experts_per_tok = network_config["num_experts_per_tok"] self.norm_topk_prob = network_config["norm_topk_prob"] - super().__init__(layer_num, network_config, mode) + super().__init__(layer_num, network_config) self.head_dim_ = network_config["head_dim"] self.tp_k_head_num_ = max(self.tp_k_head_num_, 1) self.tp_v_head_num_ = max(self.tp_v_head_num_, 1) diff --git a/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py index d219173401..4e2b65d743 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py @@ -14,8 +14,8 @@ class Qwen3MOEMTPTransformerLayerInfer(Qwen3MOETransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) return def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py index 6cc447a594..8ba95c1386 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py @@ -9,8 +9,8 @@ class Qwen3MOEMTPPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.eh_proj_weight_ = ROWMMWeight( weight_names="model.layers.0.proj.weight", diff --git a/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py index 96e453ebe7..c24166e13d 100644 --- a/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py @@ -8,8 +8,8 @@ class Qwen3VLMultimodalPreLayerInfer(LlamaMultimodalPreLayerInfer): - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) return def context_forward( diff --git a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py index 9ce475e974..d1c51365a1 100644 --- a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py @@ -14,8 +14,8 @@ class Qwen3VLTransformerLayerInfer(Qwen2VLTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.head_dim_ = network_config["head_dim"] self.mrope_section = torch.tensor( network_config["rope_scaling"]["mrope_section"], dtype=torch.int32, device="cuda" diff --git a/lightllm/models/qwen3_vl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_vl/layer_weights/pre_and_post_layer_weight.py index 5d41d85515..8a380853de 100644 --- a/lightllm/models/qwen3_vl/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_vl/layer_weights/pre_and_post_layer_weight.py @@ -12,8 +12,8 @@ def rename_weight_keys(weights): class Qwen3VLPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): diff --git a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py index facad2e56b..328cc0a625 100644 --- a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py @@ -11,8 +11,8 @@ class Qwen3VLMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.mrope_section = torch.tensor( network_config["rope_scaling"]["mrope_section"], dtype=torch.int32, device="cuda" ) diff --git a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py index b1f5ee6600..52a982f495 100644 --- a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py @@ -4,8 +4,8 @@ class Qwen3VLMOEPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="model.language_model.embed_tokens.weight", data_type=self.data_type_, diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index f439073077..939843a3eb 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -24,8 +24,8 @@ class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer): - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): diff --git a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py index 395ed4ba1a..f908dbdd3b 100755 --- a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py @@ -8,8 +8,8 @@ class StablelmTransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.partial_rotary_factor = self.network_config_.get("partial_rotary_factor", 1) return diff --git a/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py index 0ad3e07df5..3d044eeb56 100755 --- a/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py @@ -2,8 +2,8 @@ class StableLMPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.final_norm_weight_ = NoTpNormWeight( weight_name="model.norm.weight", data_type=self.data_type_, diff --git a/lightllm/models/starcoder/layer_infer/pre_layer_infer.py b/lightllm/models/starcoder/layer_infer/pre_layer_infer.py index 52072a3487..6b88c066ee 100644 --- a/lightllm/models/starcoder/layer_infer/pre_layer_infer.py +++ b/lightllm/models/starcoder/layer_infer/pre_layer_infer.py @@ -9,8 +9,8 @@ class StarcoderPreLayerInfer(PreLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.layer_norm_eps_ = network_config["layer_norm_epsilon"] def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: StarcoderPreAndPostLayerWeight): diff --git a/lightllm/models/starcoder/layer_infer/transformer_layer_infer.py b/lightllm/models/starcoder/layer_infer/transformer_layer_infer.py index 561ffc316f..074f3411a7 100644 --- a/lightllm/models/starcoder/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/starcoder/layer_infer/transformer_layer_infer.py @@ -6,8 +6,8 @@ class StarcoderTransformerLayerInfer(BloomTransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.tp_k_head_num_ = 1 self.tp_v_head_num_ = 1 self._bind_func() diff --git a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py index d5bdd79a7b..329a0245f0 100644 --- a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py @@ -8,8 +8,8 @@ class StarcoderPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="transformer.wte.weight", diff --git a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py index 796a96bc4a..09e3299eb6 100644 --- a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py @@ -5,8 +5,8 @@ class Starcoder2TransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight diff --git a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py index 28a26cb4b3..6ee1885372 100644 --- a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py @@ -4,8 +4,8 @@ class Starcoder2PreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="model.embed_tokens.weight", diff --git a/lightllm/models/tarsier2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/tarsier2/layer_weights/pre_and_post_layer_weight.py index b24fc0f0d1..44e18c2826 100644 --- a/lightllm/models/tarsier2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/tarsier2/layer_weights/pre_and_post_layer_weight.py @@ -17,8 +17,8 @@ def rename_weight_keys(weights): class Tarsier2Qwen2PreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): @@ -28,8 +28,8 @@ def load_hf_weights(self, weights): class Tarsier2LlamaPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): diff --git a/lightllm/models/vit/layer_infer/post_layer_infer.py b/lightllm/models/vit/layer_infer/post_layer_infer.py index 613aec3fa7..fa4a87f158 100644 --- a/lightllm/models/vit/layer_infer/post_layer_infer.py +++ b/lightllm/models/vit/layer_infer/post_layer_infer.py @@ -9,11 +9,10 @@ class ViTPostLayerInfer: """ """ - def __init__(self, network_config, mode): + def __init__(self, network_config): self.tp_rank_ = get_current_rank_in_dp() self.tp_world_size_ = get_dp_world_size() self.network_config_ = network_config - self.mode = mode self.llm_hidden_size = network_config["llm_hidden_size"] self.downsample_ratio = network_config["downsample_ratio"] return diff --git a/lightllm/models/vit/layer_infer/pre_layer_infer.py b/lightllm/models/vit/layer_infer/pre_layer_infer.py index 896e8e898c..306bf9f0e6 100644 --- a/lightllm/models/vit/layer_infer/pre_layer_infer.py +++ b/lightllm/models/vit/layer_infer/pre_layer_infer.py @@ -11,11 +11,10 @@ class ViTPreLayerInfer: """ """ - def __init__(self, network_config, mode): + def __init__(self, network_config): self.tp_rank_ = get_current_rank_in_dp() self.tp_world_size_ = get_dp_world_size() self.network_config_ = network_config - self.mode = mode return def forward(self, pixel_values, layer_weight: ViTPreAndPostLayerWeight): diff --git a/lightllm/models/vit/layer_infer/transformer_layer_infer.py b/lightllm/models/vit/layer_infer/transformer_layer_infer.py index 0b89dca11e..0d55d1b57f 100644 --- a/lightllm/models/vit/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/vit/layer_infer/transformer_layer_infer.py @@ -13,7 +13,7 @@ class ViTTransformerLayerInfer: """ """ - def __init__(self, layer_num, network_config, mode=[]): + def __init__(self, layer_num, network_config): self.tp_rank_ = get_current_rank_in_dp() self.tp_world_size_ = get_dp_world_size() self.eps_ = network_config["layer_norm_eps"] @@ -25,7 +25,6 @@ def __init__(self, layer_num, network_config, mode=[]): self.tp_padding_embed_dim_ = self.tp_padding_head_num * self.head_dim_ self.network_config_ = network_config - self.mode = mode self.layer_num_ = layer_num return diff --git a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py index 276d4e5d0b..e2bed10361 100644 --- a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py @@ -7,8 +7,8 @@ class ViTPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.embed_dim = self.network_config_["hidden_size"] self.image_size = self.network_config_["image_size"] self.patch_size = self.network_config_["patch_size"] diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index c6024594e3..dffcc16fe8 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -14,8 +14,8 @@ class ViTTransformerLayerWeight(TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _cuda(self, cpu_tensor): diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index b8e6eaf929..9c2bc42426 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -40,7 +40,6 @@ def __init__(self, kvargs): self.tp_world_size_ = get_dp_world_size() self.weight_dir_ = kvargs["weight_dir"] self.load_way = kvargs.get("load_way", "HF") - self.mode = [m.replace("int4weight", "w4a16").replace("int8weight", "w8a16") for m in kvargs.get("mode", [])] self.weight_dict = kvargs.get("weight_dict", None) self.data_type = kvargs.get("data_type", "float16") self.quant_type = kvargs.get("quant_type", None) @@ -112,15 +111,12 @@ def _padding_hidden_size(self): return def _init_weights(self): - self.pre_post_weight = self.pre_and_post_weight_class( - self.data_type, network_config=self.config, mode=self.mode - ) + self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) self.trans_layers_weight = [ self.transformer_weight_class( i, self.data_type, network_config=self.config, - mode=self.mode, quant_cfg=self.quant_cfg, ) for i in range(self.config["num_hidden_layers"]) @@ -141,10 +137,10 @@ def _init_quant(self): logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") def _init_infer_layer(self): - self.pre_infer = self.pre_layer_infer_class(network_config=self.config, mode=self.mode) - self.post_infer = self.post_layer_infer_class(network_config=self.config, mode=self.mode) + self.pre_infer = self.pre_layer_infer_class(network_config=self.config) + self.post_infer = self.post_layer_infer_class(network_config=self.config) self.layers_infer = [ - self.transformer_layer_infer_class(i, network_config=self.config, mode=self.mode) + self.transformer_layer_infer_class(i, network_config=self.config) for i in range(self.config["num_hidden_layers"]) ] return From 0852f64752825874ee50b208914fdd75199c0ea7 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 08:26:28 +0000 Subject: [PATCH 088/114] fix mode. --- .../models/qwen/layer_weights/transformer_layer_weight.py | 4 ++-- .../models/qwen2/layer_weights/transformer_layer_weight.py | 4 ++-- .../models/qwen3/layer_weights/transformer_layer_weight.py | 4 ++-- .../qwen3_moe/layer_weights/transformer_layer_weight.py | 4 ++-- .../qwen3_moe_mtp/layer_weights/transformer_layer_weight.py | 4 ++-- .../qwen3_vl_moe/layer_weights/transformers_layer_weight.py | 4 ++-- .../models/stablelm/layer_weights/transformer_layer_weight.py | 4 ++-- .../starcoder/layer_weights/transformer_layer_weight.py | 4 ++-- .../starcoder2/layer_weights/transformer_layer_weight.py | 4 ++-- 9 files changed, 18 insertions(+), 18 deletions(-) diff --git a/lightllm/models/qwen/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen/layer_weights/transformer_layer_weight.py index 9afb964ad2..ac1bf91f4b 100755 --- a/lightllm/models/qwen/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen/layer_weights/transformer_layer_weight.py @@ -3,8 +3,8 @@ class QwenTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) def load_hf_weights(self, weights): qkv_weight_name = f"transformer.h.{self.layer_num_}.attn.c_attn.weight" diff --git a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py index 6962818c49..9c3e2cb3a8 100644 --- a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py @@ -2,8 +2,8 @@ class Qwen2TransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) def _init_weight_names(self): super()._init_weight_names() diff --git a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py index 86b9e172a9..90b7810adf 100644 --- a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py @@ -5,8 +5,8 @@ class Qwen3TransformerLayerWeight(Qwen2TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index 72721f9d6f..486f4d6966 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -4,14 +4,14 @@ class Qwen3MOETransformerLayerWeight(Qwen3TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): self.n_routed_experts = network_config["num_experts"] self.is_moe = ( network_config["num_experts"] > 0 and layer_num not in network_config["mlp_only_layers"] and (layer_num + 1) % network_config["decoder_sparse_step"] == 0 ) - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py index 22d4d19505..095afecd91 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py @@ -4,8 +4,8 @@ class Qwen3MOEMTPTransformerLayerWeight(Qwen3MOETransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight(self): diff --git a/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py b/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py index f4eef6e698..48ddf52089 100644 --- a/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py +++ b/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py @@ -4,8 +4,8 @@ class Qwen3VLMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) def load_hf_weights(self, weights): moe_prefix = f"model.layers.{self.layer_num_}.mlp.experts" diff --git a/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py b/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py index a1a73f6745..03ee50feb5 100755 --- a/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py @@ -2,8 +2,8 @@ class StablelmTransformerLayerWeight(Qwen2TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): diff --git a/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py b/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py index 2aa9dd9ef2..41f24f79cb 100644 --- a/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py @@ -3,8 +3,8 @@ class StarcoderTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg, layer_prefix="transformer.h") + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg, layer_prefix="transformer.h") assert network_config["num_attention_heads"] % self.tp_world_size_ == 0 def load_hf_weights(self, weights): diff --git a/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py b/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py index 6314fa0e57..53342e221f 100644 --- a/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py @@ -3,8 +3,8 @@ class Starcoder2TransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _parse_config(self): From 8f59c773569f60e24e55173c213444f54265f8ed Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 08:50:38 +0000 Subject: [PATCH 089/114] remove max_len_in_batch --- lightllm/common/basemodel/basemodel.py | 3 --- lightllm/common/basemodel/batch_objs.py | 1 - lightllm/common/basemodel/cuda_graph.py | 2 -- lightllm/common/basemodel/prefill_cuda_graph.py | 1 - .../model_infer/mode_backend/chunked_prefill/impl.py | 2 +- .../router/model_infer/mode_backend/dp_backend/impl.py | 6 +++--- .../mode_backend/generic_padded_pre_process.py | 4 ---- .../model_infer/mode_backend/generic_pre_process.py | 3 --- test/benchmark/static_inference/model_infer.py | 10 ---------- test/benchmark/static_inference/model_infer_mtp.py | 2 -- 10 files changed, 4 insertions(+), 30 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 9c30b1ff44..e27f87fba5 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -857,7 +857,6 @@ def _check_max_len_infer(self): model_input = ModelInput( batch_size=1, total_token_num=total_token_num, - max_len_in_batch=self.batch_max_tokens, max_q_seq_len=self.batch_max_tokens, max_kv_seq_len=self.batch_max_tokens, max_cache_len=0, @@ -934,7 +933,6 @@ def _autotune_warmup(self): model_input = ModelInput( batch_size=1, total_token_num=total_token_num, - max_len_in_batch=input_len, max_q_seq_len=input_len, max_kv_seq_len=input_len, max_cache_len=0, @@ -997,7 +995,6 @@ def _init_padded_req(self): model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, - max_len_in_batch=prefill_input_len, max_q_seq_len=prefill_input_len, max_kv_seq_len=prefill_input_len, max_cache_len=0, diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 138f084270..8c9d28e86c 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -11,7 +11,6 @@ class ModelInput: # 通用变量 batch_size: int total_token_num: int - max_len_in_batch: int # 在 decode 阶段, 常规模式下, max_q_seq_len 必定是 1, # 在 mtp 模式下,max_q_seq_len 统计的是一个请求考虑了 mtp 步数的 # 最大长度,实际值是 max([(1 + req.mtp_step) for req in reqs]) diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 9eeab7270c..74fed18c44 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -196,7 +196,6 @@ def warmup(self, model): model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, max_q_seq_len=self.mtp_step + 1, max_kv_seq_len=max_len_in_batch, input_ids=input_ids, @@ -256,7 +255,6 @@ def warmup_overlap(self, model): is_prefill=False, batch_size=batch_size, total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, max_q_seq_len=self.mtp_step + 1, max_kv_seq_len=max_len_in_batch, input_ids=input_ids, diff --git a/lightllm/common/basemodel/prefill_cuda_graph.py b/lightllm/common/basemodel/prefill_cuda_graph.py index 3d77a3ae4c..4e9af4bb1b 100644 --- a/lightllm/common/basemodel/prefill_cuda_graph.py +++ b/lightllm/common/basemodel/prefill_cuda_graph.py @@ -168,7 +168,6 @@ def warmup(self, model): model_input = ModelInput( batch_size=1, total_token_num=total_token_num, - max_len_in_batch=total_token_num, max_q_seq_len=total_token_num, max_kv_seq_len=total_token_num, max_cache_len=0, diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index f3450261b7..a8a5224ebc 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -399,7 +399,7 @@ def _draft_decode_eagle( draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) draft_model_input.b_seq_len += 1 - draft_model_input.max_len_in_batch += 1 + draft_model_input.max_kv_seq_len += 1 eagle_mem_indexes_i = eagle_mem_indexes[_step * num_reqs : (_step + 1) * num_reqs] draft_model_input.mem_indexes = torch.cat( [draft_model_input.mem_indexes.view(-1, self.mtp_step + 1)[:, 1:], eagle_mem_indexes_i.view(-1, 1)], diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index df10a6d4e6..bb0e848e76 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -591,7 +591,7 @@ def _draft_decode_eagle( draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) # update the meta info of the inference draft_model_input.b_seq_len += 1 - draft_model_input.max_len_in_batch += 1 + draft_model_input.max_kv_seq_len += 1 eagle_mem_indexes_i = eagle_mem_indexes[_step * real_req_num : (_step + 1) * real_req_num] eagle_mem_indexes_i = F.pad( input=eagle_mem_indexes_i, @@ -955,7 +955,7 @@ def _draft_decode_eagle_overlap( ) draft_model_input0.b_seq_len += 1 - draft_model_input0.max_len_in_batch += 1 + draft_model_input0.max_kv_seq_len += 1 eagle_mem_indexes_i = eagle_mem_indexes0[_step * real_req_num0 : (_step + 1) * real_req_num0] eagle_mem_indexes_i = F.pad( input=eagle_mem_indexes_i, @@ -969,7 +969,7 @@ def _draft_decode_eagle_overlap( ).view(-1) draft_model_input1.b_seq_len += 1 - draft_model_input1.max_len_in_batch += 1 + draft_model_input1.max_kv_seq_len += 1 eagle_mem_indexes_i = eagle_mem_indexes1[_step * real_req_num1 : (_step + 1) * real_req_num1] eagle_mem_indexes_i = F.pad( input=eagle_mem_indexes_i, diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 7845b9b2ca..03ac4cfb05 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -73,7 +73,6 @@ def padded_prepare_prefill_inputs( max_kv_seq_len = max(b_seq_len) max_cache_len = max(b_ready_cache_len) max_q_seq_len = max(b_q_seq_len) - max_len_in_batch = max(b_q_seq_len) input_ids = np.concatenate(input_ids, dtype=np.int64) input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cpu") @@ -102,7 +101,6 @@ def padded_prepare_prefill_inputs( model_input = ModelInput( batch_size=b_seq_len.shape[0], total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, max_q_seq_len=max_q_seq_len, max_kv_seq_len=max_kv_seq_len, max_cache_len=max_cache_len, @@ -185,7 +183,6 @@ def padded_prepare_decode_inputs( max_kv_seq_len = max(b_seq_len) max_q_seq_len = max(b_q_seq_len) - max_len_in_batch = max(b_seq_len) b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") @@ -210,7 +207,6 @@ def padded_prepare_decode_inputs( model_input = ModelInput( batch_size=b_seq_len.shape[0], total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, max_q_seq_len=max_q_seq_len, max_kv_seq_len=max_kv_seq_len, input_ids=None, diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 963394116a..f1cb0326ff 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -50,7 +50,6 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> max_kv_seq_len = max(b_seq_len) max_cache_len = max(b_ready_cache_len) - max_len_in_batch = max(b_q_seq_len) max_q_seq_len = max(b_q_seq_len) input_ids = np.concatenate(input_ids, dtype=np.int64) @@ -72,7 +71,6 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> model_input = ModelInput( batch_size=b_seq_len.shape[0], total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, max_q_seq_len=max_q_seq_len, max_kv_seq_len=max_kv_seq_len, max_cache_len=max_cache_len, @@ -147,7 +145,6 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In model_input = ModelInput( batch_size=b_seq_len.shape[0], total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, max_q_seq_len=max_q_seq_len, max_kv_seq_len=max_kv_seq_len, input_ids=None, diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index 3fc7ee4b45..a8abd2ae64 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -73,7 +73,6 @@ def overlap_prefill( ): _0_batch_size = batch_size // 2 _0_total_token_num = total_token_num // 2 - _0_max_len_in_batch = max_len_in_batch _0_input_ids = input_ids[: total_token_num // 2] _0_mem_indexes = mem_indexes[: total_token_num // 2] _0_b_req_idx = b_req_idx[: batch_size // 2] @@ -83,7 +82,6 @@ def overlap_prefill( micro_batch1 = ModelInput( batch_size=_0_batch_size, total_token_num=_0_total_token_num, - max_len_in_batch=_0_max_len_in_batch, input_ids=_0_input_ids, b_req_idx=_0_b_req_idx, b_mtp_index=_0_b_mtp_index, @@ -96,7 +94,6 @@ def overlap_prefill( _1_batch_size = batch_size - batch_size // 2 _1_total_token_num = total_token_num - total_token_num // 2 - _1_max_len_in_batch = max_len_in_batch _1_input_ids = input_ids[total_token_num // 2 :] _1_mem_indexes = mem_indexes[total_token_num // 2 :] _1_b_req_idx = b_req_idx[batch_size // 2 :] @@ -107,7 +104,6 @@ def overlap_prefill( micro_batch2 = ModelInput( batch_size=_1_batch_size, total_token_num=_1_total_token_num, - max_len_in_batch=_1_max_len_in_batch, input_ids=_1_input_ids, b_req_idx=_1_b_req_idx, b_mtp_index=_1_b_mtp_index, @@ -129,7 +125,6 @@ def overlap_decode( ): _0_batch_size = batch_size // 2 _0_total_token_num = total_token_num // 2 - _0_max_len_in_batch = max_len_in_batch _0_input_ids = input_ids[: batch_size // 2] _0_mem_indexes = mem_indexes[: batch_size // 2] _0_b_req_idx = b_req_idx[: batch_size // 2] @@ -138,7 +133,6 @@ def overlap_decode( micro_batch1 = ModelInput( batch_size=_0_batch_size, total_token_num=_0_total_token_num, - max_len_in_batch=_0_max_len_in_batch, input_ids=_0_input_ids, b_req_idx=_0_b_req_idx, b_mtp_index=_0_b_mtp_index, @@ -149,7 +143,6 @@ def overlap_decode( _1_batch_size = batch_size - batch_size // 2 _1_total_token_num = total_token_num - total_token_num // 2 - _1_max_len_in_batch = max_len_in_batch _1_input_ids = input_ids[batch_size // 2 :] _1_mem_indexes = mem_indexes[batch_size // 2 :] _1_b_req_idx = b_req_idx[batch_size // 2 :] @@ -159,7 +152,6 @@ def overlap_decode( micro_batch2 = ModelInput( batch_size=_1_batch_size, total_token_num=_1_total_token_num, - max_len_in_batch=_1_max_len_in_batch, input_ids=_1_input_ids, b_req_idx=_1_b_req_idx, b_mtp_index=_1_b_mtp_index, @@ -191,7 +183,6 @@ def prefill( model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, max_q_seq_len=max_len_in_batch, max_kv_seq_len=max_len_in_batch, max_cache_len=0, @@ -217,7 +208,6 @@ def decode( model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, max_q_seq_len=1, max_kv_seq_len=max_len_in_batch, input_ids=input_ids, diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index 942af0f883..07ad52a132 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -129,7 +129,6 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, - max_len_in_batch=input_len, input_ids=test_data, mem_indexes=mem_indexes, b_req_idx=b_req_idx, @@ -197,7 +196,6 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ model_input = ModelInput( batch_size=batch_size * (len(draft_models) + 1), total_token_num=nopad_total_token_num, - max_len_in_batch=nopad_max_len_in_batch, input_ids=decode_input_ids, mem_indexes=mem_indexes, b_req_idx=nopad_b_seq_idx, From d193c602d0372835e4d0156cac6356de15a3dce7 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 09:19:38 +0000 Subject: [PATCH 090/114] remove max_len_in_batch --- lightllm/common/basemodel/attention/triton/fp.py | 8 ++++---- lightllm/common/basemodel/basemodel.py | 6 ++---- lightllm/common/basemodel/infer_struct.py | 6 ------ lightllm/common/basemodel/prefill_cuda_graph.py | 1 - .../gqa/flash_decoding/gqa_flash_decoding.py | 8 ++++---- .../gqa/flash_decoding/gqa_flash_decoding_vsm.py | 2 +- .../att/decode_att/int4kv/ppl_int4kv_flash_decoding.py | 8 ++++---- .../att/decode_att/int8kv/ppl_int8kv_flash_decoding.py | 8 ++++---- .../int8kv/ppl_int8kv_flash_decoding_diverse.py | 10 +++++----- .../decode_att/mha/flash_decoding/flash_decoding.py | 8 ++++---- .../att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py | 8 ++++---- .../mla_att/decode_att/gqa_flash_decoding.py | 4 ++-- .../deepseek2/triton_kernel/gqa_flash_decoding_fp8.py | 6 +++--- .../model_infer/mode_backend/generic_pre_process.py | 3 --- 14 files changed, 37 insertions(+), 49 deletions(-) diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py index da9e5205d4..1ff0aec0c1 100644 --- a/lightllm/common/basemodel/attention/triton/fp.py +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -54,7 +54,7 @@ def _alibi_prefill_att( self.infer_state.b_q_start_loc, self.infer_state.b_seq_len, self.infer_state.b_ready_cache_len, - self.infer_state.max_len_in_batch, + self.infer_state.max_q_seq_len, self.infer_state.req_manager.req_to_token_indexs, ) return out @@ -72,7 +72,7 @@ def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, self.infer_state.b_q_start_loc, self.infer_state.b_seq_len, self.infer_state.b_ready_cache_len, - self.infer_state.max_len_in_batch, + self.infer_state.max_q_seq_len, self.infer_state.req_manager.req_to_token_indexs, ) return out @@ -129,7 +129,7 @@ def _alibi_decode_att( self.infer_state.b_req_idx, self.infer_state.b_kv_start_loc, self.infer_state.b_seq_len, - self.infer_state.max_len_in_batch, + self.infer_state.max_kv_seq_len, self.infer_state.total_token_num, alloc_tensor_func=alloc_func, ) @@ -255,7 +255,7 @@ def _normal_decode_stage3_att( B_req_idx=self.infer_state.b_req_idx, B_Start_Loc=self.infer_state.b_kv_start_loc, B_Seqlen=self.infer_state.b_seq_len, - max_len_in_batch=self.infer_state.max_len_in_batch, + max_len_in_batch=self.infer_state.max_kv_seq_len, ) o_tensor = alloc_func(q.shape, q.dtype) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index e27f87fba5..6c64311db8 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -296,7 +296,6 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) infer_state.return_all_prompt_logics = self.return_all_prompt_logics infer_state.batch_size = model_input.batch_size infer_state.total_token_num = model_input.total_token_num - infer_state.max_len_in_batch = model_input.max_len_in_batch infer_state.max_q_seq_len = model_input.max_q_seq_len infer_state.max_kv_seq_len = model_input.max_kv_seq_len infer_state.max_cache_len = model_input.max_cache_len @@ -394,7 +393,6 @@ def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle new_model_input = copy.copy(model_input) new_model_input.batch_size = model_input.batch_size + 1 new_model_input.total_token_num += padded_token_num - new_model_input.max_len_in_batch = max(padded_token_num, model_input.max_len_in_batch) new_model_input.max_q_seq_len = max(padded_token_num, model_input.max_q_seq_len) new_model_input.max_kv_seq_len = max(padded_token_num, model_input.max_kv_seq_len) new_model_input.max_cache_len = max(0, model_input.max_cache_len) @@ -513,7 +511,7 @@ def _decode( model_input.b_mtp_index, ) - if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_len_in_batch): + if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_kv_seq_len): find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size) padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size) infer_state = self._create_inferstate(padded_model_input) @@ -705,7 +703,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode assert model_input1.mem_indexes.is_cuda origin_batch_size = model_input0.batch_size - max_len_in_batch = max(model_input0.max_len_in_batch, model_input1.max_len_in_batch) + max_len_in_batch = max(model_input0.max_kv_seq_len, model_input1.max_kv_seq_len) if self.graph is not None and self.graph.can_run(origin_batch_size, max_len_in_batch): find_graph_batch_size = self.graph.find_closest_graph_batch_size(origin_batch_size) diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index e046197077..52726981bd 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -38,10 +38,6 @@ def __init__(self): self.b_mark_shared_group: torch.Tensor = None # only for diverse mode used in decode phase. self.b_seq_len: torch.Tensor = None - # max_len_in_batch prefill 和 decode 阶段含义不同 - # prefill 阶段指每个req 输入token的长度(不包括已经cache的部分)最大值 - # decode 阶段指的是每个req的总长 最大值 - self.max_len_in_batch: int = None # max_cache_len 用于 prefill 阶段标识请求中最大 cache的kv 的长度 self.max_cache_len: int = None # prefix_total_token_num 用于 prefill 阶段标识当前请求中所有已经ready的kv的长度 @@ -124,8 +120,6 @@ def init_some_extra_state(self, model): self.b1_cu_kv_seq_len, self.position_ids, ) = gen_decode_params(self.b_seq_len) - # TODO: check the correctness - self.max_kv_seq_len = self.max_len_in_batch self.b_kv_start_loc = self.b1_cu_kv_seq_len[0:-1] def init_att_state(self): diff --git a/lightllm/common/basemodel/prefill_cuda_graph.py b/lightllm/common/basemodel/prefill_cuda_graph.py index 4e9af4bb1b..a8b2616418 100644 --- a/lightllm/common/basemodel/prefill_cuda_graph.py +++ b/lightllm/common/basemodel/prefill_cuda_graph.py @@ -228,7 +228,6 @@ def warmup_overlap(self, model): micro_batch = ModelInput( batch_size=1, total_token_num=total_token_num, - max_len_in_batch=total_token_num, max_q_seq_len=total_token_num, max_kv_seq_len=total_token_num, max_cache_len=0, diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py index c56bf7d5ab..26ec3ebd71 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py @@ -6,7 +6,7 @@ def gqa_token_decode_attention_flash_decoding( ): BLOCK_SEQ = 128 batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + max_kv_seq_len = infer_state.max_kv_seq_len q_head_num, head_dim = q.shape[1], q.shape[2] calcu_shape1 = (batch_size, q_head_num, head_dim) @@ -16,10 +16,10 @@ def gqa_token_decode_attention_flash_decoding( o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" ) flash_decode_stage1( @@ -29,7 +29,7 @@ def gqa_token_decode_attention_flash_decoding( infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, infer_state.b_seq_len, - infer_state.max_len_in_batch, + infer_state.max_kv_seq_len, mid_o, mid_o_logexpsum, BLOCK_SEQ, diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py index 850d4185c3..6a9bb79c7d 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py @@ -421,7 +421,7 @@ def gqa_token_decode_attention_flash_decoding_vsm( if not run_config: if torch.cuda.is_current_stream_capturing(): - avg_seq_len_in_batch = infer_state.max_len_in_batch + avg_seq_len_in_batch = infer_state.max_kv_seq_len else: avg_seq_len_in_batch = infer_state.total_token_num // batch_size diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py index 8c61ed3c4e..76184ea2e2 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py @@ -14,7 +14,7 @@ def token_decode_attention_flash_decoding( ): BLOCK_SEQ = 256 batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + max_kv_seq_len = infer_state.max_kv_seq_len q_head_num = q.shape[1] head_dim = q.shape[2] calcu_shape1 = (batch_size, q_head_num, head_dim) @@ -24,10 +24,10 @@ def token_decode_attention_flash_decoding( o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" ) light_ops.group8_int4kv_flashdecoding_stage1( @@ -43,7 +43,7 @@ def token_decode_attention_flash_decoding( infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, infer_state.b_seq_len, - infer_state.max_len_in_batch, + infer_state.max_kv_seq_len, ) flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding.py index a02ce88a95..f51d611661 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding.py @@ -15,7 +15,7 @@ def token_decode_attention_flash_decoding( BLOCK_SEQ = 256 q_head_num, head_dim = q.shape[1], q.shape[2] batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + max_kv_seq_len = infer_state.max_kv_seq_len calcu_shape1 = (batch_size, q_head_num, head_dim) from ..mha.flash_decoding.flash_decoding_stage2 import flash_decode_stage2 @@ -23,10 +23,10 @@ def token_decode_attention_flash_decoding( o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=q.dtype, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=q.dtype, device="cuda" ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=q.dtype, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=q.dtype, device="cuda" ) light_ops.group8_int8kv_flashdecoding_stage1( @@ -42,7 +42,7 @@ def token_decode_attention_flash_decoding( infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, infer_state.b_seq_len, - infer_state.max_len_in_batch, + infer_state.max_kv_seq_len, ) flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse.py index d42a1a12a3..6efb030ce6 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse.py @@ -31,16 +31,16 @@ def token_decode_attention_flash_decoding( BLOCK_SEQ = 256 batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + max_kv_seq_len = infer_state.max_kv_seq_len calcu_shape1 = (batch_size, q_head_num, head_dim) o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 2, head_dim], dtype=q.dtype, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2, head_dim], dtype=q.dtype, device="cuda" ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 2], dtype=q.dtype, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2], dtype=q.dtype, device="cuda" ) current_stream = torch.cuda.current_stream() @@ -57,7 +57,7 @@ def token_decode_attention_flash_decoding( B_req_idx=infer_state.b_req_idx, b_shared_seq_len=infer_state.b_shared_seq_len, b_mark_shared_group=infer_state.b_mark_shared_group, - max_len_in_batch=infer_state.max_len_in_batch, + max_len_in_batch=infer_state.max_kv_seq_len, mid_out=mid_o, mid_out_logsumexp=mid_o_logexpsum, block_seq=BLOCK_SEQ, @@ -79,7 +79,7 @@ def token_decode_attention_flash_decoding( infer_state.b_req_idx, infer_state.b_seq_len, infer_state.b_shared_seq_len, - infer_state.max_len_in_batch, + infer_state.max_kv_seq_len, ) current_stream.wait_stream(stream1) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding.py index a386212486..6c50fc3927 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding.py @@ -4,7 +4,7 @@ def token_decode_attention_flash_decoding(q, infer_state, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty): BLOCK_SEQ = 256 batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + max_kv_seq_len = infer_state.max_kv_seq_len q_head_num, head_dim = q.shape[1], q.shape[2] calcu_shape1 = (batch_size, q_head_num, head_dim) @@ -14,10 +14,10 @@ def token_decode_attention_flash_decoding(q, infer_state, cache_k, cache_v, out= o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" ) flash_decode_stage1( @@ -27,7 +27,7 @@ def token_decode_attention_flash_decoding(q, infer_state, cache_k, cache_v, out= infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, infer_state.b_seq_len, - infer_state.max_len_in_batch, + infer_state.max_kv_seq_len, mid_o, mid_o_logexpsum, BLOCK_SEQ, diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py index fc21848e16..b0a9b6245c 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py @@ -7,7 +7,7 @@ def token_decode_attention_flash_decoding(q, infer_state, cache_k, cache_v, out= batch_size = infer_state.batch_size q_head_num = q.shape[1] head_dim = q.shape[2] - max_len_in_batch = infer_state.max_len_in_batch + max_kv_seq_len = infer_state.max_kv_seq_len calcu_shape1 = (batch_size, q_head_num, head_dim) from ..mha.flash_decoding.flash_decoding_stage2 import flash_decode_stage2 @@ -15,10 +15,10 @@ def token_decode_attention_flash_decoding(q, infer_state, cache_k, cache_v, out= o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" ) light_ops.fp16_flashdecoding_stage1( @@ -32,7 +32,7 @@ def token_decode_attention_flash_decoding(q, infer_state, cache_k, cache_v, out= infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, infer_state.b_seq_len, - infer_state.max_len_in_batch, + infer_state.max_kv_seq_len, ) flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py index 9d5f6bb8c9..28839b5f59 100644 --- a/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py @@ -15,7 +15,7 @@ def gqa_token_decode_attention_flash_decoding( q_nope, q_rope, kv_nope, kv_rope, infer_state, softmax_scale, out=None, alloc_tensor_func=torch.empty, **run_config ): batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + max_kv_seq_len = infer_state.max_kv_seq_len q_head_num, kv_lora_rank = q_nope.shape[1], q_nope.shape[2] q_rope_dim = q_rope.shape[2] @@ -26,7 +26,7 @@ def gqa_token_decode_attention_flash_decoding( if not run_config: if torch.cuda.is_current_stream_capturing(): - avg_seq_len_in_batch = max_len_in_batch + avg_seq_len_in_batch = max_kv_seq_len else: avg_seq_len_in_batch = infer_state.total_token_num // batch_size diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py index ed2f564b5a..b9be73e278 100644 --- a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py +++ b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py @@ -27,13 +27,13 @@ def gqa_token_decode_attention_flash_decoding_fp8( **run_config ): batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + max_kv_seq_len = infer_state.max_kv_seq_len calcu_shape1 = (batch_size, q_head_num, kv_lora_rank) calcu_shape2 = (batch_size, q_head_num, q_rope_dim) if not run_config: if torch.cuda.is_current_stream_capturing(): - avg_seq_len_in_batch = max_len_in_batch + avg_seq_len_in_batch = max_kv_seq_len else: avg_seq_len_in_batch = infer_state.total_token_num // batch_size @@ -192,7 +192,7 @@ def _fwd_kernel_calcu_index_and_block_seq( infer_state = Deepseek2InferStateInfo() infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX + infer_state.max_kv_seq_len = N_CTX infer_state.total_token_num = Z * N_CTX infer_state.req_manager = ReqManager(Z, N_CTX, None) infer_state.req_manager.req_to_token_indexs = req_to_token_indexs diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index f1cb0326ff..4eb8c7e1e6 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -93,7 +93,6 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[InferReq]]: run_reqs: List[InferReq] = [] total_token_num = 0 - max_len_in_batch = 0 b_req_idx = [] b_mtp_index = [] b_seq_len = [] @@ -107,7 +106,6 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_seq_len.append(seq_len) b_q_seq_len.append(1) total_token_num += seq_len - max_len_in_batch = max(max_len_in_batch, seq_len) b_mtp_index.append(0) multimodal_params.append(req.multimodal_params) # process the draft tokens. @@ -117,7 +115,6 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In seq_len += 1 b_seq_len.append(seq_len) total_token_num += seq_len - max_len_in_batch = max(max_len_in_batch, seq_len) b_mtp_index.append(step + 1) multimodal_params.append(req.multimodal_params) b_q_seq_len.append(1) From 0c6818dc9858e7d36cff8209a0acbb83b49c2a61 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 09:30:00 +0000 Subject: [PATCH 091/114] fix cuda graph. --- lightllm/common/basemodel/cuda_graph.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 74fed18c44..6c5f835784 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -67,7 +67,7 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo): graph_obj = torch.cuda.CUDAGraph() input_ids = infer_state.input_ids batch_size = input_ids.shape[0] - infer_state.max_len_in_batch = self.graph_max_len_in_batch + infer_state.max_kv_seq_len = self.graph_max_len_in_batch infer_state.total_token_num = self.graph_max_len_in_batch * batch_size # warmup # 因为有些推理过程的代码,会通过判断infer_state中是否存在某些属性来在一层上 @@ -100,9 +100,9 @@ def _capture_decode_overlap( graph_obj = torch.cuda.CUDAGraph() input_ids = infer_state.input_ids batch_size = input_ids.shape[0] - infer_state.max_len_in_batch = self.graph_max_len_in_batch + infer_state.max_kv_seq_len = self.graph_max_len_in_batch infer_state.total_token_num = self.graph_max_len_in_batch * batch_size - infer_state1.max_len_in_batch = self.graph_max_len_in_batch + infer_state1.max_kv_seq_len = self.graph_max_len_in_batch infer_state1.total_token_num = self.graph_max_len_in_batch * batch_size # warmup for _ in range(1): @@ -196,7 +196,7 @@ def warmup(self, model): model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, - max_q_seq_len=self.mtp_step + 1, + max_q_seq_len=1, max_kv_seq_len=max_len_in_batch, input_ids=input_ids, mem_indexes=mem_indexes, @@ -255,7 +255,7 @@ def warmup_overlap(self, model): is_prefill=False, batch_size=batch_size, total_token_num=total_token_num, - max_q_seq_len=self.mtp_step + 1, + max_q_seq_len=1, max_kv_seq_len=max_len_in_batch, input_ids=input_ids, b_mtp_index=b_mtp_index, From 31fea47005ad1c626db7cc84c488d4a0382444dc Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 09:42:21 +0000 Subject: [PATCH 092/114] fix cuda graph --- lightllm/common/basemodel/attention/base_att.py | 7 +++++-- lightllm/common/basemodel/attention/fa3/fp.py | 2 +- lightllm/common/basemodel/attention/fa3/fp8.py | 2 +- lightllm/common/basemodel/attention/fa3/mla.py | 2 +- lightllm/common/basemodel/attention/flashinfer/fp.py | 1 + lightllm/common/basemodel/attention/flashinfer/fp8.py | 3 +++ lightllm/common/basemodel/attention/flashinfer/mla.py | 1 + lightllm/common/basemodel/attention/triton/fp.py | 2 +- lightllm/common/basemodel/attention/triton/int4kv.py | 2 +- lightllm/common/basemodel/attention/triton/int8kv.py | 2 +- lightllm/common/basemodel/attention/triton/mla.py | 2 +- lightllm/common/basemodel/infer_struct.py | 4 ++-- 12 files changed, 19 insertions(+), 11 deletions(-) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index c15a254567..42f33a452a 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -98,9 +98,12 @@ class BaseDecodeAttState(ABC): def init_state(self): pass - @abstractmethod def copy_for_decode_cuda_graph(self, new_state: "BaseDecodeAttState"): - pass + for attr_name, attr_value in vars(new_state).items(): + if isinstance(attr_value, torch.Tensor): + attr_ = getattr(self, attr_name, None) + if attr_ is not None and attr_.data_ptr() != attr_value.data_ptr(): + attr_.copy_(attr_value, non_blocking=True) @abstractmethod def decode_att( diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 7058cdb394..be248692e0 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -184,7 +184,7 @@ def init_state(self): return def copy_for_decode_cuda_graph(self, new_state: "Fa3DecodeAttState"): - pass + super().copy_for_decode_cuda_graph(new_state) def decode_att( self, diff --git a/lightllm/common/basemodel/attention/fa3/fp8.py b/lightllm/common/basemodel/attention/fa3/fp8.py index 41dcde125a..75b3a2657d 100644 --- a/lightllm/common/basemodel/attention/fa3/fp8.py +++ b/lightllm/common/basemodel/attention/fa3/fp8.py @@ -165,7 +165,7 @@ def init_state(self): return def copy_for_decode_cuda_graph(self, new_state: "Fp8Fa3DecodeAttState"): - pass + super().copy_for_decode_cuda_graph(new_state) def decode_att( self, diff --git a/lightllm/common/basemodel/attention/fa3/mla.py b/lightllm/common/basemodel/attention/fa3/mla.py index ed6182fe43..9af8970617 100644 --- a/lightllm/common/basemodel/attention/fa3/mla.py +++ b/lightllm/common/basemodel/attention/fa3/mla.py @@ -167,7 +167,7 @@ def init_state(self): return def copy_for_decode_cuda_graph(self, new_state: "MlaFa3DecodeAttState"): - pass + super().copy_for_decode_cuda_graph(new_state) def decode_att( self, diff --git a/lightllm/common/basemodel/attention/flashinfer/fp.py b/lightllm/common/basemodel/attention/flashinfer/fp.py index 21f30dd6f4..4c6ec0efc6 100644 --- a/lightllm/common/basemodel/attention/flashinfer/fp.py +++ b/lightllm/common/basemodel/attention/flashinfer/fp.py @@ -179,6 +179,7 @@ def init_state(self): return def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) self.decode_wrapper.plan( new_state.kv_starts, new_state.kv_indices, diff --git a/lightllm/common/basemodel/attention/flashinfer/fp8.py b/lightllm/common/basemodel/attention/flashinfer/fp8.py index ee20f40e8e..115d6985ac 100644 --- a/lightllm/common/basemodel/attention/flashinfer/fp8.py +++ b/lightllm/common/basemodel/attention/flashinfer/fp8.py @@ -72,6 +72,9 @@ def init_state(self): super().init_state() self.offline_scales = self.infer_state.mem_manager.scales_list + def copy_for_decode_cuda_graph(self, new_state): + return super().copy_for_decode_cuda_graph(new_state) + def decode_att( self, q: torch.Tensor, diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py index 8786a86057..bed52db94c 100644 --- a/lightllm/common/basemodel/attention/flashinfer/mla.py +++ b/lightllm/common/basemodel/attention/flashinfer/mla.py @@ -169,6 +169,7 @@ def init_state(self): return def copy_for_decode_cuda_graph(self, new_state: "MlaFlashInferDecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) self.decode_wrapper.plan( new_state.q_indptr, new_state.kv_starts, diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py index 1ff0aec0c1..d29f15ec3b 100644 --- a/lightllm/common/basemodel/attention/triton/fp.py +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -84,7 +84,7 @@ def init_state(self): pass def copy_for_decode_cuda_graph(self, new_state: "TritonDecodeAttState"): - pass + super().copy_for_decode_cuda_graph(new_state) def decode_att( self, diff --git a/lightllm/common/basemodel/attention/triton/int4kv.py b/lightllm/common/basemodel/attention/triton/int4kv.py index 14a194a438..6a7acaef66 100644 --- a/lightllm/common/basemodel/attention/triton/int4kv.py +++ b/lightllm/common/basemodel/attention/triton/int4kv.py @@ -118,7 +118,7 @@ def init_state(self): pass def copy_for_decode_cuda_graph(self, new_state: "Int4kvTritonDecodeAttState"): - pass + super().copy_for_decode_cuda_graph(new_state) def decode_att( self, diff --git a/lightllm/common/basemodel/attention/triton/int8kv.py b/lightllm/common/basemodel/attention/triton/int8kv.py index 1471fbd699..6a795c4376 100644 --- a/lightllm/common/basemodel/attention/triton/int8kv.py +++ b/lightllm/common/basemodel/attention/triton/int8kv.py @@ -119,7 +119,7 @@ def init_state(self): pass def copy_for_decode_cuda_graph(self, new_state: "Int8kvTritonDecodeAttState"): - pass + super().copy_for_decode_cuda_graph(new_state) def decode_att( self, diff --git a/lightllm/common/basemodel/attention/triton/mla.py b/lightllm/common/basemodel/attention/triton/mla.py index 6fe171120e..5689e4979a 100644 --- a/lightllm/common/basemodel/attention/triton/mla.py +++ b/lightllm/common/basemodel/attention/triton/mla.py @@ -71,7 +71,7 @@ def init_state(self): pass def copy_for_decode_cuda_graph(self, new_state: "MlaTritonDecodeAttState"): - pass + super().copy_for_decode_cuda_graph(new_state) def decode_att( self, diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 52726981bd..75856b1086 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -139,9 +139,9 @@ def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"): if attr_ is not None and attr_.data_ptr() != attr_value.data_ptr(): attr_.copy_(attr_value, non_blocking=True) - self.decode_att_state.copy_for_decode_cuda_graph() + self.decode_att_state.copy_for_decode_cuda_graph(new_infer_state.decode_att_state) if self.decode_att_state1 is not None: - self.decode_att_state1.copy_for_decode_cuda_graph() + self.decode_att_state1.copy_for_decode_cuda_graph(new_infer_state.decode_att_state1) return def prefill_dp_balance(self, input_ids: torch.Tensor): From 4edf70ea85b42b6dd378339b2273cceeb0bd047d Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 12:38:43 +0000 Subject: [PATCH 093/114] fix --- lightllm/common/basemodel/attention/fa3/fp.py | 4 +--- lightllm/common/basemodel/attention/fa3/fp8.py | 5 +---- lightllm/common/basemodel/attention/fa3/mla.py | 4 +--- .../common/basemodel/triton_kernel/gen_prefill_params.py | 1 + 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index be248692e0..952bb39d91 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -136,9 +136,7 @@ def init_state(self): device=self.infer_state.b_seq_len.device, ) b_kv_seq_len = self.infer_state.b_seq_len[mtp_size - 1 :: mtp_size] - b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor( - b_q_seq_len, b_kv_seq_len[mtp_size - 1 :: mtp_size] - ) + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) self.cu_seqlens_q = b1_cu_q_seq_len.int() self.cu_seqlens_k = b1_cu_kv_seq_len.int() else: diff --git a/lightllm/common/basemodel/attention/fa3/fp8.py b/lightllm/common/basemodel/attention/fa3/fp8.py index 75b3a2657d..3feed1ef46 100644 --- a/lightllm/common/basemodel/attention/fa3/fp8.py +++ b/lightllm/common/basemodel/attention/fa3/fp8.py @@ -1,13 +1,10 @@ import dataclasses import torch -from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from ..base_att import AttControl from typing import Optional, TYPE_CHECKING -from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.sgl_utils import flash_attn_with_kvcache from lightllm.utils.envs_utils import get_env_start_args -from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant -from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops from typing import Union from .fp import Fa3AttBackend, Fa3PrefillAttState, Fa3DecodeAttState diff --git a/lightllm/common/basemodel/attention/fa3/mla.py b/lightllm/common/basemodel/attention/fa3/mla.py index 9af8970617..95a4120ff5 100644 --- a/lightllm/common/basemodel/attention/fa3/mla.py +++ b/lightllm/common/basemodel/attention/fa3/mla.py @@ -119,9 +119,7 @@ def init_state(self): device=self.infer_state.b_seq_len.device, ) b_kv_seq_len = self.infer_state.b_seq_len[mtp_size - 1 :: mtp_size] - b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor( - b_q_seq_len, b_kv_seq_len[mtp_size - 1 :: mtp_size] - ) + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) self.cu_seqlens_q = b1_cu_q_seq_len.int() self.cu_seqlens_k = b1_cu_kv_seq_len.int() else: diff --git a/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py b/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py index e73b342994..8f9172b552 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py +++ b/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py @@ -43,6 +43,7 @@ def _gen_cumsum_pad0_kernel( def gen_cumsum_pad0_tensor(b_q_seq_len: torch.Tensor, b_kv_seq_len: torch.Tensor): assert len(b_q_seq_len.shape) == 1 assert b_q_seq_len.shape == b_kv_seq_len.shape + assert b_q_seq_len.is_contiguous() b1_cu_q_seq_len = torch.empty((b_q_seq_len.shape[0] + 1,), dtype=torch.int32, device="cuda") b1_cu_kv_seq_len = torch.empty((b_kv_seq_len.shape[0] + 1,), dtype=torch.int32, device="cuda") From 9d2cb3a10c9ffe5f1822d03cb62075c1267dd4d8 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 13:30:29 +0000 Subject: [PATCH 094/114] fix --- lightllm/common/basemodel/batch_objs.py | 4 +--- .../gqa/flash_decoding/gqa_flash_decoding_stage1.py | 12 +++++++++++- .../ppl_int8kv_flash_decoding_diverse_stage1.py | 3 ++- .../mha/flash_decoding/flash_decoding_stage1.py | 2 +- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 8c9d28e86c..758c0b5194 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -11,9 +11,7 @@ class ModelInput: # 通用变量 batch_size: int total_token_num: int - # 在 decode 阶段, 常规模式下, max_q_seq_len 必定是 1, - # 在 mtp 模式下,max_q_seq_len 统计的是一个请求考虑了 mtp 步数的 - # 最大长度,实际值是 max([(1 + req.mtp_step) for req in reqs]) + # 在 decode 阶段, max_q_seq_len 必定是 1, max_q_seq_len: int max_kv_seq_len: int max_cache_len: int = None diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py index 320c2cf798..2814ff44bc 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py @@ -123,8 +123,18 @@ def _fwd_kernel_flash_decode_stage1( @torch.no_grad() def flash_decode_stage1( - q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_batch, mid_out, mid_out_logsumexp, block_seq + q, + k: torch.Tensor, + v: torch.Tensor, + Req_to_tokens, + B_req_idx, + B_Seqlen, + max_len_in_batch, + mid_out, + mid_out_logsumexp, + block_seq, ): + assert k.stride() == v.stride() BLOCK_SEQ = block_seq BLOCK_N = 16 assert BLOCK_SEQ % BLOCK_N == 0 diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage1.py index 8b3423ce99..7403f6dd5c 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage1.py @@ -269,12 +269,13 @@ def flash_decode_stage1( gqa_group_size = q.shape[1] // k.shape[1] assert triton.next_power_of_2(Lk) == Lk KV_QUANT_GROUP_SIZE = v.shape[-1] // v_scale.shape[-1] - assert KV_QUANT_GROUP_SIZE == 8 + assert triton.next_power_of_2(KV_QUANT_GROUP_SIZE) == KV_QUANT_GROUP_SIZE BLOCK_HEAD = triton.next_power_of_2(gqa_group_size) BLOCK_BATCH = triton.next_power_of_2(max_batch_group_size) if BLOCK_HEAD * BLOCK_BATCH < 16: BLOCK_BATCH = 16 // BLOCK_HEAD + assert k.stride() == v.stride() _fwd_kernel_flash_decode_diverse_stage1[grid]( Q=q, stride_qbs=q.stride(0), diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage1.py index 4691e2db50..f41a5c8fde 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage1.py @@ -117,7 +117,7 @@ def flash_decode_stage1( batch, head_num = B_req_idx.shape[0], q.shape[1] grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) gqa_group_size = q.shape[1] // k.shape[1] - + assert k.stride() == v.stride() _fwd_kernel_flash_decode_stage1[grid]( q, k, From a6a854075736f9658c0b267f55971105db13723f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 13:55:29 +0000 Subject: [PATCH 095/114] fix --- .../models/deepseek2/layer_infer/transformer_layer_infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 133336dfeb..2e7869a9d3 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -135,7 +135,7 @@ def _decompress_kv( b_req_idx=infer_state.b_req_idx, req_to_token_indexs=infer_state.req_manager.req_to_token_indexs, b_seq_len=infer_state.b_seq_len, - b_kv_start_loc=infer_state.b1_kv_start_loc[:-1], + b_kv_start_loc=infer_state.b1_cu_kv_seq_len[:-1], max_kv_seq_len=infer_state.max_kv_seq_len, ) # CC From 7670c4db58e0482db9383cec536d5eb284fcee04 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 8 Jan 2026 14:07:26 +0000 Subject: [PATCH 096/114] fix --- lightllm/common/basemodel/basemodel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 6c64311db8..26d51af3db 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -350,6 +350,7 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s new_model_input = copy.copy(model_input) new_model_input.batch_size = new_batch_size new_model_input.total_token_num += padded_batch_size * 2 + new_model_input.max_kv_seq_len = max(2, model_input.max_kv_seq_len) new_model_input.input_ids = F.pad(new_model_input.input_ids, (0, padded_batch_size), mode="constant", value=1) new_model_input.b_req_idx = F.pad( new_model_input.b_req_idx, (0, padded_batch_size), mode="constant", value=self.req_manager.HOLD_REQUEST_ID From 1d4884d3feba303946f285f82edb380da7a7d6a9 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 9 Jan 2026 01:54:47 +0000 Subject: [PATCH 097/114] fix --- lightllm/server/api_cli.py | 17 +++++++++++------ lightllm/utils/envs_utils.py | 2 +- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 422a656a87..16c844d483 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -318,23 +318,28 @@ def make_argument_parser() -> argparse.ArgumentParser: type=str, nargs="+", choices=["None", "triton", "fa3", "flashinfer"], - default=["None"], - help="""prefill attention kernel used in llm""", + default=["triton"], + help="""prefill attention kernel used in llm. + None: automatically select backend based on current GPU device, + not supported yet, will support in future""", ) parser.add_argument( "--llm_decode_att_backend", type=str, nargs="+", choices=["None", "triton", "fa3", "flashinfer"], - default=["None"], - help="""decode attention kernel used in llm""", + default=["triton"], + help="""decode attention kernel used in llm. + None: automatically select backend based on current GPU device, + not supported yet, will support in future""", ) parser.add_argument( "--llm_kv_type", type=str, - choices=["None", "int8kv", "int4kv", "fp8kv"], + choices=["None", "int8kv", "int4kv"], default="None", - help="""kv type used in llm, None for dtype that llm used in config.json""", + help="""kv type used in llm, None for dtype that llm used in config.json. + fp8kv: not fully supported yet, will support in future""", ) parser.add_argument( "--llm_kv_quant_group_size", diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 0b54ef5dce..0a70f1dfa6 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -215,7 +215,7 @@ def get_diverse_max_batch_shared_group_size() -> int: @lru_cache(maxsize=None) def enable_diverse_mode_gqa_decode_fast_kernel() -> bool: - return get_env_start_args().diverse_mode and "ppl_int8kv_flashdecoding_diverse" in get_env_start_args().mode + return get_env_start_args().diverse_mode and "int8kv" == get_env_start_args().llm_kv_type @lru_cache(maxsize=None) From 289a3698da4749f67bd172e7ccfa2ca8ecb7c063 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 9 Jan 2026 02:35:41 +0000 Subject: [PATCH 098/114] fix --- lightllm/common/basemodel/attention/triton/int4kv.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lightllm/common/basemodel/attention/triton/int4kv.py b/lightllm/common/basemodel/attention/triton/int4kv.py index 6a7acaef66..25199dc470 100644 --- a/lightllm/common/basemodel/attention/triton/int4kv.py +++ b/lightllm/common/basemodel/attention/triton/int4kv.py @@ -73,8 +73,9 @@ def _groupsize_quant_prefill_att( assert k_scale.untyped_storage().data_ptr() == v_scale.untyped_storage().data_ptr() total_token_num = self.infer_state.total_token_num - k_dequant = alloc_func((total_token_num, k.shape[1], k.shape[2]), dtype=q.dtype, device=q.device) - v_dequant = alloc_func((total_token_num, v.shape[1], v.shape[2]), dtype=q.dtype, device=q.device) + head_dim = k.shape[2] * 2 # 2个4bit存储为一个int8, 所以维度需要翻倍,才是解量化后的精度 + k_dequant = alloc_func((total_token_num, k.shape[1], head_dim), dtype=q.dtype, device=q.device) + v_dequant = alloc_func((total_token_num, v.shape[1], head_dim), dtype=q.dtype, device=q.device) o_tensor = alloc_func(q.shape, dtype=q.dtype, device=q.device) max_kv_seq_len = self.infer_state.max_kv_seq_len From 6aee5fe16d6c528dc02c2729639224043e7b89fa Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 9 Jan 2026 03:11:32 +0000 Subject: [PATCH 099/114] fix. --- .../int4kv/int4kv_flash_decoding_stage1.py | 196 ++++++++++++++++++ .../int4kv/ppl_int4kv_flash_decoding.py | 30 +-- 2 files changed, 211 insertions(+), 15 deletions(-) create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py new file mode 100644 index 0000000000..e19aaf4f12 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py @@ -0,0 +1,196 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def int4_to_float(k_int8, k_scale, offs_d): + k_high = ((k_int8.to(tl.uint8, bitcast=True) & 0xF0) >> 4).to(tl.int8, bitcast=True) + k_low = k_int8 & 0x0F + k_high = tl.where(k_high >= 8, k_high - 16, k_high) + k_low = tl.where(k_low >= 8, k_low - 16, k_low) + k_int4 = tl.where( + offs_d[None, :] % 2 == 0, + k_low, + k_high, + ) + k = k_int4.to(k_scale.dtype) * k_scale + return k + + +@triton.jit +def _fwd_kernel_flash_decode_stage1( + Q, + K, + K_scale, + V, + V_scale, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + gqa_group_size, + quant_group_size, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + seq_start_block = tl.program_id(2) + cur_kv_head = cur_head // gqa_group_size + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + cur_batch_start_index = seq_start_block * BLOCK_SEQ + cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + + block_n_size = ( + tl.where( + cur_batch_end_index - cur_batch_start_index <= 0, + 0, + cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1, + ) + // BLOCK_N + ) + + offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) + + q = tl.load(Q + off_q) + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, block_n_size, 1): + offs_n_new = start_n * BLOCK_N + offs_n + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) + k_loc = k_loc.to(tl.int64) + off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] // 2 + off_k_scale = off_k // (quant_group_size // 2) + k_int8 = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0) + k_scale = tl.load(K_scale + off_k_scale, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) + k = int4_to_float(k_int8, k_scale, offs_d) + + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + att_value = tl.where((offs_n_new < cur_batch_end_index), att_value, float("-inf")) + v_int8 = tl.load(V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0) + v_scale = tl.load(V_scale + off_k_scale, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) + v = int4_to_float(v_int8, v_scale, offs_d) + + cur_max_logic = tl.max(att_value, axis=0) + new_max_logic = tl.maximum(cur_max_logic, max_logic) + + exp_logic = tl.exp(att_value - new_max_logic) + logic_scale = tl.exp(max_logic - new_max_logic) + acc *= logic_scale + acc += tl.sum(exp_logic[:, None] * v, axis=0) + + sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0) + max_logic = new_max_logic + + need_store = tl.where(block_n_size == 0, 0, 1) + for _ in range(0, need_store, 1): + off_mid_o = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + seq_start_block * stride_mid_os + offs_d + off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block + tl.store(Mid_O + off_mid_o, acc / sum_exp) + tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp)) + return + + +@torch.no_grad() +def int4kv_flash_decode_stage1( + q, + k, + k_scale, + v, + v_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + max_len_in_batch, + mid_out, + mid_out_logsumexp, + block_seq, +): + BLOCK_SEQ = block_seq + BLOCK_N = 16 + assert BLOCK_SEQ % BLOCK_N == 0 + # shape constraints + Lq, Lk = q.shape[-1], k.shape[-1] * 2 + assert Lq == Lk + assert Lk in {16, 32, 64, 128} + sm_scale = 1.0 / (Lk ** 0.5) + batch, head_num = B_req_idx.shape[0], q.shape[1] + grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) + gqa_group_size = q.shape[1] // k.shape[1] + quant_group_size = Lk // k_scale.shape[-1] + assert triton.next_power_of_2(quant_group_size) == quant_group_size + assert k.stride() == v.stride() + _fwd_kernel_flash_decode_stage1[grid]( + q, + k, + k_scale, + v, + v_scale, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + mid_out, + mid_out_logsumexp, + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + mid_out.stride(0), + mid_out.stride(1), + mid_out.stride(2), + mid_out.stride(3), + mid_out_logsumexp.stride(0), + mid_out_logsumexp.stride(1), + mid_out_logsumexp.stride(2), + gqa_group_size, + quant_group_size, + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK_N, + num_warps=1, + num_stages=2, + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py index 76184ea2e2..2df12c45b5 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py @@ -1,5 +1,4 @@ import torch -from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops def token_decode_attention_flash_decoding( @@ -30,20 +29,21 @@ def token_decode_attention_flash_decoding( [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" ) - light_ops.group8_int4kv_flashdecoding_stage1( - BLOCK_SEQ, - mid_o, - mid_o_logexpsum, - 1.0 / (head_dim ** 0.5), - q.view(calcu_shape1), - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_kv_seq_len, + from .int4kv_flash_decoding_stage1 import int4kv_flash_decode_stage1 + + int4kv_flash_decode_stage1( + q=q.view(calcu_shape1), + k=cache_k, + k_scale=cache_k_scale, + v=cache_v, + v_scale=cache_v_scale, + Req_to_tokens=infer_state.req_manager.req_to_token_indexs, + B_req_idx=infer_state.b_req_idx, + B_Seqlen=infer_state.b_seq_len, + max_len_in_batch=infer_state.max_kv_seq_len, + mid_out=mid_o, + mid_out_logsumexp=mid_o_logexpsum, + block_seq=BLOCK_SEQ, ) flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) From 2d9705fc1befc79eda92a636a4150bd5fe8734e7 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 9 Jan 2026 04:31:33 +0000 Subject: [PATCH 100/114] fix --- lightllm/models/deepseek2/triton_kernel/sample_kv.py | 1 - test/acc/test_deepseekr1.sh | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/lightllm/models/deepseek2/triton_kernel/sample_kv.py b/lightllm/models/deepseek2/triton_kernel/sample_kv.py index ace118c9e5..88cef99b05 100644 --- a/lightllm/models/deepseek2/triton_kernel/sample_kv.py +++ b/lightllm/models/deepseek2/triton_kernel/sample_kv.py @@ -42,7 +42,6 @@ def _sample_kv_kernel( kv_loc = tl.load( req_to_token_indexs + stride_req_to_tokens_b * cur_batch_req_idx + offs_m, - other=0, ).to(tl.int64) off_kv_nope = kv_loc[:, None] * stride_all_s + offs_nope_d[None, :] off_kv_rope = kv_loc[:, None] * stride_all_s + (offs_rope_d + BLOCK_NOPE_DIM)[None, :] diff --git a/test/acc/test_deepseekr1.sh b/test/acc/test_deepseekr1.sh index 5fcfc0c08b..180d2d4e20 100644 --- a/test/acc/test_deepseekr1.sh +++ b/test/acc/test_deepseekr1.sh @@ -1,4 +1,4 @@ -LOADWORKER=18 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --port 8089 --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 +LOADWORKER=18 python -m lightllm.server.api_server --batch_max_tokens 6000 --model_dir /mtc/models/DeepSeek-R1 --tp 8 --port 8089 --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 From 4a4961f06799c9d50e59fcba9e78c5fb60952c90 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 9 Jan 2026 04:42:36 +0000 Subject: [PATCH 101/114] fix --- .../models/deepseek2/layer_infer/transformer_layer_infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 2e7869a9d3..8695f2de89 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -105,7 +105,7 @@ def _token_attention_kernel( ): q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - kv = infer_state.mem_manager.get_att_input_params() + kv = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) out = infer_state.decode_att_state.decode_att( q=(q_nope, q_rope), From d9bcf6cee70492b9a905e005ec7232d149358ca7 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 9 Jan 2026 05:09:41 +0000 Subject: [PATCH 102/114] fix --- .../int4kv/int4kv_flash_decoding_stage1.py | 12 +++-- .../int4kv/ppl_int4kv_flash_decoding.py | 4 +- .../kv_copy/ppl_int4kv_copy_kv.py | 48 ++++++++++--------- 3 files changed, 35 insertions(+), 29 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py index e19aaf4f12..212825a962 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py @@ -5,10 +5,13 @@ @triton.jit def int4_to_float(k_int8, k_scale, offs_d): - k_high = ((k_int8.to(tl.uint8, bitcast=True) & 0xF0) >> 4).to(tl.int8, bitcast=True) + k_int8 = k_int8.to(tl.uint8, bitcast=True) + k_high = (k_int8 & 0xF0) >> 4 k_low = k_int8 & 0x0F - k_high = tl.where(k_high >= 8, k_high - 16, k_high) - k_low = tl.where(k_low >= 8, k_low - 16, k_low) + k_high = k_high.to(tl.int8, bitcast=True) + k_low = k_low.to(tl.int8, bitcast=True) + k_high -= 7 + k_low -= 7 k_int4 = tl.where( offs_d[None, :] % 2 == 0, k_low, @@ -155,6 +158,7 @@ def int4kv_flash_decode_stage1( quant_group_size = Lk // k_scale.shape[-1] assert triton.next_power_of_2(quant_group_size) == quant_group_size assert k.stride() == v.stride() + # TODO 优化为gqa使用tensor core的实现,速度更快。 _fwd_kernel_flash_decode_stage1[grid]( q, k, @@ -190,7 +194,7 @@ def int4kv_flash_decode_stage1( BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK_N, - num_warps=1, + num_warps=4, num_stages=2, ) return diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py index 2df12c45b5..a5a054b93a 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py @@ -23,10 +23,10 @@ def token_decode_attention_flash_decoding( o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=q.dtype, device="cuda" ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=q.dtype, device="cuda" ) from .int4kv_flash_decoding_stage1 import int4kv_flash_decode_stage1 diff --git a/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py index f7268a853d..53d1256ec7 100644 --- a/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py +++ b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py @@ -60,17 +60,19 @@ def _fwd_kernel_destindex_copy_quantize_int4_kv( q_src_data_0 = (src_data_0 / data_scale[:, None]).to(tl.int8) q_src_data_0 = tl.where(q_src_data_0 > 7, 7, q_src_data_0) q_src_data_0 = tl.where(q_src_data_0 < -7, -7, q_src_data_0) + q_src_data_0 += 7 q_src_data_0 = q_src_data_0.to(tl.uint8, bitcast=True) q_src_data_1 = (src_data_1 / data_scale[:, None]).to(tl.int8) q_src_data_1 = tl.where(q_src_data_1 > 7, 7, q_src_data_1) q_src_data_1 = tl.where(q_src_data_1 < -7, -7, q_src_data_1) + q_src_data_1 += 7 q_src_data_1 = q_src_data_1.to(tl.uint8, bitcast=True) - low_4 = ((q_src_data_0 & 0x80) >> 4) | (q_src_data_0 & 0xF) - high_4 = (((q_src_data_1 & 0x80) >> 4) | (q_src_data_1 & 0xF)) << 4 + low_4 = q_src_data_0 & 0xF + high_4 = (q_src_data_1 & 0xF) << 4 - out_data = (low_4 | high_4).to(tl.int8, bitcast=True) + out_data = (low_4 | high_4).to(Out.dtype.element_ty, bitcast=True) o_ptrs = ( Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :] @@ -136,6 +138,24 @@ def destindex_copy_int4kv( return +@triton.jit +def int4_to_float(k_int8, offs_d): + k_int8 = k_int8.to(tl.uint8, bitcast=True) + k_high = (k_int8 & 0xF0) >> 4 + k_low = k_int8 & 0x0F + k_high = k_high.to(tl.int8, bitcast=True) + k_low = k_low.to(tl.int8, bitcast=True) + k_high -= 7 + k_low -= 7 + + k_int4 = tl.where( + offs_d[None, None, :] % 2 == 0, + k_low, + k_high, + ) + return k_int4 + + @triton.jit def _fwd_dequantize_int4kv( k, @@ -206,16 +226,7 @@ def _fwd_dequantize_int4kv( + group_offs[None, :, None] * k_sg + offs_d[None, None, :] // 2 ) - k_high = ((k_int8.to(tl.uint8, bitcast=True) & 0xF0) >> 4).to(tl.int8, bitcast=True) - k_low = k_int8 & 0x0F - k_high = tl.where(k_high >= 8, k_high - 16, k_high) - k_low = tl.where(k_low >= 8, k_low - 16, k_low) - - k_int4 = tl.where( - offs_d[None, None, :] % 2 == 0, - k_low, - k_high, - ) + k_int4 = int4_to_float(k_int8, offs_d) k_scale_data = tl.load( k_scale @@ -242,16 +253,7 @@ def _fwd_dequantize_int4kv( + group_offs[None, :, None] * v_sg + offs_d[None, None, :] // 2 ) - v_high = ((v_int8.to(tl.uint8, bitcast=True) & 0xF0) >> 4).to(tl.int8, bitcast=True) - v_low = v_int8 & 0x0F - v_high = tl.where(v_high >= 8, v_high - 16, v_high) - v_low = tl.where(v_low >= 8, v_low - 16, v_low) - - v_int4 = tl.where( - offs_d[None, None, :] % 2 == 0, - v_low, - v_high, - ) + v_int4 = int4_to_float(v_int8, offs_d) v_scale_data = tl.load( v_scale + kv_loc[:, None, None] * v_scale_ss From 5db566a5054a2a1ccfab1e4c41a7eda988539656 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 9 Jan 2026 05:12:23 +0000 Subject: [PATCH 103/114] fix --- lightllm/common/basemodel/attention/base_att.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index 42f33a452a..859d97ca84 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -81,8 +81,8 @@ def init_state(self): def prefill_att( self, q: torch.Tensor, - k: torch.tensor, - v: torch.tensor, + k: torch.Tensor, + v: torch.Tensor, att_control: AttControl = AttControl(), alloc_func=torch.empty, ) -> torch.Tensor: From c89bcd11c10699521b2f36020b7c9a138be525ef Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 9 Jan 2026 05:26:04 +0000 Subject: [PATCH 104/114] fix --- lightllm/models/deepseek2/triton_kernel/sample_kv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/deepseek2/triton_kernel/sample_kv.py b/lightllm/models/deepseek2/triton_kernel/sample_kv.py index 88cef99b05..c09d313e22 100644 --- a/lightllm/models/deepseek2/triton_kernel/sample_kv.py +++ b/lightllm/models/deepseek2/triton_kernel/sample_kv.py @@ -73,7 +73,7 @@ def sample_kv( batch = b_seq_len.shape[0] BLOCK = 64 if not is_tesla() else 32 - num_warps = 4 + num_warps = 8 grid = ( batch, triton.cdiv(max_kv_seq_len, BLOCK), From 730fd5048ca6e2aae94295bc2f82ce0795ace3ae Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 9 Jan 2026 05:26:48 +0000 Subject: [PATCH 105/114] fix --- lightllm/common/basemodel/attention/triton/mla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/attention/triton/mla.py b/lightllm/common/basemodel/attention/triton/mla.py index 5689e4979a..80ce05bb61 100644 --- a/lightllm/common/basemodel/attention/triton/mla.py +++ b/lightllm/common/basemodel/attention/triton/mla.py @@ -105,7 +105,7 @@ def _mla_decode_att( alloc_func=torch.empty, ): assert att_control.mla_decode - softmax_scale = att_control.mla_prefill_dict["softmax_scale"] + softmax_scale = att_control.mla_decode_dict["softmax_scale"] from ...triton_kernel.mla_att.decode_att import gqa_token_decode_attention_flash_decoding From b503525b6940afee6402a9a94f46cdb791dcfcba Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 9 Jan 2026 05:34:32 +0000 Subject: [PATCH 106/114] fix --- lightllm/common/basemodel/attention/triton/mla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/attention/triton/mla.py b/lightllm/common/basemodel/attention/triton/mla.py index 80ce05bb61..8288193ad7 100644 --- a/lightllm/common/basemodel/attention/triton/mla.py +++ b/lightllm/common/basemodel/attention/triton/mla.py @@ -116,7 +116,7 @@ def _mla_decode_att( out = gqa_token_decode_attention_flash_decoding( q_nope=q_nope, q_rope=q_rope, - kv_nope=kv[:, :, :qk_rope_head_dim], + kv_nope=kv[:, :, :-qk_rope_head_dim], kv_rope=kv[:, :, -qk_rope_head_dim:], infer_state=self.infer_state, softmax_scale=softmax_scale, From a1b85a780d44006e8761e9a6854908e7f1b386e5 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 9 Jan 2026 08:42:32 +0000 Subject: [PATCH 107/114] fix --- .../triton_kernel/kv_copy/mla_copy_kv.py | 11 ++- .../deepseek2_mem_manager.py | 1 + .../triton_kernel/repack_kv_index.py | 90 ------------------- .../deepseek2/triton_kernel/sample_kv.py | 2 +- 4 files changed, 9 insertions(+), 95 deletions(-) delete mode 100644 lightllm/models/deepseek2/triton_kernel/repack_kv_index.py diff --git a/lightllm/common/basemodel/triton_kernel/kv_copy/mla_copy_kv.py b/lightllm/common/basemodel/triton_kernel/kv_copy/mla_copy_kv.py index 39deb1b6f7..41a25877a7 100644 --- a/lightllm/common/basemodel/triton_kernel/kv_copy/mla_copy_kv.py +++ b/lightllm/common/basemodel/triton_kernel/kv_copy/mla_copy_kv.py @@ -36,11 +36,11 @@ def _fwd_kernel_destindex_copy_kv( dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) - kv_nope_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_d_nope[None, :] - kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_d_rope[None, :] + kv_nope_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_d_nope + kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_d_rope - o_nope_ptrs = O_nope + dest_index * stride_o_nope_bs + stride_o_nope_d * offs_d_nope[None, :] - o_rope_ptrs = O_rope + dest_index * stride_o_rope_bs + stride_o_rope_d * offs_d_rope[None, :] + o_nope_ptrs = O_nope + dest_index * stride_o_nope_bs + stride_o_nope_d * offs_d_nope + o_rope_ptrs = O_rope + dest_index * stride_o_rope_bs + stride_o_rope_d * offs_d_rope kv_nope = tl.load(kv_nope_ptrs) kv_rope = tl.load(kv_rope_ptrs) @@ -60,6 +60,9 @@ def destindex_copy_kv(KV_nope, KV_rope, DestLoc, O_nope, O_rope): assert KV_nope.shape[2] == O_nope.shape[2] assert KV_rope.shape[1] == O_rope.shape[1] assert KV_rope.shape[2] == O_rope.shape[2] + assert triton.next_power_of_2(kv_nope_head_dim) == kv_nope_head_dim + assert triton.next_power_of_2(kv_rope_head_dim) == kv_rope_head_dim + grid = (seq_len,) num_warps = 1 diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py index 3629acf973..3d93e1b070 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py @@ -26,6 +26,7 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: rope_dim = 64 kv_lora_rank = kv.shape[2] - rope_dim + assert kv_lora_rank + rope_dim == self.kv_buffer.shape[-1] destindex_copy_kv( kv[:, :, :kv_lora_rank], diff --git a/lightllm/models/deepseek2/triton_kernel/repack_kv_index.py b/lightllm/models/deepseek2/triton_kernel/repack_kv_index.py deleted file mode 100644 index e86d2e819e..0000000000 --- a/lightllm/models/deepseek2/triton_kernel/repack_kv_index.py +++ /dev/null @@ -1,90 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_repack_kv_index( - kv_index, - req_index, - out_kv_index, - seq_len, - start_loc, - kv_stride_h, - SEQ_BLOCK: tl.constexpr, -): - cur_batch = tl.program_id(0) - start_seq_n = tl.program_id(1) - - cur_batch_seq_len = tl.load(seq_len + cur_batch) - cur_batch_req_idx = tl.load(req_index + cur_batch) - cur_batch_start_loc = tl.load(start_loc + cur_batch) - - offs_seq = start_seq_n * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK) - block_end_loc = tl.minimum((start_seq_n + 1) * SEQ_BLOCK, cur_batch_seq_len) - kv_index_data = tl.load( - kv_index + kv_stride_h * cur_batch_req_idx + offs_seq, - mask=offs_seq < block_end_loc, - other=0, - ) - out_kv_index_ptr = out_kv_index + cur_batch_start_loc + offs_seq - tl.store(out_kv_index_ptr, kv_index_data, mask=offs_seq < block_end_loc) - return - - -@torch.no_grad() -def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index): - batch_size = req_index.shape[0] - # flashinfer requires out_kv_index to be zeroed before use - out_kv_index.zero_() - BLOCK = 64 - grid = ( - batch_size, - triton.cdiv(max_seq_len, BLOCK), - ) - - _fwd_kernel_repack_kv_index[grid]( - kv_index, - req_index, - out_kv_index, - seq_len, - start_loc, - kv_index.stride(0), - SEQ_BLOCK=BLOCK, - num_warps=8, - num_stages=1, - ) - return - - -def repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, output): - for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): - output[start : start + sl] = req_to_token_indexs[b][:sl] - - -if __name__ == "__main__": - import torch.nn.functional as F - - BATCH, MAX_SEQ_LEN = 10, 1024 - rand_idx = torch.randperm(2 * MAX_SEQ_LEN * BATCH).cuda().int() - b_req_idx = torch.randperm(BATCH).cuda().int() - b_seq_len = torch.randint(1, MAX_SEQ_LEN, (BATCH,)).cuda().int() - req_to_token_indexs = torch.zeros((2 * BATCH, 2 * MAX_SEQ_LEN)).cuda().int() - b_start_loc = ( - torch.cat([torch.zeros([1], device=b_seq_len.device, dtype=b_seq_len.dtype), b_seq_len[0:-1].cumsum(0)]) - .cuda() - .int() - ) - - output = torch.zeros((b_seq_len.sum(),)).cuda().int() - ref = torch.zeros((b_seq_len.sum(),)).cuda().int() - for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): - req_to_token_indexs[b][:sl] = rand_idx[start : start + sl] - - fn1 = lambda: repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, ref) - fn2 = lambda: repack_kv_index(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, MAX_SEQ_LEN, output) - ms1 = triton.testing.do_bench(fn1) - ms2 = triton.testing.do_bench_cudagraph(fn2) - print(ms1, ms2) - assert torch.allclose(output.float(), ref.float()) diff --git a/lightllm/models/deepseek2/triton_kernel/sample_kv.py b/lightllm/models/deepseek2/triton_kernel/sample_kv.py index c09d313e22..0cc3c66abe 100644 --- a/lightllm/models/deepseek2/triton_kernel/sample_kv.py +++ b/lightllm/models/deepseek2/triton_kernel/sample_kv.py @@ -69,7 +69,7 @@ def sample_kv( ): nope_dim = sampled_compressed_kv_nope.shape[-1] rope_dim = sampled_k_rope.shape[-1] - + assert rope_dim == 64 batch = b_seq_len.shape[0] BLOCK = 64 if not is_tesla() else 32 From fab336543837f21ff1b2419c0e72a3984065538f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 9 Jan 2026 09:16:37 +0000 Subject: [PATCH 108/114] fix --- lightllm/models/deepseek2/triton_kernel/sample_kv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/deepseek2/triton_kernel/sample_kv.py b/lightllm/models/deepseek2/triton_kernel/sample_kv.py index 0cc3c66abe..53a0a60eb2 100644 --- a/lightllm/models/deepseek2/triton_kernel/sample_kv.py +++ b/lightllm/models/deepseek2/triton_kernel/sample_kv.py @@ -37,7 +37,7 @@ def _sample_kv_kernel( offs_rope_d = tl.arange(0, BLOCK_ROPE_DIM) offs_m = (start_m * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)) % cur_batch_seq_len - if (start_m + 1) * BLOCK_SEQ > cur_batch_seq_len: + if start_m * BLOCK_SEQ > cur_batch_seq_len: return kv_loc = tl.load( From 3c400b860cc06eaf1eca7b1683ee50db5da140e3 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 9 Jan 2026 09:23:11 +0000 Subject: [PATCH 109/114] fix --- .../triton_kernel/destindex_copy_kv.py | 107 ------------------ 1 file changed, 107 deletions(-) delete mode 100644 lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py diff --git a/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py b/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py deleted file mode 100644 index 39deb1b6f7..0000000000 --- a/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py +++ /dev/null @@ -1,107 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -def _is_power_of_two(n): - return n > 0 and (n & (n - 1)) == 0 - - -@triton.jit -def _fwd_kernel_destindex_copy_kv( - KV_nope, - KV_rope, - Dest_loc, - O_nope, - O_rope, - stride_kv_nope_bs, - stride_kv_nope_h, - stride_kv_nope_d, - stride_kv_rope_bs, - stride_kv_rope_h, - stride_kv_rope_d, - stride_o_nope_bs, - stride_o_nope_h, - stride_o_nope_d, - stride_o_rope_bs, - stride_o_rope_h, - stride_o_rope_d, - BLOCK_DMODEL_NOPE: tl.constexpr, - BLOCK_DMODEL_ROPE: tl.constexpr, -): - cur_index = tl.program_id(0) - offs_d_nope = tl.arange(0, BLOCK_DMODEL_NOPE) - offs_d_rope = tl.arange(0, BLOCK_DMODEL_ROPE) - - dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) - - kv_nope_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_d_nope[None, :] - kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_d_rope[None, :] - - o_nope_ptrs = O_nope + dest_index * stride_o_nope_bs + stride_o_nope_d * offs_d_nope[None, :] - o_rope_ptrs = O_rope + dest_index * stride_o_rope_bs + stride_o_rope_d * offs_d_rope[None, :] - - kv_nope = tl.load(kv_nope_ptrs) - kv_rope = tl.load(kv_rope_ptrs) - - tl.store(o_nope_ptrs, kv_nope) - tl.store(o_rope_ptrs, kv_rope) - return - - -@torch.no_grad() -def destindex_copy_kv(KV_nope, KV_rope, DestLoc, O_nope, O_rope): - seq_len = DestLoc.shape[0] - kv_nope_head_dim = KV_nope.shape[2] - kv_rope_head_dim = KV_rope.shape[2] - - assert KV_nope.shape[1] == O_nope.shape[1] - assert KV_nope.shape[2] == O_nope.shape[2] - assert KV_rope.shape[1] == O_rope.shape[1] - assert KV_rope.shape[2] == O_rope.shape[2] - grid = (seq_len,) - num_warps = 1 - - _fwd_kernel_destindex_copy_kv[grid]( - KV_nope, - KV_rope, - DestLoc, - O_nope, - O_rope, - KV_nope.stride(0), - KV_nope.stride(1), - KV_nope.stride(2), - KV_rope.stride(0), - KV_rope.stride(1), - KV_rope.stride(2), - O_nope.stride(0), - O_nope.stride(1), - O_nope.stride(2), - O_rope.stride(0), - O_rope.stride(1), - O_rope.stride(2), - BLOCK_DMODEL_NOPE=kv_nope_head_dim, - BLOCK_DMODEL_ROPE=kv_rope_head_dim, - num_warps=num_warps, - num_stages=1, - ) - return - - -if __name__ == "__main__": - import torch.nn.functional as F - - B, N_CTX, H, NOPE_HEAD, ROPE_HEAD = 32, 1024, 1, 512, 64 - dtype = torch.bfloat16 - dest_loc = torch.randint(0, 100, (50,), device="cuda").unique() - kv = torch.randn((len(dest_loc), H, NOPE_HEAD + ROPE_HEAD), dtype=dtype).cuda() - O_nope = torch.zeros((B * N_CTX, H, NOPE_HEAD), dtype=dtype).cuda() - O_rope = torch.zeros((B * N_CTX, H, ROPE_HEAD), dtype=dtype).cuda() - - kv_nope = kv[:, :, :NOPE_HEAD] - kv_rope = kv[:, :, NOPE_HEAD:] - destindex_copy_kv(kv_nope, kv_rope, dest_loc, O_nope, O_rope) - - assert torch.allclose(O_nope[dest_loc], kv_nope, atol=1e-2, rtol=0) - assert torch.allclose(O_rope[dest_loc], kv_rope, atol=1e-2, rtol=0) From 2528c4c7cc9943195b810e9ba9969287f0ea6e47 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 9 Jan 2026 10:19:52 +0000 Subject: [PATCH 110/114] fix --- lightllm/common/basemodel/attention/flashinfer/mla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py index bed52db94c..0e060ee44a 100644 --- a/lightllm/common/basemodel/attention/flashinfer/mla.py +++ b/lightllm/common/basemodel/attention/flashinfer/mla.py @@ -99,7 +99,7 @@ def _mla_prefill_att( ) -> torch.Tensor: self.backend: MlaFlashInferAttBackend = self.backend # for typing k_nope, k_rope = k - o_tensor = alloc_func(q.shape, q.dtype, device="cuda") + o_tensor = alloc_func((q.shape[0], q.shape[1], k_nope.shape[1]), q.dtype, device="cuda") q_head_num = q.shape[1] k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1) self.prefill_wrapper.run(q, k, v, out=o_tensor) From b50d74cd355f29b28627c311cb0480d369ba91c3 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 9 Jan 2026 10:24:50 +0000 Subject: [PATCH 111/114] fix --- lightllm/common/basemodel/attention/flashinfer/mla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py index 0e060ee44a..6e52203b4f 100644 --- a/lightllm/common/basemodel/attention/flashinfer/mla.py +++ b/lightllm/common/basemodel/attention/flashinfer/mla.py @@ -99,7 +99,7 @@ def _mla_prefill_att( ) -> torch.Tensor: self.backend: MlaFlashInferAttBackend = self.backend # for typing k_nope, k_rope = k - o_tensor = alloc_func((q.shape[0], q.shape[1], k_nope.shape[1]), q.dtype, device="cuda") + o_tensor = alloc_func((q.shape[0], q.shape[1], k_nope.shape[2]), q.dtype, device="cuda") q_head_num = q.shape[1] k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1) self.prefill_wrapper.run(q, k, v, out=o_tensor) From c17eed19b66160ab8d4cf6b014257d15e83d6c84 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 9 Jan 2026 13:15:29 +0000 Subject: [PATCH 112/114] fix --- lightllm/common/basemodel/cuda_graph.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 6c5f835784..56ee346bda 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -77,10 +77,16 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo): # 浅拷贝,不然后续传入到cuda graph捕获过程中后,infer_state因为提前拥有了这些属性, # 导致不会重新初始化,这样捕获过程中会不能捕获这些临时添加到 infer_state 管理对象 # 中的 tensor。 + for _ in range(1): + # 记录原始存在的变量 + pure_para_set = set(vars(infer_state).keys()) torch.cuda.synchronize() decode_func(copy.copy(infer_state)) torch.cuda.synchronize() + for param_name in vars(infer_state).keys(): + if param_name not in pure_para_set: + delattr(infer_state, param_name) with lightllm_capture_graph(dist_group): with torch.cuda.graph(graph_obj, pool=self.mempool): @@ -106,9 +112,19 @@ def _capture_decode_overlap( infer_state1.total_token_num = self.graph_max_len_in_batch * batch_size # warmup for _ in range(1): + # 记录原始存在的变量 + pure_para_set = set(vars(infer_state).keys()) + pure_para_set1 = set(vars(infer_state1).keys()) torch.cuda.synchronize() decode_func(copy.copy(infer_state), copy.copy(infer_state1)) torch.cuda.synchronize() + for para_name in vars(infer_state).keys(): + if para_name not in pure_para_set: + delattr(infer_state, para_name) + for para_name in vars(infer_state1).keys(): + if para_name not in pure_para_set1: + delattr(infer_state1, para_name) + with lightllm_capture_graph(dist_group1): with lightllm_capture_graph(dist_group): with torch.cuda.graph(graph_obj, pool=self.mempool): From 74b8f5b36d0f33c540c1b32f6c363f0f693d0381 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 9 Jan 2026 13:22:11 +0000 Subject: [PATCH 113/114] fix --- lightllm/common/basemodel/cuda_graph.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 56ee346bda..dd29c9a833 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -84,7 +84,7 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo): torch.cuda.synchronize() decode_func(copy.copy(infer_state)) torch.cuda.synchronize() - for param_name in vars(infer_state).keys(): + for param_name in set(vars(infer_state).keys()): if param_name not in pure_para_set: delattr(infer_state, param_name) @@ -118,10 +118,10 @@ def _capture_decode_overlap( torch.cuda.synchronize() decode_func(copy.copy(infer_state), copy.copy(infer_state1)) torch.cuda.synchronize() - for para_name in vars(infer_state).keys(): + for para_name in set(vars(infer_state).keys()): if para_name not in pure_para_set: delattr(infer_state, para_name) - for para_name in vars(infer_state1).keys(): + for para_name in set(vars(infer_state1).keys()): if para_name not in pure_para_set1: delattr(infer_state1, para_name) From f4b982cad893db1270235e2e64ca60d4041df117 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 9 Jan 2026 14:57:29 +0000 Subject: [PATCH 114/114] fix --- lightllm/common/basemodel/attention/fa3/mla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/attention/fa3/mla.py b/lightllm/common/basemodel/attention/fa3/mla.py index 95a4120ff5..9a10457b12 100644 --- a/lightllm/common/basemodel/attention/fa3/mla.py +++ b/lightllm/common/basemodel/attention/fa3/mla.py @@ -217,7 +217,7 @@ def _mla_decode_att( cache_seqlens=self.b_att_seq_len, cu_seqlens_q=self.cu_seqlens_q, cu_seqlens_k_new=self.cu_seqlens_k, - max_seqlen_q=self.infer_state.max_q_seq_len, + max_seqlen_q=self.decode_max_q_seq_len, softmax_scale=softmax_scale, causal=True, window_size=(-1, -1),