Skip to content

Commit ca4e474

Browse files
committed
add FLAGS instead max_partition_size
1 parent f445a7a commit ca4e474

24 files changed

+26
-116
lines changed

csrc/gpu/append_attention.cu

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
6161
const float out_linear_in_scale,
6262
const int encoder_block_shape_q,
6363
const int decoder_block_shape_q,
64-
const int max_partition_size,
65-
const int encoder_max_partition_size,
6664
const int speculate_max_draft_token_num,
6765
const bool causal,
6866
const bool speculate_decoder) {
@@ -209,8 +207,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
209207
quant_max_bound,
210208
quant_min_bound,
211209
out_linear_in_scale,
212-
max_partition_size,
213-
encoder_max_partition_size,
214210
speculate_max_draft_token_num,
215211
causal,
216212
false,
@@ -248,8 +244,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
248244
quant_max_bound,
249245
quant_min_bound,
250246
out_linear_in_scale,
251-
max_partition_size,
252-
encoder_max_partition_size,
253247
speculate_max_draft_token_num,
254248
causal,
255249
false,
@@ -292,8 +286,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
292286
quant_max_bound,
293287
quant_min_bound,
294288
out_linear_in_scale,
295-
max_partition_size,
296-
encoder_max_partition_size,
297289
speculate_max_draft_token_num,
298290
causal,
299291
false,
@@ -440,8 +432,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
440432
quant_max_bound,
441433
quant_min_bound,
442434
out_linear_in_scale,
443-
max_partition_size,
444-
encoder_max_partition_size,
445435
speculate_max_draft_token_num,
446436
causal,
447437
!speculate_decoder,
@@ -479,8 +469,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
479469
quant_max_bound,
480470
quant_min_bound,
481471
out_linear_in_scale,
482-
max_partition_size,
483-
encoder_max_partition_size,
484472
speculate_max_draft_token_num,
485473
causal,
486474
!speculate_decoder,
@@ -524,8 +512,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
524512
quant_max_bound,
525513
quant_min_bound,
526514
out_linear_in_scale,
527-
max_partition_size,
528-
encoder_max_partition_size,
529515
speculate_max_draft_token_num,
530516
causal,
531517
!speculate_decoder,
@@ -585,8 +571,6 @@ std::vector<paddle::Tensor> AppendAttention(
585571
const float out_linear_in_scale,
586572
const int encoder_block_shape_q,
587573
const int decoder_block_shape_q,
588-
const int max_partition_size,
589-
const int encoder_max_partition_size,
590574
const int speculate_max_draft_token_num,
591575
const bool causal,
592576
const bool speculate_decoder) {
@@ -650,8 +634,6 @@ std::vector<paddle::Tensor> AppendAttention(
650634
out_linear_in_scale,
651635
encoder_block_shape_q,
652636
decoder_block_shape_q,
653-
max_partition_size,
654-
encoder_max_partition_size,
655637
speculate_max_draft_token_num,
656638
causal,
657639
speculate_decoder);
@@ -700,8 +682,6 @@ std::vector<paddle::Tensor> AppendAttention(
700682
out_linear_in_scale,
701683
encoder_block_shape_q,
702684
decoder_block_shape_q,
703-
max_partition_size,
704-
encoder_max_partition_size,
705685
speculate_max_draft_token_num,
706686
causal,
707687
speculate_decoder);
@@ -751,8 +731,6 @@ std::vector<paddle::Tensor> AppendAttention(
751731
out_linear_in_scale,
752732
encoder_block_shape_q,
753733
decoder_block_shape_q,
754-
max_partition_size,
755-
encoder_max_partition_size,
756734
speculate_max_draft_token_num,
757735
causal,
758736
speculate_decoder);
@@ -800,8 +778,6 @@ std::vector<paddle::Tensor> AppendAttention(
800778
out_linear_in_scale,
801779
encoder_block_shape_q,
802780
decoder_block_shape_q,
803-
max_partition_size,
804-
encoder_max_partition_size,
805781
speculate_max_draft_token_num,
806782
causal,
807783
speculate_decoder);
@@ -905,8 +881,6 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
905881
const float out_linear_in_scale,
906882
const int encoder_block_shape_q,
907883
const int decoder_block_shape_q,
908-
const int max_partition_size,
909-
const int encoder_max_partition_size,
910884
const int speculate_max_draft_token_num,
911885
const bool causal,
912886
const bool speculate_decoder) {
@@ -985,8 +959,6 @@ PD_BUILD_OP(append_attention)
985959
"out_linear_in_scale: float",
986960
"encoder_block_shape_q: int",
987961
"decoder_block_shape_q: int",
988-
"max_partition_size: int",
989-
"encoder_max_partition_size: int",
990962
"speculate_max_draft_token_num: int",
991963
"causal: bool",
992964
"speculate_decoder: bool"})

