Skip to content

Commit 854e668

Browse files
committed
implement ops/ops2 separation logic
1 parent 0bc48b6 commit 854e668

27 files changed

+633
-1762
lines changed

.github/workflows/pr-test-npu.yml

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,47 @@ concurrency:
1818
cancel-in-progress: true
1919

2020
jobs:
21+
get-changed-files:
22+
name: Check changed files
23+
runs-on: ubuntu-latest
24+
outputs:
25+
ops2_changed: ${{ steps.match-groups.outputs.ops2_any_changed }}
26+
ops_changed: ${{ steps.match-groups.outputs.ops_any_changed }}
27+
common_changed: ${{ steps.match-groups.outputs.common_any_changed }}
28+
steps:
29+
- name: Checkout code
30+
uses: actions/checkout@v4
31+
with:
32+
fetch-depth: 0
33+
34+
- name: Match changed files
35+
id: match-groups
36+
uses: tj-actions/changed-files@v45
37+
with:
38+
files_yaml: |
39+
ops:
40+
- csrc/deepep/ops/**
41+
ops2:
42+
- csrc/deepep/ops2/**
43+
common:
44+
- csrc/**/*
45+
- '!csrc/deepep/ops/**'
46+
- '!csrc/deepep/ops2/**'
47+
- build.sh
48+
- cmake/**
49+
- python/**
50+
- test/**
51+
- scripts/**
52+
- .github/workflows/pr-test-npu.yml
53+
2154
test-all-build:
22-
if: (github.repository == 'sgl-project/sgl-kernel-npu' || github.event_name == 'pull_request') &&
23-
github.event.pull_request.draft == false
55+
needs: get-changed-files
56+
if: |
57+
github.event_name == 'workflow_dispatch' || (
58+
(github.repository == 'sgl-project/sgl-kernel-npu' || github.event_name == 'pull_request') &&
59+
github.event.pull_request.draft == false &&
60+
(needs.get-changed-files.outputs.ops_changed == 'true' || needs.get-changed-files.outputs.common_changed == 'true')
61+
)
2462
runs-on: linux-aarch64-a3-16
2563
container:
2664
image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc1-a3-ubuntu22.04-py3.11
@@ -336,8 +374,13 @@ jobs:
336374
run: bash scripts/generalization_test_fused_deep_moe.sh
337375

338376
test-build-deepep-a3:
339-
if: (github.repository == 'sgl-project/sgl-kernel-npu' || github.event_name == 'pull_request') &&
340-
github.event.pull_request.draft == false
377+
needs: get-changed-files
378+
if: |
379+
github.event_name == 'workflow_dispatch' || (
380+
(github.repository == 'sgl-project/sgl-kernel-npu' || github.event_name == 'pull_request') &&
381+
github.event.pull_request.draft == false &&
382+
(needs.get-changed-files.outputs.ops_changed == 'true' || needs.get-changed-files.outputs.common_changed == 'true')
383+
)
341384
runs-on: linux-aarch64-a3-16
342385
container:
343386
image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc1-a3-ubuntu22.04-py3.11
@@ -653,8 +696,13 @@ jobs:
653696
run: bash scripts/generalization_test_fused_deep_moe.sh
654697

655698
test-build-deepep-a2:
656-
if: (github.repository == 'sgl-project/sgl-kernel-npu' || github.event_name == 'pull_request') &&
657-
github.event.pull_request.draft == false
699+
needs: get-changed-files
700+
if: |
701+
github.event_name == 'workflow_dispatch' || (
702+
(github.repository == 'sgl-project/sgl-kernel-npu' || github.event_name == 'pull_request') &&
703+
github.event.pull_request.draft == false &&
704+
(needs.get-changed-files.outputs.ops2_changed == 'true' || needs.get-changed-files.outputs.common_changed == 'true')
705+
)
658706
runs-on: linux-aarch64-a2-8
659707
container:
660708
image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc1-910b-ubuntu22.04-py3.11

csrc/deepep/ops/op_host/cam_moe_combine_normal_tiling.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ static bool CheckTensorShape(gert::TilingContext *context, CamMoeCombineNormalTi
398398
int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0);
399399
int64_t xDim1 = xStorageShape->GetStorageShape().GetDim(1);
400400
OP_TILING_CHECK(xDim0 != topkWeightsDim0,
401-
OP_LOGE(nodeName, "x's dim0 not equal to bs, bs = %ld, x's dim0 = %ld", topkWeightsDim0, xDim0),
401+
OP_LOGE(nodeName, "x's dim0 is greater than bs, bs = %ld, x's dim0 = %ld", topkWeightsDim0, xDim0),
402402
return false);
403403
OP_TILING_CHECK(xDim1 != recvXDim1,
404404
OP_LOGE(nodeName, "x's dim1 not equal to h, x's dim1 = %ld, h = %ld", xDim1, recvXDim1),

csrc/deepep/ops/op_kernel/notify_dispatch.h

Lines changed: 52 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ class NotifyDispatch
5151
// Synchronization flag occupies length
5252
constexpr static int64_t FLAG_UNIT_INT_NUM = 4;
5353
constexpr static int64_t MAGIC_MASK = ~((1LL << 32) - 1);
54-
constexpr static int32_t EXPERT_NORMAL_NUM = 256;
55-
constexpr static int32_t BATCH_ROUND = 16;
54+
constexpr static int32_t BATCH_ROUND = 32;
5655

5756
public:
5857
__aicore__ inline NotifyDispatch(int rank, int rankSize, uint32_t extraFlag)
@@ -72,7 +71,7 @@ class NotifyDispatch
7271
recvOffset_ = recvOffset;
7372
maxBs_ = maxBs;
7473
recvTokensPerExpert_ = recvTokensPerExpert;
75-
batchRounds = numExperts > EXPERT_NORMAL_NUM ? BATCH_ROUND : BATCH_ROUND * 2;
74+
batchRounds = BATCH_ROUND;
7675
tokenPerExpertDataAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
7776
sendDataOffsetAlignLen = Ceil(batchRounds * numExperts * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
7877
sendDataAlignLen = Ceil(batchRounds * numExperts * sendPerGroup * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
@@ -340,14 +339,12 @@ class NotifyDispatch
340339
uint32_t singleRankTotalElemCount = round * numLocalExperts * sendPerGroup;
341340
uint32_t singleRankBatchElemCount = currentBatchRounds * numLocalExperts * sendPerGroup;
342341
uint32_t singleRankBatchDataLen = singleRankBatchElemCount * sizeof(int32_t);
343-
uint32_t alignedDataLen = Ceil(singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
344-
uint32_t strideElem = alignedDataLen / sizeof(int32_t); // 目标地址也改变,使用对齐后的地址
345342
DataCopyExtParams recvDataParams = {1U, static_cast<uint32_t>(singleRankBatchDataLen), 0, 0, 0};
346343
DataCopyPadExtParams<T> DataCopyPadExtParams{false, 0U, 0U, 0U};
347344

348345
for (uint32_t i = 0; i < rankSize; i++) {
349346
uint32_t srcOffset = i * singleRankTotalElemCount + rStart * numLocalExperts * sendPerGroup;
350-
uint32_t dstOffset = i * strideElem;
347+
uint32_t dstOffset = i * singleRankBatchElemCount;
351348
// 搬运该Rank下的 currentBatchRounds 数据
352349
DataCopyPad(recvDataTensor[dstOffset], recvDataOutputGt[srcOffset], recvDataParams, DataCopyPadExtParams);
353350
}
@@ -360,18 +357,14 @@ class NotifyDispatch
360357
Duplicate<T>(recvCountTensor, 0, sendCountAlignLen / sizeof(int32_t)); // V
361358

362359
SyncFunc<AscendC::HardEvent::V_S>();
363-
uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof(int32_t);
364-
uint32_t alignedDataLen = Ceil(singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
365-
uint32_t strideElem = alignedDataLen / sizeof(int32_t);
366360
uint32_t computeNum = currentBatchRounds * numLocalExperts;
367361
for (uint32_t r = 0; r < currentBatchRounds; ++r) {
368362
uint32_t computeNumIn = r * numLocalExperts;
369363
uint32_t computeNumOut = r * numExperts;
370364
for (uint32_t expId = 0; expId < numLocalExperts; ++expId) {
371365
for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) {
372366
uint32_t index = expId * rankSize + srcRank;
373-
uint32_t offsetInRank = sendPerGroup * (computeNumIn + expId);
374-
uint32_t pair_idx = srcRank * strideElem + offsetInRank;
367+
uint32_t pair_idx = sendPerGroup * (srcRank * computeNum + computeNumIn + expId);
375368
recvCountTensor(computeNumOut + index) = recvDataTensor(pair_idx);
376369
}
377370
}
@@ -383,34 +376,56 @@ class NotifyDispatch
383376
sendOffsetTensor = sendOffsetBuf.Get<T>();
384377
Duplicate<T>(sendOffsetTensor, 0, sendCountAlignLen / sizeof(int32_t));
385378
SyncFunc<AscendC::HardEvent::V_S>();
386-
uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof(int32_t);
387-
uint32_t alignedDataLen = Ceil(singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
388-
uint32_t strideElem = alignedDataLen / sizeof(int32_t);
389379
uint32_t computeNum = currentBatchRounds * numLocalExperts;
390380
for (uint32_t r = 0; r < currentBatchRounds; ++r) {
391381
uint32_t computeNumIn = r * numLocalExperts;
392382
uint32_t computeNumOut = r * numExperts;
393383
for (uint32_t expId = 0; expId < numLocalExperts; ++expId) {
394384
for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) {
395385
uint32_t index = expId * rankSize + srcRank;
396-
uint32_t offsetInRank = sendPerGroup * (computeNumIn + expId);
397-
uint32_t pair_idx = srcRank * strideElem + offsetInRank;
386+
uint32_t pair_idx = sendPerGroup * (srcRank * computeNum + computeNumIn + expId);
398387
sendOffsetTensor(computeNumOut + index) = recvDataTensor(pair_idx + 1);
399388
}
400389
}
401390
}
402391
}
403392

393+
__aicore__ inline void ReorderSendTokensPerRankOutput()
394+
{
395+
pipe.InitBuffer(sendTokensPerRankBuf, sendTokensPerRankAlignLen);
396+
pipe.InitBuffer(seenRoundBuf, sendTokensPerRankAlignLen);
397+
sendTokensPerRankTensor = sendTokensPerRankBuf.Get<int32_t>();
398+
seenRoundTensor = seenRoundBuf.Get<int32_t>();
399+
Duplicate<int32_t>(sendTokensPerRankTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t));
400+
SyncFunc<AscendC::HardEvent::V_S>();
401+
SyncFunc<AscendC::HardEvent::MTE2_S>();
402+
for (uint32_t r = 0; r < round; ++r) {
403+
Duplicate<int32_t>(seenRoundTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t));
404+
SyncFunc<AscendC::HardEvent::V_S>();
405+
for (uint32_t expId = 0; expId < numLocalExperts; ++expId) {
406+
for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) {
407+
uint32_t index = expId * rankSize + srcRank;
408+
uint32_t pair_idx =
409+
sendPerGroup * (srcRank * numLocalExperts * round + r * numLocalExperts + expId);
410+
if (!seenRoundTensor(srcRank)) {
411+
sendTokensPerRankTensor(srcRank) += recvDataTensor(pair_idx + 2);
412+
seenRoundTensor(srcRank) = 1;
413+
}
414+
}
415+
}
416+
SyncFunc<AscendC::HardEvent::S_V>();
417+
}
418+
}
419+
404420
__aicore__ inline void BuildTotalRecvTokens()
405421
{
406422
if (blockIdx != TOTAL_CNT_CORE) {
407423
return;
408424
}
409425
int32_t sumVal = 0;
410-
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
411-
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
412-
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
413-
recvDataAlignLen = rankSize * singleRankAlignLen;
426+
427+
recvDataAlignLen =
428+
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
414429
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
415430
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
416431
pipe.InitBuffer(recvCountBuf, sendCountAlignLen);
@@ -452,10 +467,8 @@ class NotifyDispatch
452467
if (blockIdx != RECV_COUNT_CORE) {
453468
return;
454469
}
455-
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
456-
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
457-
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
458-
recvDataAlignLen = rankSize * singleRankAlignLen;
470+
recvDataAlignLen =
471+
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
459472
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
460473
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
461474
pipe.InitBuffer(recvCountBuf, sendCountAlignLen);
@@ -492,10 +505,8 @@ class NotifyDispatch
492505
if (blockIdx != RECV_OFFSET_CORE) {
493506
return;
494507
}
495-
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
496-
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
497-
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
498-
recvDataAlignLen = rankSize * singleRankAlignLen;
508+
recvDataAlignLen =
509+
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
499510
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
500511
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
501512
pipe.InitBuffer(sendOffsetBuf, sendCountAlignLen);
@@ -524,10 +535,8 @@ class NotifyDispatch
524535
if (blockIdx != MAX_BS_CORE) {
525536
return;
526537
}
527-
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
528-
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
529-
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
530-
recvDataAlignLen = rankSize * singleRankAlignLen;
538+
recvDataAlignLen =
539+
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
531540
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
532541

533542
pipe.InitBuffer(sendTokensPerRankBuf, sendTokensPerRankAlignLen);
@@ -540,19 +549,16 @@ class NotifyDispatch
540549
SyncFunc<AscendC::HardEvent::MTE2_S>();
541550
for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) {
542551
uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds;
543-
uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof(int32_t);
544-
uint32_t alignedDataLen = Ceil(singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
545-
uint32_t strideElem = alignedDataLen / sizeof(int32_t);
552+
546553
ReorderOutput(rStart, currentBatchRounds);
547554
SyncFunc<AscendC::HardEvent::MTE2_S>();
548555
for (uint32_t r = 0; r < currentBatchRounds; ++r) {
549-
uint32_t offsetInRound = r * numLocalExperts;
550556
Duplicate<int32_t>(seenRoundTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t));
551557
SyncFunc<AscendC::HardEvent::V_S>();
552558
for (uint32_t expId = 0; expId < numLocalExperts; ++expId) {
553559
for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) {
554-
uint32_t offsetInRank = sendPerGroup * (offsetInRound + expId);
555-
uint32_t pair_idx = srcRank * strideElem + offsetInRank;
560+
uint32_t pair_idx = sendPerGroup * (srcRank * numLocalExperts * currentBatchRounds +
561+
r * numLocalExperts + expId);
556562
if (!seenRoundTensor(srcRank)) {
557563
sendTokensPerRankTensor(srcRank) += recvDataTensor(pair_idx + 2);
558564
seenRoundTensor(srcRank) = 1;
@@ -579,10 +585,8 @@ class NotifyDispatch
579585
if (blockIdx != RECV_TOKEN_PER_EXP_CORE) {
580586
return;
581587
}
582-
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
583-
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
584-
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
585-
recvDataAlignLen = rankSize * singleRankAlignLen;
588+
recvDataAlignLen =
589+
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
586590
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
587591
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
588592
pipe.InitBuffer(recvCountBuf, sendCountAlignLen);
@@ -626,10 +630,8 @@ class NotifyDispatch
626630
return;
627631
}
628632

629-
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
630-
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
631-
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
632-
recvDataAlignLen = rankSize * singleRankAlignLen;
633+
recvDataAlignLen =
634+
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
633635
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
634636
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
635637
pipe.InitBuffer(recvCountBuf, sendCountAlignLen);
@@ -674,10 +676,8 @@ class NotifyDispatch
674676
if (blockIdx != SRC_RANK_EXP_OFFSET_CORE) {
675677
return;
676678
}
677-
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
678-
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
679-
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
680-
recvDataAlignLen = rankSize * singleRankAlignLen;
679+
recvDataAlignLen =
680+
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
681681
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
682682
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // 32Kb
683683
pipe.InitBuffer(recvCountBuf, sendCountAlignLen);
@@ -726,10 +726,8 @@ class NotifyDispatch
726726
if (blockIdx != R_IN_SRCRANK_OFFSET_CORE) {
727727
return;
728728
}
729-
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
730-
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
731-
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
732-
recvDataAlignLen = rankSize * singleRankAlignLen;
729+
recvDataAlignLen =
730+
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
733731
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
734732
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // 32Kb
735733
pipe.InitBuffer(recvCountBuf, sendCountAlignLen);

csrc/deepep/ops2/op_host/cam_moe_combine_normal.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class CamMoeCombineNormal : public OpDef
3030
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
3131
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
3232
.AutoContiguous();
33-
this->Input("token_idx")
33+
this->Input("topk_idx")
3434
.ParamType(REQUIRED)
3535
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
3636
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})

0 commit comments

Comments
 (0)