From 62a9f0cbde99413a437595623619dbb5d04a5509 Mon Sep 17 00:00:00 2001 From: STwangyingrui Date: Mon, 20 Apr 2026 11:20:51 +0000 Subject: [PATCH] support gqa --- spas_sage_attn/utils.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/spas_sage_attn/utils.py b/spas_sage_attn/utils.py index 9b5761c..8a6255c 100644 --- a/spas_sage_attn/utils.py +++ b/spas_sage_attn/utils.py @@ -331,6 +331,13 @@ def get_block_map_meansim(q, k, is_causal=False, BLKQ=128, BLKK=64, simthreshd1= pooled_qblocks, sim_qblocks = get_pool_sim_triton_simmean(q, BLKQ, simthreshd1) pooled_kblocks, sim_kblocks = get_pool_sim_triton_simmean(k, BLKK, simthreshd1) + # GQA + num_kv_heads = k.size(1) + if num_kv_heads != Headnum: + repeat_factor = Headnum // num_kv_heads + pooled_kblocks = pooled_kblocks.repeat_interleave(repeat_factor, dim=1) + sim_kblocks = sim_kblocks.repeat_interleave(repeat_factor, dim=1) + sim_kblocks = sim_kblocks.unsqueeze(-2).expand(-1, -1, nq, -1) # faster than repeat sim_qblocks = sim_qblocks.unsqueeze(-1).expand(-1, -1, -1, nk) pooled_score = pooled_qblocks @ pooled_kblocks.transpose(-1, -2) * q.shape[-1] ** -0.5 @@ -383,6 +390,13 @@ def get_block_map_meansim_fuse_quant(q, k, km=None, is_causal=False, BLKQ=128, B pooled_qblocks, sim_qblocks, q_int8, q_scale = get_pool_sim_triton_simmean_fuse_quant(q, None, BLKQ, simthreshd1) pooled_kblocks, sim_kblocks, k_int8, k_scale = get_pool_sim_triton_simmean_fuse_quant(k, km, BLKK, simthreshd1) + # GQA + num_kv_heads = k.size(1) + if num_kv_heads != Headnum: + repeat_factor = Headnum // num_kv_heads + pooled_kblocks = pooled_kblocks.repeat_interleave(repeat_factor, dim=1) + sim_kblocks = sim_kblocks.repeat_interleave(repeat_factor, dim=1) + sim_kblocks = sim_kblocks.unsqueeze(-2).expand(-1, -1, nq, -1) # faster than repeat sim_qblocks = sim_qblocks.unsqueeze(-1).expand(-1, -1, -1, nk) pooled_score = pooled_qblocks @ pooled_kblocks.transpose(-1, -2) * q.shape[-1] ** -0.5 @@ -444,4 +458,4 @@ def block_map_lut(block_map): lut = torch.sort(filled_matrix, dim=-1)[0] - 1 # make index start from 0 lut[:, :, :, 1:] = lut[:, :, :, 1:] - lut[:, :, :, :-1] - return lut.to(torch.int32), valid_entry_num.to(torch.int32) \ No newline at end of file + return lut.to(torch.int32), valid_entry_num.to(torch.int32)