Skip to content

Commit 7569674

Browse files
committed
Made changes to the runtime to support normal kernel
1 parent e3ac7b5 commit 7569674

3 files changed

Lines changed: 135 additions & 2 deletions

File tree

python/tvm/relax/frontend/nn/llm/kv_cache.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,49 @@ def mla_absorbed(
180180
)
181181
).reshape(b, s, h_qo, kv_lora_rank)
182182

183+
def mla_normal(
184+
self,
185+
layer_id: int,
186+
q: Tensor,
187+
k: Tensor,
188+
v: Tensor,
189+
compressed_kv: Tensor,
190+
k_pe: Tensor,
191+
attn_score_scaling_factor: float = 1.0,
192+
) -> Tensor:
193+
"""Compute multi-head latent attention with the given data
194+
on the specified layer using the normal flow(WITHOUT weight absorption).
195+
"""
196+
# pylint: disable=protected-access
197+
b, s, h_qo, d_qk = q._expr.struct_info.shape
198+
d_v = v._expr.struct_info.shape[3]
199+
kv_lora_rank = compressed_kv._expr.struct_info.shape[3]
200+
qk_rope_head_dim = k_pe._expr.struct_info.shape[3]
201+
q = q.reshape(b * s, h_qo, d_qk)
202+
k = k.reshape(b * s, h_qo, d_qk)
203+
v = v.reshape(b * s, h_qo, d_v)
204+
compressed_kv = compressed_kv.reshape(b * s, kv_lora_rank)
205+
k_pe = k_pe.reshape(b * s, qk_rope_head_dim)
206+
207+
return Tensor(
208+
_expr=rx.BlockBuilder.current().emit(
209+
rx.call_dps_packed(
210+
"vm.builtin.attention_kv_cache_mla_normal",
211+
[
212+
self._expr,
213+
rx.PrimValue(layer_id), # type: ignore[arg-type]
214+
rx.PrimValue(attn_score_scaling_factor),
215+
q._expr,
216+
k._expr,
217+
v._expr,
218+
compressed_kv._expr,
219+
k_pe._expr,
220+
],
221+
out_sinfo=rx.TensorStructInfo((b * s, h_qo, d_v), q.dtype),
222+
)
223+
)
224+
).reshape(b, s, h_qo, d_v)
225+
183226
def get_query_positions(self, total_length: tir.PrimExpr) -> Tensor:
184227
"""Get the in-sequence positions of each slot in the query,
185228
which are needed for applying positional embeddings in some models.
@@ -591,7 +634,7 @@ def create_mla_kv_cache( # pylint: disable=too-many-locals
591634
rx.PrimValue(0),
592635
bb.add_func(_attention_prefill_mla(num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, False, target), "tir_attention_prefill_mla"),
593636
bb.add_func(_attention_decode_mla(num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, False, target), "tir_attention_decode_mla"),
594-
bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, v_head_dim, dtype, {}, target), "tir_attention_prefill_ragged_mla_normal"),
637+
bb.add_func(_attention_prefill_ragged_generic(num_key_value_heads, num_attention_heads, qk_rope_head_dim, v_head_dim, dtype, {}, target), "tir_attention_prefill_ragged_mla_normal"),
595638
bb.add_func(_attention_prefill_ragged_mla_absorbed(num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, target), "tir_attention_prefill_ragged_mla_absorbed"),
596639
bb.add_func(_merge_state_inplace(num_attention_heads, kv_lora_rank, dtype, target), "tir_attention_merge_state"),
597640
bb.add_func(llama_rope_with_position_map(10000, 1, qk_rope_head_dim, num_attention_heads, num_key_value_heads, dtype, {}, None), "tir_split_rotary"),

src/runtime/relax_vm/kv_state.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,14 @@ TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_absorbed")
9090
std::move(k_pe_data), std::move(o_data), attn_score_scaling_factor);
9191
});
9292

93+
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_normal")
94+
.set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
95+
double attn_score_scaling_factor, NDArray q_data, NDArray k_data, NDArray v_data, NDArray compressed_kv_data,
96+
NDArray k_pe_data, NDArray o_data) {
97+
kv_cache->MLANormal(layer_id, std::move(q_data), std::move(k_data), std::move(v_data), std::move(compressed_kv_data),
98+
std::move(k_pe_data), std::move(o_data), attn_score_scaling_factor);
99+
});
100+
93101
// RNN State methods
94102
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_get").set_body_method<RNNState>(&RNNStateObj::Get);
95103
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_set")

src/runtime/relax_vm/paged_kv_cache.cc

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2241,7 +2241,89 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
22412241
void MLANormal(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data,
22422242
NDArray compressed_kv_data, NDArray k_pe_data, NDArray o_data,
22432243
double attn_score_scaling_factor) {
2244-
// Todo(ruihang): implement it
2244+
// Part 1: Basic Checks and Setup.
2245+
int64_t local_layer_id = layer_id - layer_id_begin_offset_;
2246+
CHECK_GE(local_layer_id, 0);
2247+
CHECK_LT(local_layer_id, num_layers_);
2248+
NDArray pages = pages_[local_layer_id];
2249+
CHECK(q_data.DataType() == pages.DataType());
2250+
CHECK(k_data.DataType() == pages.DataType());
2251+
CHECK(v_data.DataType() == pages.DataType());
2252+
CHECK(compressed_kv_data.DataType() == pages.DataType());
2253+
CHECK(k_pe_data.DataType() == pages.DataType());
2254+
CHECK(o_data.DataType() == pages.DataType());
2255+
CHECK(attn_kinds_[layer_id] == AttnKind::kMLA);
2256+
2257+
// Expected shapes:
2258+
// q_data: (num_total_length, num_qo_heads, qk_head_dim)
2259+
// k_data: (num_total_length, num_qo_heads, qk_head_dim)
2260+
// v_data: (num_total_length, num_qo_heads, v_head_dim)
2261+
// compressed_kv_data: (num_total_length, qk_head_dim - qk_rope_head_dim)
2262+
// k_pe_data: (num_total_length, qk_rope_head_dim)
2263+
// o_data: (num_total_length, num_qo_heads, v_head_dim)
2264+
CHECK_EQ(q_data->ndim, 3);
2265+
CHECK_EQ(k_data->ndim, 3);
2266+
CHECK_EQ(v_data->ndim, 3);
2267+
CHECK_EQ(compressed_kv_data->ndim, 2);
2268+
CHECK_EQ(k_pe_data->ndim, 2);
2269+
CHECK_EQ(o_data->ndim, 3);
2270+
2271+
int64_t total_seq_length = 0;
2272+
for (int64_t i = 0; i < cur_batch_size_; ++i) {
2273+
total_seq_length += cur_append_lengths_[i];
2274+
}
2275+
CHECK_LE(q_data->shape[0], total_seq_length);
2276+
CHECK_LE(k_data->shape[0], total_seq_length);
2277+
CHECK_LE(v_data->shape[0], total_seq_length);
2278+
CHECK_LE(compressed_kv_data->shape[0], total_seq_length);
2279+
CHECK_LE(k_pe_data->shape[0], total_seq_length);
2280+
CHECK_EQ(k_pe_data->shape[1], qk_rope_head_dim_);
2281+
CHECK_LE(o_data->shape[0], total_seq_length);
2282+
CHECK_EQ(q_data->shape[1], num_qo_heads_);
2283+
CHECK_EQ(o_data->shape[1], num_qo_heads_);
2284+
CHECK_EQ(k_data->shape[1], num_qo_heads_);
2285+
CHECK_EQ(v_data->shape[1], num_qo_heads_);
2286+
CHECK_EQ(q_data->shape[2], qk_head_dim_);
2287+
CHECK_EQ(k_data->shape[2], qk_head_dim_);
2288+
CHECK_EQ(v_data->shape[2], v_head_dim_);
2289+
CHECK_EQ(o_data->shape[2], v_head_dim_);
2290+
2291+
2292+
// Part 2: Synchronize streams and update auxiliary data.
2293+
ComputeStreamWaitForCopyStream();
2294+
ICHECK(!dirty_aux_data_device_);
2295+
2296+
// Append k/v data to kv-cache if flag "append_before_attn" is set.
2297+
if (append_before_attn_) {
2298+
f_transpose_append_mla_(pages_[local_layer_id], compressed_kv_data, k_pe_data,
2299+
append_position_map_view_);
2300+
}
2301+
2302+
// Part 4: Call the ragged kernel.
2303+
// Here, we use f_mla_prefill_ragged_normal_, which is designed to work for both decode
2304+
// and normal prefill cases. Optionally, you could check a flag like `use_decode_kernel_[0]`
2305+
// to adjust parameters; here we assume the kernel internally supports both cases.
2306+
f_mla_prefill_ragged_normal_(q_data,
2307+
cur_append_length_indptr_view_,
2308+
k_data,
2309+
v_data,
2310+
cur_append_length_indptr_view_,
2311+
q_rope_position_map_view_,
2312+
k_ragged_rope_pos_offset_view_,
2313+
o_data, // output tensor
2314+
merged_attn_scores_view_,
2315+
/*causal=*/1,
2316+
RoPEMode::kNone, // Rope changes have already been applied before the kernel
2317+
0, // Rope param, not important
2318+
0, // Rope param, not important
2319+
attn_score_scaling_factor);
2320+
2321+
// Part 5: If appending is to occur after attention, call the append kernel.
2322+
if (!append_before_attn_) {
2323+
f_transpose_append_mla_(pages_[local_layer_id], compressed_kv_data, k_pe_data,
2324+
append_position_map_view_);
2325+
}
2326+
22452327
}
22462328

22472329
void LinearAttention(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data,

0 commit comments

Comments
 (0)