@@ -2241,7 +2241,89 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
22412241 void MLANormal (int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data,
22422242 NDArray compressed_kv_data, NDArray k_pe_data, NDArray o_data,
22432243 double attn_score_scaling_factor) {
2244- // Todo(ruihang): implement it
2244+ // Part 1: Basic Checks and Setup.
2245+ int64_t local_layer_id = layer_id - layer_id_begin_offset_;
2246+ CHECK_GE (local_layer_id, 0 );
2247+ CHECK_LT (local_layer_id, num_layers_);
2248+ NDArray pages = pages_[local_layer_id];
2249+ CHECK (q_data.DataType () == pages.DataType ());
2250+ CHECK (k_data.DataType () == pages.DataType ());
2251+ CHECK (v_data.DataType () == pages.DataType ());
2252+ CHECK (compressed_kv_data.DataType () == pages.DataType ());
2253+ CHECK (k_pe_data.DataType () == pages.DataType ());
2254+ CHECK (o_data.DataType () == pages.DataType ());
2255+ CHECK (attn_kinds_[layer_id] == AttnKind::kMLA );
2256+
2257+ // Expected shapes:
2258+ // q_data: (num_total_length, num_qo_heads, qk_head_dim)
2259+ // k_data: (num_total_length, num_qo_heads, qk_head_dim)
2260+ // v_data: (num_total_length, num_qo_heads, v_head_dim)
2261+ // compressed_kv_data: (num_total_length, qk_head_dim - qk_rope_head_dim)
2262+ // k_pe_data: (num_total_length, qk_rope_head_dim)
2263+ // o_data: (num_total_length, num_qo_heads, v_head_dim)
2264+ CHECK_EQ (q_data->ndim , 3 );
2265+ CHECK_EQ (k_data->ndim , 3 );
2266+ CHECK_EQ (v_data->ndim , 3 );
2267+ CHECK_EQ (compressed_kv_data->ndim , 2 );
2268+ CHECK_EQ (k_pe_data->ndim , 2 );
2269+ CHECK_EQ (o_data->ndim , 3 );
2270+
2271+ int64_t total_seq_length = 0 ;
2272+ for (int64_t i = 0 ; i < cur_batch_size_; ++i) {
2273+ total_seq_length += cur_append_lengths_[i];
2274+ }
2275+ CHECK_LE (q_data->shape [0 ], total_seq_length);
2276+ CHECK_LE (k_data->shape [0 ], total_seq_length);
2277+ CHECK_LE (v_data->shape [0 ], total_seq_length);
2278+ CHECK_LE (compressed_kv_data->shape [0 ], total_seq_length);
2279+ CHECK_LE (k_pe_data->shape [0 ], total_seq_length);
2280+ CHECK_EQ (k_pe_data->shape [1 ], qk_rope_head_dim_);
2281+ CHECK_LE (o_data->shape [0 ], total_seq_length);
2282+ CHECK_EQ (q_data->shape [1 ], num_qo_heads_);
2283+ CHECK_EQ (o_data->shape [1 ], num_qo_heads_);
2284+ CHECK_EQ (k_data->shape [1 ], num_qo_heads_);
2285+ CHECK_EQ (v_data->shape [1 ], num_qo_heads_);
2286+ CHECK_EQ (q_data->shape [2 ], qk_head_dim_);
2287+ CHECK_EQ (k_data->shape [2 ], qk_head_dim_);
2288+ CHECK_EQ (v_data->shape [2 ], v_head_dim_);
2289+ CHECK_EQ (o_data->shape [2 ], v_head_dim_);
2290+
2291+
2292+ // Part 2: Synchronize streams and update auxiliary data.
2293+ ComputeStreamWaitForCopyStream ();
2294+ ICHECK (!dirty_aux_data_device_);
2295+
2296+ // Append k/v data to kv-cache if flag "append_before_attn" is set.
2297+ if (append_before_attn_) {
2298+ f_transpose_append_mla_ (pages_[local_layer_id], compressed_kv_data, k_pe_data,
2299+ append_position_map_view_);
2300+ }
2301+
2302+ // Part 4: Call the ragged kernel.
2303+ // Here, we use f_mla_prefill_ragged_normal_, which is designed to work for both decode
2304+ // and normal prefill cases. Optionally, you could check a flag like `use_decode_kernel_[0]`
2305+ // to adjust parameters; here we assume the kernel internally supports both cases.
2306+ f_mla_prefill_ragged_normal_ (q_data,
2307+ cur_append_length_indptr_view_,
2308+ k_data,
2309+ v_data,
2310+ cur_append_length_indptr_view_,
2311+ q_rope_position_map_view_,
2312+ k_ragged_rope_pos_offset_view_,
2313+ o_data, // output tensor
2314+ merged_attn_scores_view_,
2315+ /* causal=*/ 1 ,
2316+ RoPEMode::kNone , // Rope changes have already been applied before the kernel
2317+ 0 , // Rope param, not important
2318+ 0 , // Rope param, not important
2319+ attn_score_scaling_factor);
2320+
2321+ // Part 5: If appending is to occur after attention, call the append kernel.
2322+ if (!append_before_attn_) {
2323+ f_transpose_append_mla_ (pages_[local_layer_id], compressed_kv_data, k_pe_data,
2324+ append_position_map_view_);
2325+ }
2326+
22452327 }
22462328
22472329 void LinearAttention (int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data,
0 commit comments