Skip to content

Commit 0bc48b6

Browse files
luanyunduWSEmma
andauthored
Fix the bug that total expert num greater than 256 or local expert num is less than 8 (#364)
Co-authored-by: WSEmma <wusemma@163.com>
1 parent 98bc6f6 commit 0bc48b6

File tree

2 files changed

+55
-53
lines changed

2 files changed

+55
-53
lines changed

csrc/deepep/ops/op_host/cam_moe_combine_normal_tiling.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ static bool CheckTensorShape(gert::TilingContext *context, CamMoeCombineNormalTi
398398
int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0);
399399
int64_t xDim1 = xStorageShape->GetStorageShape().GetDim(1);
400400
OP_TILING_CHECK(xDim0 != topkWeightsDim0,
401-
OP_LOGE(nodeName, "x's dim0 is greater than bs, bs = %ld, x's dim0 = %ld", topkWeightsDim0, xDim0),
401+
OP_LOGE(nodeName, "x's dim0 not equal to bs, bs = %ld, x's dim0 = %ld", topkWeightsDim0, xDim0),
402402
return false);
403403
OP_TILING_CHECK(xDim1 != recvXDim1,
404404
OP_LOGE(nodeName, "x's dim1 not equal to h, x's dim1 = %ld, h = %ld", xDim1, recvXDim1),

csrc/deepep/ops/op_kernel/notify_dispatch.h

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ class NotifyDispatch
5151
// Synchronization flag occupies length
5252
constexpr static int64_t FLAG_UNIT_INT_NUM = 4;
5353
constexpr static int64_t MAGIC_MASK = ~((1LL << 32) - 1);
54-
constexpr static int32_t BATCH_ROUND = 32;
54+
constexpr static int32_t EXPERT_NORMAL_NUM = 256;
55+
constexpr static int32_t BATCH_ROUND = 16;
5556

5657
public:
5758
__aicore__ inline NotifyDispatch(int rank, int rankSize, uint32_t extraFlag)
@@ -71,7 +72,7 @@ class NotifyDispatch
7172
recvOffset_ = recvOffset;
7273
maxBs_ = maxBs;
7374
recvTokensPerExpert_ = recvTokensPerExpert;
74-
batchRounds = BATCH_ROUND;
75+
batchRounds = numExperts > EXPERT_NORMAL_NUM ? BATCH_ROUND : BATCH_ROUND * 2;
7576
tokenPerExpertDataAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
7677
sendDataOffsetAlignLen = Ceil(batchRounds * numExperts * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
7778
sendDataAlignLen = Ceil(batchRounds * numExperts * sendPerGroup * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
@@ -339,12 +340,14 @@ class NotifyDispatch
339340
uint32_t singleRankTotalElemCount = round * numLocalExperts * sendPerGroup;
340341
uint32_t singleRankBatchElemCount = currentBatchRounds * numLocalExperts * sendPerGroup;
341342
uint32_t singleRankBatchDataLen = singleRankBatchElemCount * sizeof(int32_t);
343+
uint32_t alignedDataLen = Ceil(singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
344+
uint32_t strideElem = alignedDataLen / sizeof(int32_t); // 目标地址也改变,使用对齐后的地址
342345
DataCopyExtParams recvDataParams = {1U, static_cast<uint32_t>(singleRankBatchDataLen), 0, 0, 0};
343346
DataCopyPadExtParams<T> DataCopyPadExtParams{false, 0U, 0U, 0U};
344347

345348
for (uint32_t i = 0; i < rankSize; i++) {
346349
uint32_t srcOffset = i * singleRankTotalElemCount + rStart * numLocalExperts * sendPerGroup;
347-
uint32_t dstOffset = i * singleRankBatchElemCount;
350+
uint32_t dstOffset = i * strideElem;
348351
// 搬运该Rank下的 currentBatchRounds 数据
349352
DataCopyPad(recvDataTensor[dstOffset], recvDataOutputGt[srcOffset], recvDataParams, DataCopyPadExtParams);
350353
}
@@ -357,14 +360,18 @@ class NotifyDispatch
357360
Duplicate<T>(recvCountTensor, 0, sendCountAlignLen / sizeof(int32_t)); // V
358361

359362
SyncFunc<AscendC::HardEvent::V_S>();
363+
uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof(int32_t);
364+
uint32_t alignedDataLen = Ceil(singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
365+
uint32_t strideElem = alignedDataLen / sizeof(int32_t);
360366
uint32_t computeNum = currentBatchRounds * numLocalExperts;
361367
for (uint32_t r = 0; r < currentBatchRounds; ++r) {
362368
uint32_t computeNumIn = r * numLocalExperts;
363369
uint32_t computeNumOut = r * numExperts;
364370
for (uint32_t expId = 0; expId < numLocalExperts; ++expId) {
365371
for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) {
366372
uint32_t index = expId * rankSize + srcRank;
367-
uint32_t pair_idx = sendPerGroup * (srcRank * computeNum + computeNumIn + expId);
373+
uint32_t offsetInRank = sendPerGroup * (computeNumIn + expId);
374+
uint32_t pair_idx = srcRank * strideElem + offsetInRank;
368375
recvCountTensor(computeNumOut + index) = recvDataTensor(pair_idx);
369376
}
370377
}
@@ -376,56 +383,34 @@ class NotifyDispatch
376383
sendOffsetTensor = sendOffsetBuf.Get<T>();
377384
Duplicate<T>(sendOffsetTensor, 0, sendCountAlignLen / sizeof(int32_t));
378385
SyncFunc<AscendC::HardEvent::V_S>();
386+
uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof(int32_t);
387+
uint32_t alignedDataLen = Ceil(singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
388+
uint32_t strideElem = alignedDataLen / sizeof(int32_t);
379389
uint32_t computeNum = currentBatchRounds * numLocalExperts;
380390
for (uint32_t r = 0; r < currentBatchRounds; ++r) {
381391
uint32_t computeNumIn = r * numLocalExperts;
382392
uint32_t computeNumOut = r * numExperts;
383393
for (uint32_t expId = 0; expId < numLocalExperts; ++expId) {
384394
for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) {
385395
uint32_t index = expId * rankSize + srcRank;
386-
uint32_t pair_idx = sendPerGroup * (srcRank * computeNum + computeNumIn + expId);
396+
uint32_t offsetInRank = sendPerGroup * (computeNumIn + expId);
397+
uint32_t pair_idx = srcRank * strideElem + offsetInRank;
387398
sendOffsetTensor(computeNumOut + index) = recvDataTensor(pair_idx + 1);
388399
}
389400
}
390401
}
391402
}
392403

393-
__aicore__ inline void ReorderSendTokensPerRankOutput()
394-
{
395-
pipe.InitBuffer(sendTokensPerRankBuf, sendTokensPerRankAlignLen);
396-
pipe.InitBuffer(seenRoundBuf, sendTokensPerRankAlignLen);
397-
sendTokensPerRankTensor = sendTokensPerRankBuf.Get<int32_t>();
398-
seenRoundTensor = seenRoundBuf.Get<int32_t>();
399-
Duplicate<int32_t>(sendTokensPerRankTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t));
400-
SyncFunc<AscendC::HardEvent::V_S>();
401-
SyncFunc<AscendC::HardEvent::MTE2_S>();
402-
for (uint32_t r = 0; r < round; ++r) {
403-
Duplicate<int32_t>(seenRoundTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t));
404-
SyncFunc<AscendC::HardEvent::V_S>();
405-
for (uint32_t expId = 0; expId < numLocalExperts; ++expId) {
406-
for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) {
407-
uint32_t index = expId * rankSize + srcRank;
408-
uint32_t pair_idx =
409-
sendPerGroup * (srcRank * numLocalExperts * round + r * numLocalExperts + expId);
410-
if (!seenRoundTensor(srcRank)) {
411-
sendTokensPerRankTensor(srcRank) += recvDataTensor(pair_idx + 2);
412-
seenRoundTensor(srcRank) = 1;
413-
}
414-
}
415-
}
416-
SyncFunc<AscendC::HardEvent::S_V>();
417-
}
418-
}
419-
420404
__aicore__ inline void BuildTotalRecvTokens()
421405
{
422406
if (blockIdx != TOTAL_CNT_CORE) {
423407
return;
424408
}
425409
int32_t sumVal = 0;
426-
427-
recvDataAlignLen =
428-
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
410+
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
411+
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
412+
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
413+
recvDataAlignLen = rankSize * singleRankAlignLen;
429414
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
430415
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
431416
pipe.InitBuffer(recvCountBuf, sendCountAlignLen);
@@ -467,8 +452,10 @@ class NotifyDispatch
467452
if (blockIdx != RECV_COUNT_CORE) {
468453
return;
469454
}
470-
recvDataAlignLen =
471-
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
455+
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
456+
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
457+
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
458+
recvDataAlignLen = rankSize * singleRankAlignLen;
472459
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
473460
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
474461
pipe.InitBuffer(recvCountBuf, sendCountAlignLen);
@@ -505,8 +492,10 @@ class NotifyDispatch
505492
if (blockIdx != RECV_OFFSET_CORE) {
506493
return;
507494
}
508-
recvDataAlignLen =
509-
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
495+
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
496+
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
497+
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
498+
recvDataAlignLen = rankSize * singleRankAlignLen;
510499
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
511500
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
512501
pipe.InitBuffer(sendOffsetBuf, sendCountAlignLen);
@@ -535,8 +524,10 @@ class NotifyDispatch
535524
if (blockIdx != MAX_BS_CORE) {
536525
return;
537526
}
538-
recvDataAlignLen =
539-
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
527+
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
528+
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
529+
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
530+
recvDataAlignLen = rankSize * singleRankAlignLen;
540531
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
541532

542533
pipe.InitBuffer(sendTokensPerRankBuf, sendTokensPerRankAlignLen);
@@ -549,16 +540,19 @@ class NotifyDispatch
549540
SyncFunc<AscendC::HardEvent::MTE2_S>();
550541
for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) {
551542
uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds;
552-
543+
uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof(int32_t);
544+
uint32_t alignedDataLen = Ceil(singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
545+
uint32_t strideElem = alignedDataLen / sizeof(int32_t);
553546
ReorderOutput(rStart, currentBatchRounds);
554547
SyncFunc<AscendC::HardEvent::MTE2_S>();
555548
for (uint32_t r = 0; r < currentBatchRounds; ++r) {
549+
uint32_t offsetInRound = r * numLocalExperts;
556550
Duplicate<int32_t>(seenRoundTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t));
557551
SyncFunc<AscendC::HardEvent::V_S>();
558552
for (uint32_t expId = 0; expId < numLocalExperts; ++expId) {
559553
for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) {
560-
uint32_t pair_idx = sendPerGroup * (srcRank * numLocalExperts * currentBatchRounds +
561-
r * numLocalExperts + expId);
554+
uint32_t offsetInRank = sendPerGroup * (offsetInRound + expId);
555+
uint32_t pair_idx = srcRank * strideElem + offsetInRank;
562556
if (!seenRoundTensor(srcRank)) {
563557
sendTokensPerRankTensor(srcRank) += recvDataTensor(pair_idx + 2);
564558
seenRoundTensor(srcRank) = 1;
@@ -585,8 +579,10 @@ class NotifyDispatch
585579
if (blockIdx != RECV_TOKEN_PER_EXP_CORE) {
586580
return;
587581
}
588-
recvDataAlignLen =
589-
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
582+
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
583+
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
584+
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
585+
recvDataAlignLen = rankSize * singleRankAlignLen;
590586
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
591587
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
592588
pipe.InitBuffer(recvCountBuf, sendCountAlignLen);
@@ -630,8 +626,10 @@ class NotifyDispatch
630626
return;
631627
}
632628

633-
recvDataAlignLen =
634-
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
629+
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
630+
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
631+
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
632+
recvDataAlignLen = rankSize * singleRankAlignLen;
635633
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
636634
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
637635
pipe.InitBuffer(recvCountBuf, sendCountAlignLen);
@@ -676,8 +674,10 @@ class NotifyDispatch
676674
if (blockIdx != SRC_RANK_EXP_OFFSET_CORE) {
677675
return;
678676
}
679-
recvDataAlignLen =
680-
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
677+
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
678+
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
679+
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
680+
recvDataAlignLen = rankSize * singleRankAlignLen;
681681
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
682682
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // 32Kb
683683
pipe.InitBuffer(recvCountBuf, sendCountAlignLen);
@@ -726,8 +726,10 @@ class NotifyDispatch
726726
if (blockIdx != R_IN_SRCRANK_OFFSET_CORE) {
727727
return;
728728
}
729-
recvDataAlignLen =
730-
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
729+
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
730+
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
731+
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
732+
recvDataAlignLen = rankSize * singleRankAlignLen;
731733
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
732734
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // 32Kb
733735
pipe.InitBuffer(recvCountBuf, sendCountAlignLen);

0 commit comments

Comments
 (0)