csrc/gpu/append_attn/append_attention_c16_impl.cuh

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -786,8 +786,6 @@ void MultiQueryAppendAttention(
786786
const float quant_max_bound,
787787
const float quant_min_bound,
788788
const float in_scale,
789-
const int max_partition_size,
790-
const int encoder_max_partition_size,
791789
const int speculate_max_draft_token_num,
792790
const bool is_decoder,
793791
cudaStream_t &stream,
@@ -839,9 +837,9 @@ void MultiQueryAppendAttention(
839837
int sm_count;
840838
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
841839

842-
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
840+
uint32_t chunk_size = get_max_partition_size();
843841
if (!is_decoder) {
844-
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
842+
chunk_size = get_encoder_max_partition_size();
845843
}
846844
const int num_chunks = div_up(max_dec_len, chunk_size);
847845
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
@@ -1058,9 +1056,9 @@ void MultiQueryAppendAttention(
10581056
int sm_count;
10591057
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
10601058

1061-
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
1059+
uint32_t chunk_size = get_max_partition_size();
10621060
if (!is_decoder) {
1063-
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
1061+
chunk_size = get_encoder_max_partition_size();
10641062
}
10651063
const int num_chunks = div_up(max_dec_len, chunk_size);
10661064

@@ -1301,8 +1299,6 @@ void CascadeAppendAttentionC16Kernel(
13011299
const float quant_max_bound,
13021300
const float quant_min_bound,
13031301
const float in_scale,
1304-
const int max_partition_size,
1305-
const int encoder_max_partition_size,
13061302
const int speculate_max_draft_token_num,
13071303
const bool causal,
13081304
const bool is_decoder,
@@ -1363,8 +1359,6 @@ void CascadeAppendAttentionC16Kernel(
13631359
quant_max_bound,
13641360
quant_min_bound,
13651361
in_scale,
1366-
max_partition_size,
1367-
encoder_max_partition_size,
13681362
speculate_max_draft_token_num,
13691363
is_decoder,
13701364
stream,

csrc/gpu/append_attn/append_attention_c4_impl.cuh

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -973,8 +973,6 @@ void MultiQueryAppendC4Attention(
973973
const float quant_max_bound,
974974
const float quant_min_bound,
975975
const float in_scale,
976-
const int max_partition_size,
977-
const int encoder_max_partition_size,
978976
const int speculate_max_draft_token_num,
979977
const bool is_decoder,
980978
cudaStream_t &stream,
@@ -1036,9 +1034,9 @@ void MultiQueryAppendC4Attention(
10361034
const float ratio = static_cast<float>(num_blocks_need) /
10371035
static_cast<float>(num_blocks_per_wave);
10381036

1039-
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
1037+
uint32_t chunk_size = get_max_partition_size();
10401038
if (!is_decoder) {
1041-
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
1039+
chunk_size = get_encoder_max_partition_size();
10421040
}
10431041
const int num_chunks = div_up(max_dec_len, chunk_size);
10441042

@@ -1282,9 +1280,9 @@ void MultiQueryAppendC4Attention(
12821280
static_cast<float>(num_blocks_per_wave);
12831281

12841282

1285-
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
1283+
static uint32_t chunk_size = get_max_partition_size();
12861284
if (!is_decoder) {
1287-
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
1285+
chunk_size = get_encoder_max_partition_size();
12881286
}
12891287
const int num_chunks = div_up(max_dec_len, chunk_size);
12901288
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
@@ -1538,8 +1536,6 @@ void CascadeAppendAttentionC4Kernel(
15381536
const float quant_max_bound,
15391537
const float quant_min_bound,
15401538
const float in_scale,
1541-
const int max_partition_size,
1542-
const int encoder_max_partition_size,
15431539
const int speculate_max_draft_token_num,
15441540
const bool causal,
15451541
const bool is_decoder,
@@ -1604,8 +1600,6 @@ void CascadeAppendAttentionC4Kernel(
16041600
quant_max_bound,
16051601
quant_min_bound,
16061602
in_scale,
1607-
max_partition_size,
1608-
encoder_max_partition_size,
16091603
speculate_max_draft_token_num,
16101604
is_decoder,
16111605
stream,

csrc/gpu/append_attn/append_attention_c8_impl.cuh

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -860,8 +860,6 @@ void MultiQueryAppendC8Attention(
860860
const float quant_max_bound,
861861
const float quant_min_bound,
862862
const float in_scale,
863-
const int max_partition_size,
864-
const int encoder_max_partition_size,
865863
const int speculate_max_draft_token_num,
866864
const bool is_decoder,
867865
cudaStream_t &stream,
@@ -914,9 +912,9 @@ void MultiQueryAppendC8Attention(
914912
const int dev_id = 0;
915913
int sm_count;
916914
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
917-
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
915+
uint32_t chunk_size = get_max_partition_size();
918916
if (!is_decoder) {
919-
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
917+
chunk_size = get_encoder_max_partition_size();
920918
}
921919
const int num_chunks = div_up(max_dec_len, chunk_size);
922920
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
@@ -1136,9 +1134,9 @@ void MultiQueryAppendC8Attention(
11361134
const int dev_id = 0;
11371135
int sm_count;
11381136
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
1139-
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
1137+
uint32_t chunk_size = get_max_partition_size();
11401138
if (!is_decoder) {
1141-
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
1139+
chunk_size = get_encoder_max_partition_size();
11421140
}
11431141

11441142
const int num_chunks = div_up(max_dec_len, chunk_size);
@@ -1377,8 +1375,6 @@ void CascadeAppendAttentionC8Kernel(
13771375
const float quant_max_bound,
13781376
const float quant_min_bound,
13791377
const float in_scale,
1380-
const int max_partition_size,
1381-
const int encoder_max_partition_size,
13821378
const int speculate_max_draft_token_num,
13831379
const bool causal,
13841380
const bool is_decoder,
@@ -1441,8 +1437,6 @@ void CascadeAppendAttentionC8Kernel(
14411437
quant_max_bound,
14421438
quant_min_bound,
14431439
in_scale,
1444-
max_partition_size,
1445-
encoder_max_partition_size,
14461440
speculate_max_draft_token_num,
14471441
is_decoder,
14481442
stream,

csrc/gpu/append_attn/append_attention_kernel.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ void CascadeAppendAttentionC16Kernel(
5252
const float quant_max_bound,
5353
const float quant_min_bound,
5454
const float in_scale,
55-
const int max_partition_size,
56-
const int encoder_max_partition_size,
5755
const int speculate_max_draft_token_num,
5856
const bool causal,
5957
const bool is_decoder,
@@ -97,8 +95,6 @@ void CascadeAppendAttentionC8Kernel(
9795
const float quant_max_bound,
9896
const float quant_min_bound,
9997
const float in_scale,
100-
const int max_partition_size,
101-
const int encoder_max_partition_size,
10298
const int speculate_max_draft_token_num,
10399
const bool causal,
104100
const bool is_decoder,
@@ -142,8 +138,6 @@ void CascadeAppendAttentionC4Kernel(
142138
const float quant_max_bound,
143139
const float quant_min_bound,
144140
const float in_scale,
145-
const int max_partition_size,
146-
const int encoder_max_partition_size,
147141
const int speculate_max_draft_token_num,
148142
const bool causal,
149143
const bool is_decoder,
@@ -188,8 +182,6 @@ void CascadeAppendAttentionKernel(
188182
const float quant_max_bound,
189183
const float quant_min_bound,
190184
const float in_scale,
191-
const int max_partition_size,
192-
const int encoder_max_partition_size,
193185
const int speculate_max_draft_token_num,
194186
const bool causal,
195187
const bool is_decoder,
@@ -223,8 +215,6 @@ void CascadeAppendAttentionKernel(
223215
quant_max_bound,
224216
quant_min_bound,
225217
in_scale,
226-
max_partition_size,
227-
encoder_max_partition_size,
228218
speculate_max_draft_token_num,
229219
causal,
230220
is_decoder,
@@ -258,8 +248,6 @@ void CascadeAppendAttentionKernel(
258248
quant_max_bound,
259249
quant_min_bound,
260250
in_scale,
261-
max_partition_size,
262-
encoder_max_partition_size,
263251
speculate_max_draft_token_num,
264252
causal,
265253
is_decoder,
@@ -293,8 +281,6 @@ void CascadeAppendAttentionKernel(
293281
quant_max_bound,
294282
quant_min_bound,
295283
in_scale,
296-
max_partition_size,
297-
encoder_max_partition_size,
298284
speculate_max_draft_token_num,
299285
causal,
300286
is_decoder,
@@ -307,3 +293,17 @@ void CascadeAppendAttentionKernel(
307293
"cache_int4_zp]");
308294
}
309295
}
296+
297+
inline uint32_t get_max_partition_size() {
298+
static const char* max_partition_size_env = std::getenv("FLAGS_cascade_attention_max_partition_size");
299+
static const uint32_t max_partition_size =
300+
max_partition_size_env == nullptr ? 128 : std::stoul(std::string(max_partition_size_env));
301+
return max_partition_size;
302+
}
303+
304+
inline uint32_t get_encoder_max_partition_size() {
305+
static const char* encoder_max_partition_size_env = std::getenv("FLAGS_cascade_encoder_attention_max_partition_size");
306+
static const uint32_t encoder_max_partition_size =
307+
encoder_max_partition_size_env == nullptr ? 32768 : std::stoul(std::string(encoder_max_partition_size_env));
308+
return encoder_max_partition_size;
309+
}

csrc/gpu/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,6 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16
4949
const float quant_max_bound,
5050
const float quant_min_bound,
5151
const float in_scale,
52-
const int max_partition_size,
53-
const int encoder_max_partition_size,
5452
const int speculate_max_draft_token_num,
5553
const bool causal,
5654
const bool is_decoder,

csrc/gpu/append_attn/template_instantiation/append_attention_c16_bfloat16_fp8_kernel.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e
4848
const float quant_max_bound,
4949
const float quant_min_bound,
5050
const float in_scale,
51-
const int max_partition_size,
52-
const int encoder_max_partition_size,
5351
const int speculate_max_draft_token_num,
5452
const bool causal,
5553
const bool is_decoder,

csrc/gpu/append_attn/template_instantiation/append_attention_c16_bfloat16_int8_kernel.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
4848
const float quant_max_bound,
4949
const float quant_min_bound,
5050
const float in_scale,
51-
const int max_partition_size,
52-
const int encoder_max_partition_size,
5351
const int speculate_max_draft_token_num,
5452
const bool causal,
5553
const bool is_decoder,

csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
4848
const float quant_max_bound,
4949
const float quant_min_bound,
5050
const float in_scale,
51-
const int max_partition_size,
52-
const int encoder_max_partition_size,
5351
const int speculate_max_draft_token_num,
5452
const bool causal,
5553
const bool is_decoder,

0 commit comments

Comments
 (0)