diff --git a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse.py similarity index 91% rename from unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse.py rename to unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse.py index 67463fa7ad..ac18ffb955 100644 --- a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse.py @@ -1,4 +1,7 @@ import pytest + +pytest.skip(reason="need install lightllmKernel", allow_module_level=True) + import torch from lightllm.utils.light_utils import light_ops @@ -21,7 +24,7 @@ class MockInferState: def __init__( self, batch_size, - max_len_in_batch, + max_kv_seq_len, req_to_tokens, b_req_idx, b_seq_len, @@ -29,7 +32,7 @@ def __init__( b_mark_shared_group=None, ): self.batch_size = batch_size - self.max_len_in_batch = max_len_in_batch + self.max_kv_seq_len = max_kv_seq_len self.req_manager = MockReqManager(req_to_tokens) self.b_req_idx = b_req_idx self.b_seq_len = b_seq_len @@ -44,10 +47,11 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le 测试 ppl_int8kv_flash_decoding_diverse 的 token_decode_attention_flash_decoding 与 ppl_int8kv_flash_decoding (baseline) 的对比。 """ - from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse import ( + + from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse import ( token_decode_attention_flash_decoding as diverse_attention, ) - from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding import ( + from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding import ( token_decode_attention_flash_decoding as baseline_attention, ) @@ -87,7 +91,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le # 创建 baseline 的 infer_state (不需要 b_shared_seq_len) baseline_infer_state = MockInferState( batch_size=batch_size, - max_len_in_batch=seq_len, + max_kv_seq_len=seq_len, req_to_tokens=req_to_tokens, b_req_idx=b_req_idx, b_seq_len=b_seq_len, @@ -96,7 +100,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le # 创建 diverse 的 infer_state diverse_infer_state = MockInferState( batch_size=batch_size, - max_len_in_batch=seq_len, + max_kv_seq_len=seq_len, req_to_tokens=req_to_tokens, b_req_idx=b_req_idx, b_seq_len=b_seq_len, @@ -108,8 +112,6 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le baseline_out = baseline_attention( q=q.clone(), infer_state=baseline_infer_state, - q_head_num=num_heads, - head_dim=head_dim, cache_k=cache_k, cache_k_scale=cache_k_scale, cache_v=cache_v, @@ -120,8 +122,6 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le diverse_out = diverse_attention( q=q.clone(), infer_state=diverse_infer_state, - q_head_num=num_heads, - head_dim=head_dim, cache_k=cache_k, cache_k_scale=cache_k_scale, cache_v=cache_v, diff --git a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage1.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage1.py similarity index 93% rename from unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage1.py rename to unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage1.py index 30e83b88b6..5ef36e38e2 100644 --- a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage1.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage1.py @@ -1,6 +1,8 @@ import pytest import torch -from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse_stage1 import flash_decode_stage1 +from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse_stage1 import ( + flash_decode_stage1, +) @pytest.fixture @@ -81,7 +83,7 @@ def test_flash_decode_stage1_execution(setup_tensors): new_k = k.to(q.dtype) new_v = v.to(q.dtype) - from lightllm.models.llama.triton_kernel.gqa_flash_decoding_stage1 import ( + from lightllm.common.basemodel.triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding_stage1 import ( flash_decode_stage1 as gqa_flash_decode_stage1, ) diff --git a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage2.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage2.py similarity index 96% rename from unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage2.py rename to unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage2.py index 2ba085cc91..cde7734817 100644 --- a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage2.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage2.py @@ -1,4 +1,7 @@ import pytest + +pytest.skip(reason="need install lightllmkernel", allow_module_level=True) + import torch from lightllm.utils.light_utils import light_ops @@ -94,7 +97,7 @@ def test_flash_decode_stage2_execution(shared_seq_len): b_seq_len = setup_tensors["b_seq_len"] - setup_tensors["b_shared_seq_len"] req_to_tokens = setup_tensors["Req_to_tokens"][:, setup_tensors["b_shared_seq_len"][0].item() :] - from lightllm.models.llama.triton_kernel.gqa_flash_decoding_stage1 import ( + from lightllm.common.basemodel.triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding_stage1 import ( flash_decode_stage1 as gqa_flash_decode_stage1, ) diff --git a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage3.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage3.py similarity index 79% rename from unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage3.py rename to unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage3.py index b406e2dcf5..18550982b9 100644 --- a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage3.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage3.py @@ -1,6 +1,8 @@ import pytest import torch -from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse_stage3 import flash_diverse_decode_stage3 +from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse_stage3 import ( + flash_diverse_decode_stage3, +) @pytest.mark.parametrize( @@ -23,7 +25,10 @@ def test_flash_diverse_decode_stage3(batch, head_num, seq_len, shared_seq_len, b flash_diverse_decode_stage3(mid_out, mid_out_logexpsum, B_Seqlen, b_shared_seq_len, out, block_seq) true_out = torch.zeros_like(out) - from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2 + + from lightllm.common.basemodel.triton_kernel.att.decode_att.mha.flash_decoding.flash_decoding_stage2 import ( + flash_decode_stage2, + ) flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, true_out, block_seq) diff --git a/unit_tests/models/llama/test_context_flashattention_nopad.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad1.py similarity index 92% rename from unit_tests/models/llama/test_context_flashattention_nopad.py rename to unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad1.py index f24ab619bd..541594306d 100644 --- a/unit_tests/models/llama/test_context_flashattention_nopad.py +++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad1.py @@ -5,12 +5,11 @@ import torch.nn.functional as F import flashinfer from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( +from lightllm.common.basemodel.triton_kernel.att.prefill_att.context_flashattention_nopad import ( context_attention_fwd, context_attention_fwd_no_prompt_cache, ) from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.common.req_manager import ReqManager logger = init_logger(__name__) @@ -54,14 +53,14 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): infer_state = LlamaInferStateInfo() infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX + infer_state.max_q_seq_len = N_CTX infer_state.total_token_num = Z * N_CTX - infer_state.req_manager = ReqManager(Z, N_CTX, None) + infer_state.req_manager = type("Object", (), {})() infer_state.req_manager.req_to_token_indexs = req_to_token_indexs infer_state.b_req_idx = b_req_idx infer_state.b_seq_len = b_seq_len infer_state.b_ready_cache_len = b_ready_cache_len - infer_state.b_start_loc = q_start_loc + infer_state.b_q_start_loc = q_start_loc context_attention_fwd( q, @@ -69,10 +68,10 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): kv[:, KV_HEADS:, :], o, infer_state.b_req_idx, - infer_state.b_start_loc, + infer_state.b_q_start_loc, infer_state.b_seq_len, infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, + infer_state.max_q_seq_len, infer_state.req_manager.req_to_token_indexs, ) @@ -127,7 +126,11 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): "batch, seqlen, q_heads, kv_heads, head_dim", [ (a, b, c, d, e) - for a in [1, 16, 32, 128, 512] + for a in [ + 1, + 16, + 32, + ] for b in [16, 32, 512, 1024] for c in [28] for d in [4] @@ -149,18 +152,18 @@ def test_context_attention_fwd_no_prompt_cache(batch, seqlen, q_heads, kv_heads, infer_state = LlamaInferStateInfo() infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX + infer_state.max_q_seq_len = N_CTX infer_state.b_seq_len = b_seq_len - infer_state.b_start_loc = b_start_loc + infer_state.b_q_start_loc = b_start_loc context_attention_fwd_no_prompt_cache( q, k, v, o, - infer_state.b_start_loc, + infer_state.b_q_start_loc, infer_state.b_seq_len, - infer_state.max_len_in_batch, + infer_state.max_q_seq_len, ) head_dim = HEAD_DIM diff --git a/unit_tests/models/deepseek2/test_destindex_copy_kv.py b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_mla_destindex_copy_kv.py similarity index 93% rename from unit_tests/models/deepseek2/test_destindex_copy_kv.py rename to unit_tests/common/basemodel/triton_kernel/kv_copy/test_mla_destindex_copy_kv.py index 1379dc72de..ed0c6e369f 100644 --- a/unit_tests/models/deepseek2/test_destindex_copy_kv.py +++ b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_mla_destindex_copy_kv.py @@ -1,6 +1,6 @@ import torch import pytest -from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv import destindex_copy_kv +from lightllm.common.basemodel.triton_kernel.kv_copy.mla_copy_kv import destindex_copy_kv from lightllm.utils.log_utils import init_logger import torch.nn.functional as F diff --git a/unit_tests/models/deepseek2/test_gqa_flash_decoding.py b/unit_tests/common/basemodel/triton_kernel/mla_att/decode_att/test_gqa_flash_decoding.py similarity index 92% rename from unit_tests/models/deepseek2/test_gqa_flash_decoding.py rename to unit_tests/common/basemodel/triton_kernel/mla_att/decode_att/test_gqa_flash_decoding.py index d0bc670ecb..a5ac9708d2 100644 --- a/unit_tests/models/deepseek2/test_gqa_flash_decoding.py +++ b/unit_tests/common/basemodel/triton_kernel/mla_att/decode_att/test_gqa_flash_decoding.py @@ -5,9 +5,10 @@ import torch.nn.functional as F import flashinfer from lightllm.utils.log_utils import init_logger -from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding +from lightllm.common.basemodel.triton_kernel.mla_att.decode_att.gqa_flash_decoding import ( + gqa_token_decode_attention_flash_decoding, +) from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.common.req_manager import ReqManager logger = init_logger(__name__) @@ -53,7 +54,7 @@ def test_gqa_flash_decoding(batch, seqlen, heads, nope_head, rope_head): infer_state.batch_size = Z infer_state.max_len_in_batch = N_CTX infer_state.total_token_num = Z * N_CTX - infer_state.req_manager = ReqManager(Z, N_CTX, None) + infer_state.req_manager = type("Object", (), {})() infer_state.req_manager.req_to_token_indexs = req_to_token_indexs infer_state.b_req_idx = b_req_idx infer_state.b_seq_len = b_seq_len @@ -67,10 +68,6 @@ def test_gqa_flash_decoding(batch, seqlen, heads, nope_head, rope_head): kv_nope, kv_rope, infer_state, - H, - D_HEAD, - ROPE_HEAD, - D_HEAD, sm_scale, o, ) diff --git a/unit_tests/common/basemodel/triton_kernel/test_atomic_event.py b/unit_tests/common/basemodel/triton_kernel/test_atomic_event.py index 536cad90fc..0afcd5558a 100644 --- a/unit_tests/common/basemodel/triton_kernel/test_atomic_event.py +++ b/unit_tests/common/basemodel/triton_kernel/test_atomic_event.py @@ -18,10 +18,10 @@ def test_add_in_place(): assert input.item() == 3, "最终值应为 3" -@pytest.mark.timeout(2) -def test_wait_timeout(): - input = torch.zeros((1,), device="cuda", dtype=torch.int32) - wait_value(input, 4) +# @pytest.mark.timeout(2) +# def test_wait_timeout(): +# input = torch.zeros((1,), device="cuda", dtype=torch.int32) +# wait_value(input, 4) if __name__ == "__main__": diff --git a/unit_tests/common/basemodel/triton_kernel/test_gen_sampling_params.py b/unit_tests/common/basemodel/triton_kernel/test_gen_sampling_params.py index 99971dea25..e9d0193279 100644 --- a/unit_tests/common/basemodel/triton_kernel/test_gen_sampling_params.py +++ b/unit_tests/common/basemodel/triton_kernel/test_gen_sampling_params.py @@ -25,6 +25,7 @@ def test_token_id_counter(): for _ in range(100): token_id_counter(prompt_ids=test_prompt_ids, out_token_id_counter=test_token_id_counter) end_event.record() + end_event.synchronize() logger.info(f"test_token_id_count cost time: {start_event.elapsed_time(end_event)} ms") diff --git a/unit_tests/models/deepseek2/test_repack_kv_index.py b/unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py similarity index 96% rename from unit_tests/models/deepseek2/test_repack_kv_index.py rename to unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py index f9e5928a9e..b5184d3caa 100644 --- a/unit_tests/models/deepseek2/test_repack_kv_index.py +++ b/unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py @@ -1,7 +1,7 @@ import torch import pytest from lightllm.utils.log_utils import init_logger -from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index +from lightllm.common.basemodel.triton_kernel.repack_kv_index import repack_kv_index logger = init_logger(__name__) diff --git a/unit_tests/common/fused_moe/test_deepep.py b/unit_tests/common/fused_moe/test_deepep.py index c846be0961..45778244b7 100644 --- a/unit_tests/common/fused_moe/test_deepep.py +++ b/unit_tests/common/fused_moe/test_deepep.py @@ -1,12 +1,13 @@ +import pytest + +pytest.skip(reason="need special env, install deep_ep and deep_gemm", allow_module_level=True) + import os import torch import torch.distributed as dist -import pytest import deep_ep import random import numpy as np -from deep_ep import Buffer, EventOverlap -from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor from lightllm.common.fused_moe.grouped_fused_moe_ep import fused_experts_impl from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather from typing import Tuple @@ -25,6 +26,8 @@ def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape + from deep_gemm import ceil_div + x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device) x_padded[:m, :n] = x x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) diff --git a/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py b/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py index eba15b2a1c..671805a3d2 100644 --- a/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py +++ b/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py @@ -1,6 +1,18 @@ import torch -import time import pytest + + +def is_fp8_native_supported(): + """检查是否为 H100/B200 等原生支持 FP8 的硬件 (SM90+)""" + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 9 + + +if not is_fp8_native_supported(): + pytest.skip(reason="not support fp8 test in this gpu card", allow_module_level=True) + import random from lightllm.common.fused_moe.moe_silu_and_mul_mix_quant_ep import silu_and_mul_masked_post_quant_fwd from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd diff --git a/unit_tests/common/fused_moe/test_softmax_topk.py b/unit_tests/common/fused_moe/test_softmax_topk.py index 262c37a0f2..6252dfa8c3 100755 --- a/unit_tests/common/fused_moe/test_softmax_topk.py +++ b/unit_tests/common/fused_moe/test_softmax_topk.py @@ -9,7 +9,10 @@ def benchmark(M, N, K, renorm, runs): - import sgl_kernel as sgl_ops + try: + import sgl_kernel as sgl_ops + except Exception as e: + pytest.skip(f"no sgl_kernel error: {str(e)}", allow_module_level=True) gating = torch.randn(M, N, device="cuda", dtype=torch.float32) torch.cuda.synchronize() diff --git a/unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py b/unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py index 1ddb20b632..2c0b7bf76e 100644 --- a/unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py +++ b/unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py @@ -4,6 +4,18 @@ from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_scaled_mm_per_token_kernel import fp8_scaled_mm_per_token +def is_fp8_native_supported(): + """检查是否为 H100/B200 等原生支持 FP8 的硬件 (SM90+)""" + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 9 + + +if not is_fp8_native_supported(): + pytest.skip("not support fp8 in this gpu card", allow_module_level=True) + + @pytest.mark.parametrize("M", [1, 2, 4, 8, 16, 32, 64, 128]) @pytest.mark.parametrize("N,K", [(2048, 2048), (4096, 5120), (8192, 4096)]) @pytest.mark.parametrize("output_dtype", [torch.bfloat16]) diff --git a/unit_tests/models/deepseek2/test_destindex_copy_kv_fp8.py b/unit_tests/models/deepseek2/test_destindex_copy_kv_fp8.py deleted file mode 100644 index 4f9c0a3373..0000000000 --- a/unit_tests/models/deepseek2/test_destindex_copy_kv_fp8.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -import pytest -from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 -from lightllm.utils.log_utils import init_logger -import torch.nn.functional as F - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -@pytest.mark.parametrize( - "batch, seqlen, heads, nope_head, rope_head, copy_len", - [ - (a, b, c, d, e, f) - for a in [1, 16, 32, 128, 512] - for b in [1024, 2048] - for c in [1] - for d in [512] - for e in [64] - for f in [10, 20, 100, 1024] - ], -) -def test_destindex_copy_kv_fp8(batch, seqlen, heads, nope_head, rope_head, copy_len): - B, N_CTX, H, NOPE_HEAD, ROPE_HEAD, COPY_LEN = batch, seqlen, heads, nope_head, rope_head, copy_len - dtype = torch.bfloat16 - NUM = COPY_LEN - dest_loc = torch.arange(NUM).cuda() - kv = torch.randn((len(dest_loc), H, NOPE_HEAD + ROPE_HEAD), dtype=dtype).cuda() - out = torch.zeros((B * N_CTX, H, NOPE_HEAD + ROPE_HEAD + 2), dtype=torch.uint8).cuda() - - fp8_type = torch.float8_e4m3fn - kv_nope = kv[:, :, :NOPE_HEAD] - kv_rope = kv[:, :, NOPE_HEAD:] - O_nope = out[:, :, :NOPE_HEAD].view(fp8_type) - O_rope = out[:, :, NOPE_HEAD:-2].view(fp8_type) - O_scale = out[:, :, -2:].view(dtype) - destindex_copy_kv_fp8(kv_nope, kv_rope, dest_loc, O_nope, O_rope, O_scale) - - cos1 = F.cosine_similarity(O_nope[:NUM].to(dtype) * O_scale[:NUM], kv_nope).mean() - cos2 = F.cosine_similarity(O_rope[:NUM].to(dtype) * O_scale[:NUM], kv_rope).mean() - assert cos1 > 0.98 - assert cos2 > 0.98 diff --git a/unit_tests/models/deepseek2/test_gqa_flash_decoding_fp8.py b/unit_tests/models/deepseek2/test_gqa_flash_decoding_fp8.py deleted file mode 100644 index 72d9d9accc..0000000000 --- a/unit_tests/models/deepseek2/test_gqa_flash_decoding_fp8.py +++ /dev/null @@ -1,79 +0,0 @@ -import torch -import pytest -import numpy as np -import torch.nn.functional as F -from lightllm.utils.log_utils import init_logger -from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding -from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding_fp8 import gqa_token_decode_attention_flash_decoding_fp8 -from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.common.req_manager import ReqManager - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -@pytest.mark.parametrize( - "batch, seqlen, heads, nope_head, rope_head", - [(a, b, c, d, e) for a in [1, 16, 32, 128] for b in [16, 32, 512, 2048] for c in [16] for d in [512] for e in [64]], -) -def test_gqa_flash_decoding_fp8(batch, seqlen, heads, nope_head, rope_head): - Z, N_CTX, H, D_HEAD, ROPE_HEAD = batch, seqlen, heads, nope_head, rope_head - dtype = torch.bfloat16 - sm_scale = 1.0 / ((D_HEAD + ROPE_HEAD) ** 0.5) - q = torch.randn((Z, H, D_HEAD), dtype=dtype, device="cuda") - q_rope = torch.randn((Z, H, ROPE_HEAD), dtype=dtype, device="cuda") - - kv = torch.randn((Z * N_CTX, 1, D_HEAD + ROPE_HEAD), dtype=dtype, device="cuda") - kv_scale = torch.randn((Z * N_CTX, 1, 1), dtype=dtype, device="cuda") - kv_fp8 = kv.to(torch.float8_e4m3fn) - - req_to_token_indexs = torch.zeros((10, Z * N_CTX), dtype=torch.int32, device="cuda") - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - b_req_idx = torch.ones((Z,), dtype=torch.int32, device="cuda") - - b_seq_len[0] = N_CTX - b_req_idx[0] = 0 - req_to_token_indexs[0][:N_CTX] = torch.tensor(np.arange(N_CTX), dtype=torch.int32).cuda() - - o = torch.empty((Z, H, D_HEAD), dtype=dtype, device="cuda") - o1 = torch.empty((Z, H, D_HEAD), dtype=dtype, device="cuda") - - infer_state = Deepseek2InferStateInfo() - infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX - infer_state.total_token_num = Z * N_CTX - infer_state.req_manager = ReqManager(Z, N_CTX, None) - infer_state.req_manager.req_to_token_indexs = req_to_token_indexs - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - - kv_nope = kv_fp8[:, :, :D_HEAD].to(dtype) * kv_scale - kv_rope = kv_fp8[:, :, D_HEAD:].to(dtype) * kv_scale - gqa_token_decode_attention_flash_decoding( - q, - q_rope, - kv_nope, - kv_rope, - infer_state, - H, - D_HEAD, - ROPE_HEAD, - D_HEAD, - sm_scale, - o, - ) - - kv_nope_fp8 = kv_fp8[:, :, :D_HEAD] - kv_rope_fp8 = kv_fp8[:, :, D_HEAD:] - gqa_token_decode_attention_flash_decoding_fp8( - q, q_rope, kv_nope_fp8, kv_rope_fp8, kv_scale, infer_state, H, D_HEAD, ROPE_HEAD, D_HEAD, sm_scale, o1 - ) - - cos_sim = F.cosine_similarity(o, o1).mean() - assert cos_sim > 0.99 diff --git a/unit_tests/models/llama/test_context_flashattention_nopad_fa3_fp8.py b/unit_tests/models/llama/test_context_flashattention_nopad_fa3_fp8.py deleted file mode 100644 index 737bb655b1..0000000000 --- a/unit_tests/models/llama/test_context_flashattention_nopad_fa3_fp8.py +++ /dev/null @@ -1,154 +0,0 @@ -import torch -import time -import pytest -import triton as tl -import numpy as np -import torch.nn.functional as F -from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd, -) -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.sgl_utils import flash_attn_with_kvcache -from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def kv_quantize_per_head_fp8(kv_buffer: torch.Tensor, seq_lens): - device = kv_buffer.device - B = seq_lens.size(0) - min_fp8 = torch.finfo(torch.float8_e4m3fn).min - max_fp8 = torch.finfo(torch.float8_e4m3fn).max - _, S_max, H, D = kv_buffer.shape - seq_range = torch.arange(S_max, device=device)[None, :] - valid_mask = (seq_range < seq_lens[:, None]).view(B, S_max, 1, 1) - masked = kv_buffer * valid_mask - max_per_bh = masked.abs().amax(dim=(1, 3)) # [B, H] - scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)).to(torch.float32) - scales_exp = scales.view(B, 1, H, 1) - q = (kv_buffer / scales_exp).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn) - return q, scales - - -@pytest.mark.parametrize( - "batch, seqlen, q_heads, kv_heads, head_dim", - [ - (a, b, c, d, e) - for a in [1, 16, 32, 128, 512] - for b in [16, 32, 512, 1024] - for c in [28] - for d in [4] - for e in [128] - ], -) -def test_context_attention_fwd_fa3_fp8(batch, seqlen, q_heads, kv_heads, head_dim): - Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim - dtype = torch.bfloat16 - kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - # for i in range(Z * N_CTX): - # kv[i] = torch.randn((2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") * (i % 10 + 1) - - max_input_len = Z * N_CTX - req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX) - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) - rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") - b_seq_len += rand_num - b_ready_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda") - if N_CTX > 1: - b_ready_cache_len = torch.randint_like(b_seq_len, high=(N_CTX - 1) // 2, dtype=torch.int32, device="cuda") - b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() - q_lens = b_seq_len - b_ready_cache_len - q_start_loc = q_lens.cumsum(0) - q_lens - - q = torch.randn((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o1 = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - - infer_state = LlamaInferStateInfo() - infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX - infer_state.total_token_num = Z * N_CTX - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - infer_state.b_ready_cache_len = b_ready_cache_len - infer_state.b_start_loc = q_start_loc - - context_attention_fwd( - q, - kv[:, :KV_HEADS, :], - kv[:, KV_HEADS:, :], - o, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - req_to_token_indexs, - ) - - batch_size = Z - head_dim = HEAD_DIM - q_heads = Q_HEADS - kv_heads = KV_HEADS - page_table = torch.empty((batch_size, N_CTX), dtype=torch.int32, device="cuda") - page_table.copy_(req_to_token_indexs[b_req_idx, :N_CTX]) - - q_starts = torch.zeros((Z + 1,)).int().cuda() - q_starts[1:] = torch.cumsum(b_seq_len - b_ready_cache_len, dim=0) - kv_starts = torch.zeros_like(q_starts) - kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) - - k_cache = kv[:, :KV_HEADS, :] - v_cache = kv[:, KV_HEADS:, :] - # o1 = flash_attn_with_kvcache( - # q=q, - # k_cache=k_cache.reshape(-1, 1, kv_heads, head_dim), - # v_cache=v_cache.reshape(-1, 1, kv_heads, head_dim), - # page_table=page_table, - # cache_seqlens=infer_state.b_seq_len, - # cu_seqlens_q=q_starts, - # cu_seqlens_k_new=kv_starts, - # max_seqlen_q=N_CTX, - # causal=True, - # window_size=(-1, -1), - # softcap=0.0, - # return_softmax_lse=False, - # ) - - q, q_scale = q_per_head_fp8_quant(q.view(q.shape[0], kv_heads, -1), q_lens, q_starts) - k, k_scale = kv_quantize_per_head_fp8(k_cache[page_table], b_seq_len) - v, v_scale = kv_quantize_per_head_fp8(v_cache[page_table], b_seq_len) - o1 = flash_attn_with_kvcache( - q=q.view(-1, q_heads, head_dim), - k_cache=k.view(-1, N_CTX, kv_heads, head_dim).to(torch.float8_e4m3fn), - v_cache=v.view(-1, N_CTX, kv_heads, head_dim).to(torch.float8_e4m3fn), - # page_table=page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=q_starts, - cu_seqlens_k_new=kv_starts, - max_seqlen_q=N_CTX, - causal=True, - window_size=(-1, -1), - softcap=0.0, - q_descale=q_scale.view(batch_size, kv_heads), - k_descale=k_scale.view(batch_size, kv_heads), - v_descale=v_scale.view(batch_size, kv_heads), - return_softmax_lse=False, - ) - - # assert torch.allclose(o, o1, atol=1e-1, rtol=1e-1) - cos_sim1 = F.cosine_similarity(o, o1).mean() - print(cos_sim1) - assert cos_sim1.item() == 1 - - -if __name__ == "__main__": - test_context_attention_fwd_fa3_fp8(32, 16384, 32, 4, 128) diff --git a/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_fp8.py b/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_fp8.py deleted file mode 100644 index 5ee2306adf..0000000000 --- a/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_fp8.py +++ /dev/null @@ -1,145 +0,0 @@ -import torch -import time -import pytest -import numpy as np -import torch.nn.functional as F -import flashinfer -from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd, -) -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops - -if HAS_VLLM: - scaled_fp8_quant = vllm_ops.scaled_fp8_quant -else: - scaled_fp8_quant = None - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -@pytest.mark.parametrize( - "batch, seqlen, q_heads, kv_heads, head_dim", - [ - (a, b, c, d, e) - for a in [1, 16, 32, 128, 512] - for b in [16, 32, 512, 1024] - for c in [28] - for d in [4] - for e in [128] - ], -) -def test_context_attention_fwd_flashinfer_fp8(batch, seqlen, q_heads, kv_heads, head_dim): - Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim - dtype = torch.bfloat16 - kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - # for i in range(Z * N_CTX): - # kv[i] = torch.randn((2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") * (i % 64 + 1) - - max_input_len = Z * N_CTX - req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX) - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) - rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") - b_seq_len += rand_num - b_ready_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda") - if N_CTX > 1: - b_ready_cache_len = torch.randint_like(b_seq_len, high=(N_CTX - 1) // 2, dtype=torch.int32, device="cuda") - b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() - q_lens = b_seq_len - b_ready_cache_len - q_start_loc = q_lens.cumsum(0) - q_lens - kv_start_loc = b_seq_len.cumsum(0) - b_seq_len - - q = torch.randn((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o1 = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - - infer_state = LlamaInferStateInfo() - infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX - infer_state.total_token_num = Z * N_CTX - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - infer_state.b_ready_cache_len = b_ready_cache_len - infer_state.b_start_loc = q_start_loc - - context_attention_fwd( - q, - kv[:, :KV_HEADS, :], - kv[:, KV_HEADS:, :], - o, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - req_to_token_indexs, - ) - - batch_size = Z - head_dim = HEAD_DIM - q_heads = Q_HEADS - kv_heads = KV_HEADS - page_size = 1 - workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) - q_starts = torch.zeros((Z + 1,)).int().cuda() - q_starts[1:] = torch.cumsum(b_seq_len - b_ready_cache_len, dim=0) - kv_starts = torch.zeros_like(q_starts) - kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) - q_indptr = q_starts.int() - kv_indptr = kv_starts.int() - kv_indices = torch.arange(Z * N_CTX).cuda().int() - for b, sl, start in zip(b_req_idx, b_seq_len, kv_start_loc): - kv_indices[start : start + sl] = req_to_token_indexs[b][:sl] - kv_last_page_len_buffer = torch.empty(batch_size, device="cuda:0", dtype=torch.int32) - wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, - qo_indptr_buf=q_indptr, - paged_kv_indptr_buf=kv_indptr, - paged_kv_indices_buf=kv_indices, - paged_kv_last_page_len_buf=kv_last_page_len_buffer, - ) - kv_last_page_len = torch.full((batch_size,), page_size, dtype=torch.int32) - k_cache = kv[:, :KV_HEADS, :].contiguous() - v_cache = kv[:, KV_HEADS:, :].contiguous() - k, k_scale = scaled_fp8_quant(k_cache.view(1, -1)) - v, v_scale = scaled_fp8_quant(v_cache.view(1, -1)) - wrapper.plan( - q_indptr, - kv_indptr, - kv_indices, - kv_last_page_len, - q_heads, - kv_heads, - head_dim, - page_size, - causal=True, - pos_encoding_mode="NONE", - logits_soft_cap=0.0, - q_data_type=q.dtype, - kv_data_type=torch.float8_e4m3fn, - ) - wrapper.run( - q, - (k.view(-1, 1, kv_heads, head_dim), v.view(-1, 1, kv_heads, head_dim)), - k_scale=k_scale, - v_scale=v_scale, - out=o1, - return_lse=False, - ) - - # assert torch.allclose(o, o1, atol=1e-2, rtol=2e-1) - cos_sim1 = F.cosine_similarity(o, o1).mean() - print(cos_sim1) - assert cos_sim1 == 1 - - -if __name__ == "__main__": - test_context_attention_fwd_flashinfer_fp8(16, 1024, 28, 4, 128) diff --git a/unit_tests/models/llama/test_token_attention_nopad.py b/unit_tests/models/llama/test_token_attention_nopad.py deleted file mode 100644 index 1bbb291662..0000000000 --- a/unit_tests/models/llama/test_token_attention_nopad.py +++ /dev/null @@ -1,151 +0,0 @@ -import torch -import time -import pytest -import numpy as np -import torch.nn.functional as F -import flashinfer -from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.common.req_manager import ReqManager -from lightllm.models.llama.triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def ref_token_attention_nopad(q, k, v, o, q_h, h_dim, infer_state): - from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd - - total_token_num = infer_state.total_token_num - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, q_h, h_dim) - - att_m_tensor = torch.empty((q_h, total_token_num), dtype=torch.float32).cuda() - - token_att_fwd( - q.view(calcu_shape1), - k, - att_m_tensor, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( - token_softmax_reducev_fwd, - ) - - token_softmax_reducev_fwd( - att_m_tensor, - v, - o, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - ) - return o - - -@pytest.mark.parametrize( - "batch, seqlen, q_heads, kv_heads, head_dim", - [ - (a, b, c, d, e) - for a in [1, 16, 32, 128, 512] - for b in [16, 32, 512, 1024] - for c in [28] - for d in [4] - for e in [128] - ], -) -def test_token_attention_nopad(batch, seqlen, q_heads, kv_heads, head_dim): - Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim - dtype = torch.bfloat16 - q = torch.randn((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - - max_input_len = Z * N_CTX - req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX) - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX - b_start_loc = torch.arange(Z).cuda().int() * N_CTX - b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() - - o = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o1 = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - - infer_state = LlamaInferStateInfo() - infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX - infer_state.total_token_num = Z * N_CTX - infer_state.req_manager = ReqManager(Z, N_CTX, None) - infer_state.req_manager.req_to_token_indexs = req_to_token_indexs - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - infer_state.b_start_loc = b_start_loc - - ref_token_attention_nopad( - q, - kv[:, :KV_HEADS, :], - kv[:, KV_HEADS:, :], - o, - Q_HEADS, - HEAD_DIM, - infer_state, - ) - # gqa_decode_attention_fwd( - # q, - # kv[:,:KV_HEADS,:], - # kv[:,KV_HEADS:,:], - # o, - # infer_state.req_manager.req_to_token_indexs, - # infer_state.b_req_idx, - # infer_state.b_seq_len, - # ) - - batch_size = Z - head_dim = HEAD_DIM - q_heads = Q_HEADS - kv_heads = KV_HEADS - page_size = 1 - workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) - kv_starts = torch.zeros((Z + 1,)).int().cuda() - kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) - kv_indptr = kv_starts - kv_indices = torch.arange(Z * N_CTX).cuda().int() - for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): - kv_indices[start : start + sl] = req_to_token_indexs[b][:sl] - kv_last_page_len_buffer = torch.empty(batch_size, device="cuda:0", dtype=torch.int32) - wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=True, - paged_kv_indptr_buffer=kv_indptr, - paged_kv_indices_buffer=kv_indices, - paged_kv_last_page_len_buffer=kv_last_page_len_buffer, - ) - kv_last_page_len_buffer = torch.full((batch_size,), page_size, dtype=torch.int32) - wrapper.plan( - kv_indptr, - kv_indices, - kv_last_page_len_buffer, - q_heads, - kv_heads, - head_dim, - page_size, - q_data_type=dtype, - non_blocking=True, - ) - kv = kv.unsqueeze(1) - wrapper.run(q, (kv[:, :, :KV_HEADS, :], kv[:, :, KV_HEADS:, :]), out=o1, return_lse=False) - - cos_sim1 = F.cosine_similarity(o, o1).mean() - assert cos_sim1 == 1.0 diff --git a/unit_tests/models/llama/test_token_attention_nopad_fa3_fp8.py b/unit_tests/models/llama/test_token_attention_nopad_fa3_fp8.py deleted file mode 100644 index a7f48ab899..0000000000 --- a/unit_tests/models/llama/test_token_attention_nopad_fa3_fp8.py +++ /dev/null @@ -1,187 +0,0 @@ -import torch -import time -import pytest -import numpy as np -import torch.nn.functional as F -from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd -from lightllm.utils.sgl_utils import flash_attn_with_kvcache -from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def kv_quantize_per_head_fp8(kv_buffer: torch.Tensor, seq_lens): - device = kv_buffer.device - B = seq_lens.size(0) - min_fp8 = torch.finfo(torch.float8_e4m3fn).min - max_fp8 = torch.finfo(torch.float8_e4m3fn).max - _, S_max, H, D = kv_buffer.shape - seq_range = torch.arange(S_max, device=device)[None, :] - valid_mask = (seq_range < seq_lens[:, None]).view(B, S_max, 1, 1) - masked = kv_buffer * valid_mask - max_per_bh = masked.float().abs().amax(dim=(1, 3)) # [B, H] - scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)) - scales_exp = scales.view(B, 1, H, 1) - q = (kv_buffer / scales_exp).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn) - return q, scales - - -def ref_token_attention_nopad(q, k, v, o, q_h, h_dim, infer_state, req_to_token_indexs): - from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd - - total_token_num = infer_state.total_token_num - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, q_h, h_dim) - - att_m_tensor = torch.empty((q_h, total_token_num), dtype=torch.float32).cuda() - - token_att_fwd( - q.view(calcu_shape1), - k, - att_m_tensor, - req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( - token_softmax_reducev_fwd, - ) - - token_softmax_reducev_fwd( - att_m_tensor, - v, - o, - req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - ) - return o - - -@pytest.mark.parametrize( - "batch, seqlen, q_heads, kv_heads, head_dim", - [ - (a, b, c, d, e) - for a in [1, 16, 32, 128, 512] - for b in [16, 32, 512, 1024] - for c in [28] - for d in [4] - for e in [128] - ], -) -def test_token_attention_nopad_fa3_fp8(batch, seqlen, q_heads, kv_heads, head_dim): - Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim - dtype = torch.bfloat16 - q = torch.randn((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - # for i in range(Z * N_CTX): - # kv[i] = torch.randn((2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") * (i % 10 + 1) - - max_input_len = Z * N_CTX - req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX) - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) - rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") - b_seq_len += rand_num - b_start_loc = b_seq_len.cumsum(0) - b_seq_len - b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() - - o = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o1 = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - - infer_state = LlamaInferStateInfo() - infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX - infer_state.total_token_num = Z * N_CTX - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - infer_state.b_start_loc = b_start_loc - - ref_token_attention_nopad( - q, - kv[:, :KV_HEADS, :], - kv[:, KV_HEADS:, :], - o, - Q_HEADS, - HEAD_DIM, - infer_state, - req_to_token_indexs, - ) - # gqa_decode_attention_fwd( - # q, - # kv[:,:KV_HEADS,:], - # kv[:,KV_HEADS:,:], - # o, - # req_to_token_indexs, - # infer_state.b_req_idx, - # infer_state.b_seq_len, - # ) - - batch_size = Z - head_dim = HEAD_DIM - q_heads = Q_HEADS - kv_heads = KV_HEADS - kv_starts = torch.zeros((Z + 1,)).int().cuda() - kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) - q_starts = torch.arange(0, Z + 1).int().cuda() - page_table = torch.empty((batch_size, N_CTX), dtype=torch.int32).to(0) - page_table.copy_(req_to_token_indexs[b_req_idx, :N_CTX]) - - k_cache = kv[:, :KV_HEADS, :].contiguous() - v_cache = kv[:, KV_HEADS:, :].contiguous() - # o1 = flash_attn_with_kvcache( - # q=q, - # k_cache=k_cache[page_table].view(-1, N_CTX, kv_heads, head_dim), - # v_cache=v_cache[page_table].view(-1, N_CTX, kv_heads, head_dim), - # # page_table=page_table, - # cache_seqlens=infer_state.b_seq_len, - # cu_seqlens_q=q_starts, - # cu_seqlens_k_new=kv_starts, - # max_seqlen_q=1, - # causal=False, - # window_size=(-1, -1), - # softcap=0.0, - # return_softmax_lse=False, - # ) - - q, q_scale = scaled_fp8_quant(q.view(batch_size * kv_heads, -1), use_per_token_if_dynamic=True) - k, k_scale = kv_quantize_per_head_fp8(k_cache[page_table], b_seq_len) - v, v_scale = kv_quantize_per_head_fp8(v_cache[page_table], b_seq_len) - o1 = flash_attn_with_kvcache( - q=q.view(-1, q_heads, head_dim), - k_cache=k.view(-1, N_CTX, kv_heads, head_dim), - v_cache=v.view(-1, N_CTX, kv_heads, head_dim), - # page_table=page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=q_starts, - cu_seqlens_k_new=kv_starts, - max_seqlen_q=1, - causal=False, - window_size=(-1, -1), - softcap=0.0, - q_descale=q_scale.view(batch_size, kv_heads), - k_descale=k_scale.view(batch_size, kv_heads), - v_descale=v_scale.view(batch_size, kv_heads), - return_softmax_lse=False, - ) - - # assert torch.allclose(o, o1, atol=1e-1, rtol=1e-1) - cos_sim1 = F.cosine_similarity(o, o1).mean() - print(cos_sim1) - assert cos_sim1 == 1 - - -if __name__ == "__main__": - test_token_attention_nopad_fa3_fp8(16, 16384, 28, 4, 128) diff --git a/unit_tests/models/llama/test_token_attention_nopad_flashinfer_fp8.py b/unit_tests/models/llama/test_token_attention_nopad_flashinfer_fp8.py deleted file mode 100644 index 5c0e595b96..0000000000 --- a/unit_tests/models/llama/test_token_attention_nopad_flashinfer_fp8.py +++ /dev/null @@ -1,170 +0,0 @@ -import torch -import time -import pytest -import numpy as np -import torch.nn.functional as F -import flashinfer -from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd -from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def ref_token_attention_nopad(q, k, v, o, q_h, h_dim, infer_state, req_to_token_indexs): - from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd - - total_token_num = infer_state.total_token_num - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, q_h, h_dim) - - att_m_tensor = torch.empty((q_h, total_token_num), dtype=torch.float32).cuda() - - token_att_fwd( - q.view(calcu_shape1), - k, - att_m_tensor, - req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( - token_softmax_reducev_fwd, - ) - - token_softmax_reducev_fwd( - att_m_tensor, - v, - o, - req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - ) - return o - - -@pytest.mark.parametrize( - "batch, seqlen, q_heads, kv_heads, head_dim", - [ - (a, b, c, d, e) - for a in [1, 16, 32, 128, 512] - for b in [16, 32, 512, 1024] - for c in [28] - for d in [4] - for e in [128] - ], -) -def test_token_attention_nopad_flashinfer_fp8(batch, seqlen, q_heads, kv_heads, head_dim): - Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim - dtype = torch.bfloat16 - q = torch.randn((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - # for i in range(Z * N_CTX): - # kv[i] = torch.randn((2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") * (i % 10 + 1) - - max_input_len = Z * N_CTX - req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX) - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) - rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") - b_seq_len += rand_num - b_start_loc = b_seq_len.cumsum(0) - b_seq_len - b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() - - o = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o1 = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - - infer_state = LlamaInferStateInfo() - infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX - infer_state.total_token_num = Z * N_CTX - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - infer_state.b_start_loc = b_start_loc - - ref_token_attention_nopad( - q, - kv[:, :KV_HEADS, :], - kv[:, KV_HEADS:, :], - o, - Q_HEADS, - HEAD_DIM, - infer_state, - req_to_token_indexs, - ) - # gqa_decode_attention_fwd( - # q, - # kv[:,:KV_HEADS,:], - # kv[:,KV_HEADS:,:], - # o, - # req_to_token_indexs, - # infer_state.b_req_idx, - # infer_state.b_seq_len, - # ) - - batch_size = Z - head_dim = HEAD_DIM - q_heads = Q_HEADS - kv_heads = KV_HEADS - page_size = 1 - workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) - kv_starts = torch.zeros((Z + 1,)).int().cuda() - kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) - kv_indptr = kv_starts - kv_indices = torch.arange(Z * N_CTX).cuda().int() - for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): - kv_indices[start : start + sl] = req_to_token_indexs[b][:sl] - kv_last_page_len_buffer = torch.empty(batch_size, device="cuda:0", dtype=torch.int32) - wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=True, - paged_kv_indptr_buffer=kv_indptr, - paged_kv_indices_buffer=kv_indices, - paged_kv_last_page_len_buffer=kv_last_page_len_buffer, - ) - kv_last_page_len_buffer = torch.full((batch_size,), page_size, dtype=torch.int32) - k_cache = kv[:, :KV_HEADS, :].contiguous() - v_cache = kv[:, KV_HEADS:, :].contiguous() - k, k_scale = scaled_fp8_quant(k_cache.view(1, -1)) - v, v_scale = scaled_fp8_quant(v_cache.view(1, -1)) - wrapper.plan( - kv_indptr, - kv_indices, - kv_last_page_len_buffer, - q_heads, - kv_heads, - head_dim, - page_size, - q_data_type=dtype, - kv_data_type=torch.float8_e4m3fn, - non_blocking=True, - ) - wrapper.run( - q, - (k.view(-1, 1, kv_heads, head_dim), v.view(-1, 1, kv_heads, head_dim)), - k_scale=k_scale, - v_scale=v_scale, - out=o1, - return_lse=False, - ) - - cos_sim1 = F.cosine_similarity(o, o1).mean() - print(cos_sim1) - assert cos_sim1 == 1.0 - - -if __name__ == "__main__": - test_token_attention_nopad_flashinfer_fp8(16, 16384, 28, 4, 128) diff --git a/unit_tests/models/qwen3-vl/test_deepstack_emb.py b/unit_tests/models/qwen3-vl/test_deepstack_emb.py index 2f929fe0d0..f629a16352 100644 --- a/unit_tests/models/qwen3-vl/test_deepstack_emb.py +++ b/unit_tests/models/qwen3-vl/test_deepstack_emb.py @@ -50,7 +50,7 @@ def test_deepstack_same_image_twice(): deepstack_embs=deepstack_embs, img_token_lens=img_token_lens, img_start_token_ids=img_start_token_ids, - img_start_locs=img_start_locs, + img_start_locs_in_cache=img_start_locs, ) # 7. 看看相同图片两段上的增量 diff --git a/unit_tests/server/core/objs/test_req.py b/unit_tests/server/core/objs/test_req.py index 45fa7967fb..1c946531c1 100644 --- a/unit_tests/server/core/objs/test_req.py +++ b/unit_tests/server/core/objs/test_req.py @@ -1,6 +1,22 @@ import pytest - +import easydict from lightllm.server.core.objs.req import Req, TokenHealingReq, ChunkedPrefillReq, SamplingParams +from lightllm.utils.envs_utils import set_env_start_args + + +@pytest.fixture(scope="module", autouse=True) +def setup_module_env(): + set_env_start_args( + easydict.EasyDict( + { + "mtp_step": 0, + "llm_prefill_att_backend": ["None"], + "llm_decode_att_backend": ["None"], + "cpu_cache_token_page_size": 256, + "enable_cpu_cache": False, + } + ) + ) @pytest.fixture diff --git a/unit_tests/server/core/objs/test_shm_req_manager.py b/unit_tests/server/core/objs/test_shm_req_manager.py index dea40a4859..e26f128d5b 100644 --- a/unit_tests/server/core/objs/test_shm_req_manager.py +++ b/unit_tests/server/core/objs/test_shm_req_manager.py @@ -14,6 +14,11 @@ def setup_env(): running_max_req_size=10, disable_chunked_prefill=True, token_healing_mode=False, + mtp_step=0, + llm_prefill_att_backend=["None"], + llm_decode_att_backend=["None"], + cpu_cache_token_page_size=256, + enable_cpu_cache=False, ) ) # clear the lru_cache if used