@@ -51,7 +51,8 @@ class NotifyDispatch
5151 // Synchronization flag occupies length
5252 constexpr static int64_t FLAG_UNIT_INT_NUM = 4 ;
5353 constexpr static int64_t MAGIC_MASK = ~((1LL << 32 ) - 1 );
54- constexpr static int32_t BATCH_ROUND = 32 ;
54+ constexpr static int32_t EXPERT_NORMAL_NUM = 256 ;
55+ constexpr static int32_t BATCH_ROUND = 16 ;
5556
5657public:
5758 __aicore__ inline NotifyDispatch (int rank, int rankSize, uint32_t extraFlag)
@@ -71,7 +72,7 @@ class NotifyDispatch
7172 recvOffset_ = recvOffset;
7273 maxBs_ = maxBs;
7374 recvTokensPerExpert_ = recvTokensPerExpert;
74- batchRounds = BATCH_ROUND;
75+ batchRounds = numExperts > EXPERT_NORMAL_NUM ? BATCH_ROUND : BATCH_ROUND * 2 ;
7576 tokenPerExpertDataAlignLen = Ceil (batchRounds * numExperts * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
7677 sendDataOffsetAlignLen = Ceil (batchRounds * numExperts * sizeof (T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
7778 sendDataAlignLen = Ceil (batchRounds * numExperts * sendPerGroup * sizeof (T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
@@ -339,12 +340,14 @@ class NotifyDispatch
339340 uint32_t singleRankTotalElemCount = round * numLocalExperts * sendPerGroup;
340341 uint32_t singleRankBatchElemCount = currentBatchRounds * numLocalExperts * sendPerGroup;
341342 uint32_t singleRankBatchDataLen = singleRankBatchElemCount * sizeof (int32_t );
343+ uint32_t alignedDataLen = Ceil (singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
344+ uint32_t strideElem = alignedDataLen / sizeof (int32_t ); // 目标地址也改变,使用对齐后的地址
342345 DataCopyExtParams recvDataParams = {1U , static_cast <uint32_t >(singleRankBatchDataLen), 0 , 0 , 0 };
343346 DataCopyPadExtParams<T> DataCopyPadExtParams{false , 0U , 0U , 0U };
344347
345348 for (uint32_t i = 0 ; i < rankSize; i++) {
346349 uint32_t srcOffset = i * singleRankTotalElemCount + rStart * numLocalExperts * sendPerGroup;
347- uint32_t dstOffset = i * singleRankBatchElemCount ;
350+ uint32_t dstOffset = i * strideElem ;
348351 // 搬运该Rank下的 currentBatchRounds 数据
349352 DataCopyPad (recvDataTensor[dstOffset], recvDataOutputGt[srcOffset], recvDataParams, DataCopyPadExtParams);
350353 }
@@ -357,14 +360,18 @@ class NotifyDispatch
357360 Duplicate<T>(recvCountTensor, 0 , sendCountAlignLen / sizeof (int32_t )); // V
358361
359362 SyncFunc<AscendC::HardEvent::V_S>();
363+ uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof (int32_t );
364+ uint32_t alignedDataLen = Ceil (singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
365+ uint32_t strideElem = alignedDataLen / sizeof (int32_t );
360366 uint32_t computeNum = currentBatchRounds * numLocalExperts;
361367 for (uint32_t r = 0 ; r < currentBatchRounds; ++r) {
362368 uint32_t computeNumIn = r * numLocalExperts;
363369 uint32_t computeNumOut = r * numExperts;
364370 for (uint32_t expId = 0 ; expId < numLocalExperts; ++expId) {
365371 for (uint32_t srcRank = 0 ; srcRank < rankSize; ++srcRank) {
366372 uint32_t index = expId * rankSize + srcRank;
367- uint32_t pair_idx = sendPerGroup * (srcRank * computeNum + computeNumIn + expId);
373+ uint32_t offsetInRank = sendPerGroup * (computeNumIn + expId);
374+ uint32_t pair_idx = srcRank * strideElem + offsetInRank;
368375 recvCountTensor (computeNumOut + index) = recvDataTensor (pair_idx);
369376 }
370377 }
@@ -376,56 +383,34 @@ class NotifyDispatch
376383 sendOffsetTensor = sendOffsetBuf.Get <T>();
377384 Duplicate<T>(sendOffsetTensor, 0 , sendCountAlignLen / sizeof (int32_t ));
378385 SyncFunc<AscendC::HardEvent::V_S>();
386+ uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof (int32_t );
387+ uint32_t alignedDataLen = Ceil (singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
388+ uint32_t strideElem = alignedDataLen / sizeof (int32_t );
379389 uint32_t computeNum = currentBatchRounds * numLocalExperts;
380390 for (uint32_t r = 0 ; r < currentBatchRounds; ++r) {
381391 uint32_t computeNumIn = r * numLocalExperts;
382392 uint32_t computeNumOut = r * numExperts;
383393 for (uint32_t expId = 0 ; expId < numLocalExperts; ++expId) {
384394 for (uint32_t srcRank = 0 ; srcRank < rankSize; ++srcRank) {
385395 uint32_t index = expId * rankSize + srcRank;
386- uint32_t pair_idx = sendPerGroup * (srcRank * computeNum + computeNumIn + expId);
396+ uint32_t offsetInRank = sendPerGroup * (computeNumIn + expId);
397+ uint32_t pair_idx = srcRank * strideElem + offsetInRank;
387398 sendOffsetTensor (computeNumOut + index) = recvDataTensor (pair_idx + 1 );
388399 }
389400 }
390401 }
391402 }
392403
393- __aicore__ inline void ReorderSendTokensPerRankOutput ()
394- {
395- pipe.InitBuffer (sendTokensPerRankBuf, sendTokensPerRankAlignLen);
396- pipe.InitBuffer (seenRoundBuf, sendTokensPerRankAlignLen);
397- sendTokensPerRankTensor = sendTokensPerRankBuf.Get <int32_t >();
398- seenRoundTensor = seenRoundBuf.Get <int32_t >();
399- Duplicate<int32_t >(sendTokensPerRankTensor, 0 , sendTokensPerRankAlignLen / sizeof (int32_t ));
400- SyncFunc<AscendC::HardEvent::V_S>();
401- SyncFunc<AscendC::HardEvent::MTE2_S>();
402- for (uint32_t r = 0 ; r < round; ++r) {
403- Duplicate<int32_t >(seenRoundTensor, 0 , sendTokensPerRankAlignLen / sizeof (int32_t ));
404- SyncFunc<AscendC::HardEvent::V_S>();
405- for (uint32_t expId = 0 ; expId < numLocalExperts; ++expId) {
406- for (uint32_t srcRank = 0 ; srcRank < rankSize; ++srcRank) {
407- uint32_t index = expId * rankSize + srcRank;
408- uint32_t pair_idx =
409- sendPerGroup * (srcRank * numLocalExperts * round + r * numLocalExperts + expId);
410- if (!seenRoundTensor (srcRank)) {
411- sendTokensPerRankTensor (srcRank) += recvDataTensor (pair_idx + 2 );
412- seenRoundTensor (srcRank) = 1 ;
413- }
414- }
415- }
416- SyncFunc<AscendC::HardEvent::S_V>();
417- }
418- }
419-
420404 __aicore__ inline void BuildTotalRecvTokens ()
421405 {
422406 if (blockIdx != TOTAL_CNT_CORE) {
423407 return ;
424408 }
425409 int32_t sumVal = 0 ;
426-
427- recvDataAlignLen =
428- Ceil (batchRounds * numExperts * sendPerGroup * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
410+ uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
411+ uint32_t singleRankMaxLen = singleRankMaxElem * sizeof (int32_t );
412+ uint32_t singleRankAlignLen = Ceil (singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
413+ recvDataAlignLen = rankSize * singleRankAlignLen;
429414 pipe.InitBuffer (recvDataBuf, recvDataAlignLen);
430415 sendCountAlignLen = Ceil (batchRounds * numExperts * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
431416 pipe.InitBuffer (recvCountBuf, sendCountAlignLen);
@@ -467,8 +452,10 @@ class NotifyDispatch
467452 if (blockIdx != RECV_COUNT_CORE) {
468453 return ;
469454 }
470- recvDataAlignLen =
471- Ceil (batchRounds * numExperts * sendPerGroup * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
455+ uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
456+ uint32_t singleRankMaxLen = singleRankMaxElem * sizeof (int32_t );
457+ uint32_t singleRankAlignLen = Ceil (singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
458+ recvDataAlignLen = rankSize * singleRankAlignLen;
472459 pipe.InitBuffer (recvDataBuf, recvDataAlignLen);
473460 sendCountAlignLen = Ceil (batchRounds * numExperts * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
474461 pipe.InitBuffer (recvCountBuf, sendCountAlignLen);
@@ -505,8 +492,10 @@ class NotifyDispatch
505492 if (blockIdx != RECV_OFFSET_CORE) {
506493 return ;
507494 }
508- recvDataAlignLen =
509- Ceil (batchRounds * numExperts * sendPerGroup * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
495+ uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
496+ uint32_t singleRankMaxLen = singleRankMaxElem * sizeof (int32_t );
497+ uint32_t singleRankAlignLen = Ceil (singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
498+ recvDataAlignLen = rankSize * singleRankAlignLen;
510499 pipe.InitBuffer (recvDataBuf, recvDataAlignLen);
511500 sendCountAlignLen = Ceil (batchRounds * numExperts * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
512501 pipe.InitBuffer (sendOffsetBuf, sendCountAlignLen);
@@ -535,8 +524,10 @@ class NotifyDispatch
535524 if (blockIdx != MAX_BS_CORE) {
536525 return ;
537526 }
538- recvDataAlignLen =
539- Ceil (batchRounds * numExperts * sendPerGroup * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
527+ uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
528+ uint32_t singleRankMaxLen = singleRankMaxElem * sizeof (int32_t );
529+ uint32_t singleRankAlignLen = Ceil (singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
530+ recvDataAlignLen = rankSize * singleRankAlignLen;
540531 pipe.InitBuffer (recvDataBuf, recvDataAlignLen);
541532
542533 pipe.InitBuffer (sendTokensPerRankBuf, sendTokensPerRankAlignLen);
@@ -549,16 +540,19 @@ class NotifyDispatch
549540 SyncFunc<AscendC::HardEvent::MTE2_S>();
550541 for (uint32_t rStart = 0 ; rStart < round; rStart += batchRounds) {
551542 uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds;
552-
543+ uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof (int32_t );
544+ uint32_t alignedDataLen = Ceil (singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
545+ uint32_t strideElem = alignedDataLen / sizeof (int32_t );
553546 ReorderOutput (rStart, currentBatchRounds);
554547 SyncFunc<AscendC::HardEvent::MTE2_S>();
555548 for (uint32_t r = 0 ; r < currentBatchRounds; ++r) {
549+ uint32_t offsetInRound = r * numLocalExperts;
556550 Duplicate<int32_t >(seenRoundTensor, 0 , sendTokensPerRankAlignLen / sizeof (int32_t ));
557551 SyncFunc<AscendC::HardEvent::V_S>();
558552 for (uint32_t expId = 0 ; expId < numLocalExperts; ++expId) {
559553 for (uint32_t srcRank = 0 ; srcRank < rankSize; ++srcRank) {
560- uint32_t pair_idx = sendPerGroup * (srcRank * numLocalExperts * currentBatchRounds +
561- r * numLocalExperts + expId) ;
554+ uint32_t offsetInRank = sendPerGroup * (offsetInRound + expId);
555+ uint32_t pair_idx = srcRank * strideElem + offsetInRank ;
562556 if (!seenRoundTensor (srcRank)) {
563557 sendTokensPerRankTensor (srcRank) += recvDataTensor (pair_idx + 2 );
564558 seenRoundTensor (srcRank) = 1 ;
@@ -585,8 +579,10 @@ class NotifyDispatch
585579 if (blockIdx != RECV_TOKEN_PER_EXP_CORE) {
586580 return ;
587581 }
588- recvDataAlignLen =
589- Ceil (batchRounds * numExperts * sendPerGroup * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
582+ uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
583+ uint32_t singleRankMaxLen = singleRankMaxElem * sizeof (int32_t );
584+ uint32_t singleRankAlignLen = Ceil (singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
585+ recvDataAlignLen = rankSize * singleRankAlignLen;
590586 pipe.InitBuffer (recvDataBuf, recvDataAlignLen);
591587 sendCountAlignLen = Ceil (batchRounds * numExperts * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
592588 pipe.InitBuffer (recvCountBuf, sendCountAlignLen);
@@ -630,8 +626,10 @@ class NotifyDispatch
630626 return ;
631627 }
632628
633- recvDataAlignLen =
634- Ceil (batchRounds * numExperts * sendPerGroup * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
629+ uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
630+ uint32_t singleRankMaxLen = singleRankMaxElem * sizeof (int32_t );
631+ uint32_t singleRankAlignLen = Ceil (singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
632+ recvDataAlignLen = rankSize * singleRankAlignLen;
635633 pipe.InitBuffer (recvDataBuf, recvDataAlignLen);
636634 sendCountAlignLen = Ceil (batchRounds * numExperts * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
637635 pipe.InitBuffer (recvCountBuf, sendCountAlignLen);
@@ -676,8 +674,10 @@ class NotifyDispatch
676674 if (blockIdx != SRC_RANK_EXP_OFFSET_CORE) {
677675 return ;
678676 }
679- recvDataAlignLen =
680- Ceil (batchRounds * numExperts * sendPerGroup * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
677+ uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
678+ uint32_t singleRankMaxLen = singleRankMaxElem * sizeof (int32_t );
679+ uint32_t singleRankAlignLen = Ceil (singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
680+ recvDataAlignLen = rankSize * singleRankAlignLen;
681681 pipe.InitBuffer (recvDataBuf, recvDataAlignLen);
682682 sendCountAlignLen = Ceil (batchRounds * numExperts * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // 32Kb
683683 pipe.InitBuffer (recvCountBuf, sendCountAlignLen);
@@ -726,8 +726,10 @@ class NotifyDispatch
726726 if (blockIdx != R_IN_SRCRANK_OFFSET_CORE) {
727727 return ;
728728 }
729- recvDataAlignLen =
730- Ceil (batchRounds * numExperts * sendPerGroup * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
729+ uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
730+ uint32_t singleRankMaxLen = singleRankMaxElem * sizeof (int32_t );
731+ uint32_t singleRankAlignLen = Ceil (singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
732+ recvDataAlignLen = rankSize * singleRankAlignLen;
731733 pipe.InitBuffer (recvDataBuf, recvDataAlignLen);
732734 sendCountAlignLen = Ceil (batchRounds * numExperts * sizeof (int32_t ), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // 32Kb
733735 pipe.InitBuffer (recvCountBuf, sendCountAlignLen);
0 commit comments