@@ -51,8 +51,7 @@ class NotifyDispatch
5151 // Synchronization flag occupies length
5252 constexpr static int64_t FLAG_UNIT_INT_NUM = 4 ;
5353 constexpr static int64_t MAGIC_MASK = ~((1LL << 32 ) - 1 );
54- constexpr static int32_t EXPERT_NORMAL_NUM = 256 ;
55- constexpr static int32_t BATCH_ROUND = 16 ;
54+ constexpr static int32_t BATCH_ROUND = 32 ;
5655
5756public:
5857 __aicore__ inline NotifyDispatch (int rank, int rankSize, uint32_t extraFlag)
@@ -72,7 +71,7 @@ class NotifyDispatch
7271 recvOffset_ = recvOffset;
7372 maxBs_ = maxBs;
7473 recvTokensPerExpert_ = recvTokensPerExpert;
75- batchRounds = numExperts > EXPERT_NORMAL_NUM ? BATCH_ROUND : BATCH_ROUND * 2 ;
74+ batchRounds = BATCH_ROUND;
7675 tokenPerExpertDataAlignLen = Ceil (batchRounds * numExperts * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
7776 sendDataOffsetAlignLen = Ceil (batchRounds * numExperts * sizeof (T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
7877 sendDataAlignLen = Ceil (batchRounds * numExperts * sendPerGroup * sizeof (T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
@@ -340,14 +339,12 @@ class NotifyDispatch
340339 uint32_t singleRankTotalElemCount = round * numLocalExperts * sendPerGroup;
341340 uint32_t singleRankBatchElemCount = currentBatchRounds * numLocalExperts * sendPerGroup;
342341 uint32_t singleRankBatchDataLen = singleRankBatchElemCount * sizeof (int32_t );
343- uint32_t alignedDataLen = Ceil (singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
344- uint32_t strideElem = alignedDataLen / sizeof (int32_t ); // 目标地址也改变,使用对齐后的地址
345342 DataCopyExtParams recvDataParams = {1U , static_cast <uint32_t >(singleRankBatchDataLen), 0 , 0 , 0 };
346343 DataCopyPadExtParams<T> DataCopyPadExtParams{false , 0U , 0U , 0U };
347344
348345 for (uint32_t i = 0 ; i < rankSize; i++) {
349346 uint32_t srcOffset = i * singleRankTotalElemCount + rStart * numLocalExperts * sendPerGroup;
350- uint32_t dstOffset = i * strideElem ;
347+ uint32_t dstOffset = i * singleRankBatchElemCount ;
351348 // 搬运该Rank下的 currentBatchRounds 数据
352349 DataCopyPad (recvDataTensor[dstOffset], recvDataOutputGt[srcOffset], recvDataParams, DataCopyPadExtParams);
353350 }
@@ -360,18 +357,14 @@ class NotifyDispatch
360357 Duplicate<T>(recvCountTensor, 0 , sendCountAlignLen / sizeof (int32_t )); // V
361358
362359 SyncFunc<AscendC::HardEvent::V_S>();
363- uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof (int32_t );
364- uint32_t alignedDataLen = Ceil (singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
365- uint32_t strideElem = alignedDataLen / sizeof (int32_t );
366360 uint32_t computeNum = currentBatchRounds * numLocalExperts;
367361 for (uint32_t r = 0 ; r < currentBatchRounds; ++r) {
368362 uint32_t computeNumIn = r * numLocalExperts;
369363 uint32_t computeNumOut = r * numExperts;
370364 for (uint32_t expId = 0 ; expId < numLocalExperts; ++expId) {
371365 for (uint32_t srcRank = 0 ; srcRank < rankSize; ++srcRank) {
372366 uint32_t index = expId * rankSize + srcRank;
373- uint32_t offsetInRank = sendPerGroup * (computeNumIn + expId);
374- uint32_t pair_idx = srcRank * strideElem + offsetInRank;
367+ uint32_t pair_idx = sendPerGroup * (srcRank * computeNum + computeNumIn + expId);
375368 recvCountTensor (computeNumOut + index) = recvDataTensor (pair_idx);
376369 }
377370 }
@@ -383,34 +376,56 @@ class NotifyDispatch
383376 sendOffsetTensor = sendOffsetBuf.Get <T>();
384377 Duplicate<T>(sendOffsetTensor, 0 , sendCountAlignLen / sizeof (int32_t ));
385378 SyncFunc<AscendC::HardEvent::V_S>();
386- uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof (int32_t );
387- uint32_t alignedDataLen = Ceil (singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
388- uint32_t strideElem = alignedDataLen / sizeof (int32_t );
389379 uint32_t computeNum = currentBatchRounds * numLocalExperts;
390380 for (uint32_t r = 0 ; r < currentBatchRounds; ++r) {
391381 uint32_t computeNumIn = r * numLocalExperts;
392382 uint32_t computeNumOut = r * numExperts;
393383 for (uint32_t expId = 0 ; expId < numLocalExperts; ++expId) {
394384 for (uint32_t srcRank = 0 ; srcRank < rankSize; ++srcRank) {
395385 uint32_t index = expId * rankSize + srcRank;
396- uint32_t offsetInRank = sendPerGroup * (computeNumIn + expId);
397- uint32_t pair_idx = srcRank * strideElem + offsetInRank;
386+ uint32_t pair_idx = sendPerGroup * (srcRank * computeNum + computeNumIn + expId);
398387 sendOffsetTensor (computeNumOut + index) = recvDataTensor (pair_idx + 1 );
399388 }
400389 }
401390 }
402391 }
403392
393+ __aicore__ inline void ReorderSendTokensPerRankOutput ()
394+ {
395+ pipe.InitBuffer (sendTokensPerRankBuf, sendTokensPerRankAlignLen);
396+ pipe.InitBuffer (seenRoundBuf, sendTokensPerRankAlignLen);
397+ sendTokensPerRankTensor = sendTokensPerRankBuf.Get <int32_t >();
398+ seenRoundTensor = seenRoundBuf.Get <int32_t >();
399+ Duplicate<int32_t >(sendTokensPerRankTensor, 0 , sendTokensPerRankAlignLen / sizeof (int32_t ));
400+ SyncFunc<AscendC::HardEvent::V_S>();
401+ SyncFunc<AscendC::HardEvent::MTE2_S>();
402+ for (uint32_t r = 0 ; r < round; ++r) {
403+ Duplicate<int32_t >(seenRoundTensor, 0 , sendTokensPerRankAlignLen / sizeof (int32_t ));
404+ SyncFunc<AscendC::HardEvent::V_S>();
405+ for (uint32_t expId = 0 ; expId < numLocalExperts; ++expId) {
406+ for (uint32_t srcRank = 0 ; srcRank < rankSize; ++srcRank) {
407+ uint32_t index = expId * rankSize + srcRank;
408+ uint32_t pair_idx =
409+ sendPerGroup * (srcRank * numLocalExperts * round + r * numLocalExperts + expId);
410+ if (!seenRoundTensor (srcRank)) {
411+ sendTokensPerRankTensor (srcRank) += recvDataTensor (pair_idx + 2 );
412+ seenRoundTensor (srcRank) = 1 ;
413+ }
414+ }
415+ }
416+ SyncFunc<AscendC::HardEvent::S_V>();
417+ }
418+ }
419+
404420 __aicore__ inline void BuildTotalRecvTokens ()
405421 {
406422 if (blockIdx != TOTAL_CNT_CORE) {
407423 return ;
408424 }
409425 int32_t sumVal = 0 ;
410- uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
411- uint32_t singleRankMaxLen = singleRankMaxElem * sizeof (int32_t );
412- uint32_t singleRankAlignLen = Ceil (singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
413- recvDataAlignLen = rankSize * singleRankAlignLen;
426+
427+ recvDataAlignLen =
428+ Ceil (batchRounds * numExperts * sendPerGroup * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
414429 pipe.InitBuffer (recvDataBuf, recvDataAlignLen);
415430 sendCountAlignLen = Ceil (batchRounds * numExperts * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
416431 pipe.InitBuffer (recvCountBuf, sendCountAlignLen);
@@ -452,10 +467,8 @@ class NotifyDispatch
452467 if (blockIdx != RECV_COUNT_CORE) {
453468 return ;
454469 }
455- uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
456- uint32_t singleRankMaxLen = singleRankMaxElem * sizeof (int32_t );
457- uint32_t singleRankAlignLen = Ceil (singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
458- recvDataAlignLen = rankSize * singleRankAlignLen;
470+ recvDataAlignLen =
471+ Ceil (batchRounds * numExperts * sendPerGroup * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
459472 pipe.InitBuffer (recvDataBuf, recvDataAlignLen);
460473 sendCountAlignLen = Ceil (batchRounds * numExperts * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
461474 pipe.InitBuffer (recvCountBuf, sendCountAlignLen);
@@ -492,10 +505,8 @@ class NotifyDispatch
492505 if (blockIdx != RECV_OFFSET_CORE) {
493506 return ;
494507 }
495- uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
496- uint32_t singleRankMaxLen = singleRankMaxElem * sizeof (int32_t );
497- uint32_t singleRankAlignLen = Ceil (singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
498- recvDataAlignLen = rankSize * singleRankAlignLen;
508+ recvDataAlignLen =
509+ Ceil (batchRounds * numExperts * sendPerGroup * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
499510 pipe.InitBuffer (recvDataBuf, recvDataAlignLen);
500511 sendCountAlignLen = Ceil (batchRounds * numExperts * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
501512 pipe.InitBuffer (sendOffsetBuf, sendCountAlignLen);
@@ -524,10 +535,8 @@ class NotifyDispatch
524535 if (blockIdx != MAX_BS_CORE) {
525536 return ;
526537 }
527- uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
528- uint32_t singleRankMaxLen = singleRankMaxElem * sizeof (int32_t );
529- uint32_t singleRankAlignLen = Ceil (singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
530- recvDataAlignLen = rankSize * singleRankAlignLen;
538+ recvDataAlignLen =
539+ Ceil (batchRounds * numExperts * sendPerGroup * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
531540 pipe.InitBuffer (recvDataBuf, recvDataAlignLen);
532541
533542 pipe.InitBuffer (sendTokensPerRankBuf, sendTokensPerRankAlignLen);
@@ -540,19 +549,16 @@ class NotifyDispatch
540549 SyncFunc<AscendC::HardEvent::MTE2_S>();
541550 for (uint32_t rStart = 0 ; rStart < round; rStart += batchRounds) {
542551 uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds;
543- uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof (int32_t );
544- uint32_t alignedDataLen = Ceil (singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
545- uint32_t strideElem = alignedDataLen / sizeof (int32_t );
552+
546553 ReorderOutput (rStart, currentBatchRounds);
547554 SyncFunc<AscendC::HardEvent::MTE2_S>();
548555 for (uint32_t r = 0 ; r < currentBatchRounds; ++r) {
549- uint32_t offsetInRound = r * numLocalExperts;
550556 Duplicate<int32_t >(seenRoundTensor, 0 , sendTokensPerRankAlignLen / sizeof (int32_t ));
551557 SyncFunc<AscendC::HardEvent::V_S>();
552558 for (uint32_t expId = 0 ; expId < numLocalExperts; ++expId) {
553559 for (uint32_t srcRank = 0 ; srcRank < rankSize; ++srcRank) {
554- uint32_t offsetInRank = sendPerGroup * (offsetInRound + expId);
555- uint32_t pair_idx = srcRank * strideElem + offsetInRank ;
560+ uint32_t pair_idx = sendPerGroup * (srcRank * numLocalExperts * currentBatchRounds +
561+ r * numLocalExperts + expId) ;
556562 if (!seenRoundTensor (srcRank)) {
557563 sendTokensPerRankTensor (srcRank) += recvDataTensor (pair_idx + 2 );
558564 seenRoundTensor (srcRank) = 1 ;
@@ -579,10 +585,8 @@ class NotifyDispatch
579585 if (blockIdx != RECV_TOKEN_PER_EXP_CORE) {
580586 return ;
581587 }
582- uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
583- uint32_t singleRankMaxLen = singleRankMaxElem * sizeof (int32_t );
584- uint32_t singleRankAlignLen = Ceil (singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
585- recvDataAlignLen = rankSize * singleRankAlignLen;
588+ recvDataAlignLen =
589+ Ceil (batchRounds * numExperts * sendPerGroup * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
586590 pipe.InitBuffer (recvDataBuf, recvDataAlignLen);
587591 sendCountAlignLen = Ceil (batchRounds * numExperts * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
588592 pipe.InitBuffer (recvCountBuf, sendCountAlignLen);
@@ -626,10 +630,8 @@ class NotifyDispatch
626630 return ;
627631 }
628632
629- uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
630- uint32_t singleRankMaxLen = singleRankMaxElem * sizeof (int32_t );
631- uint32_t singleRankAlignLen = Ceil (singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
632- recvDataAlignLen = rankSize * singleRankAlignLen;
633+ recvDataAlignLen =
634+ Ceil (batchRounds * numExperts * sendPerGroup * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
633635 pipe.InitBuffer (recvDataBuf, recvDataAlignLen);
634636 sendCountAlignLen = Ceil (batchRounds * numExperts * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
635637 pipe.InitBuffer (recvCountBuf, sendCountAlignLen);
@@ -674,10 +676,8 @@ class NotifyDispatch
674676 if (blockIdx != SRC_RANK_EXP_OFFSET_CORE) {
675677 return ;
676678 }
677- uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
678- uint32_t singleRankMaxLen = singleRankMaxElem * sizeof (int32_t );
679- uint32_t singleRankAlignLen = Ceil (singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
680- recvDataAlignLen = rankSize * singleRankAlignLen;
679+ recvDataAlignLen =
680+ Ceil (batchRounds * numExperts * sendPerGroup * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
681681 pipe.InitBuffer (recvDataBuf, recvDataAlignLen);
682682 sendCountAlignLen = Ceil (batchRounds * numExperts * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // 32Kb
683683 pipe.InitBuffer (recvCountBuf, sendCountAlignLen);
@@ -726,10 +726,8 @@ class NotifyDispatch
726726 if (blockIdx != R_IN_SRCRANK_OFFSET_CORE) {
727727 return ;
728728 }
729- uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
730- uint32_t singleRankMaxLen = singleRankMaxElem * sizeof (int32_t );
731- uint32_t singleRankAlignLen = Ceil (singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
732- recvDataAlignLen = rankSize * singleRankAlignLen;
729+ recvDataAlignLen =
730+ Ceil (batchRounds * numExperts * sendPerGroup * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
733731 pipe.InitBuffer (recvDataBuf, recvDataAlignLen);
734732 sendCountAlignLen = Ceil (batchRounds * numExperts * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // 32Kb
735733 pipe.InitBuffer (recvCountBuf, sendCountAlignLen);
0 commit comments