Skip to content

Commit 57ad5c6

Browse files
authored
Fixing the issue where the A2 notify_dispatch operator gets stuck on cann8.3 (#245)
1 parent 3f61eab commit 57ad5c6

File tree

2 files changed

+8
-12
lines changed

2 files changed

+8
-12
lines changed

csrc/deepep/ops2/op_host/notify_dispatch_a2.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,25 +92,20 @@ class NotifyDispatchA2 : public OpDef
9292
this->Attr("local_rank_size").Int();
9393
this->Attr("local_rank_id").Int();
9494

95-
OpAICoreConfig aicore_config_base;
96-
aicore_config_base.DynamicCompileStaticFlag(true)
95+
OpAICoreConfig aicore_config_A2;
96+
aicore_config_A2.DynamicCompileStaticFlag(true)
9797
.DynamicFormatFlag(true)
9898
.DynamicRankSupportFlag(true)
9999
.DynamicShapeSupportFlag(true)
100100
.NeedCheckSupportFlag(false)
101101
.PrecisionReduceFlag(true)
102102
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
103+
.ExtendCfgInfo("prebuildPattern.value", "Opaque")
104+
.ExtendCfgInfo("jitCompile.flag", "static_false")
103105
.ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel");
104106

105-
OpAICoreConfig aicore_config_A2 = aicore_config_base;
106-
aicore_config_A2.ExtendCfgInfo("jitCompile.flag", "static_false");
107-
108-
OpAICoreConfig aicore_config = aicore_config_base;
109-
aicore_config.ExtendCfgInfo("jitCompile.flag", "static_true");
110-
111-
this->AICore().AddConfig("ascend910_93", aicore_config);
112107
this->AICore().AddConfig("ascend910b", aicore_config_A2);
113-
this->MC2().HcclGroup("comm_group");
108+
this->MC2().HcclGroup({"comm_group"});
114109
}
115110
};
116111

csrc/deepep/ops2/op_host/notify_dispatch_tiling_a2.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,8 @@ static void SetHcommCfg(const gert::TilingContext *context, NotifyDispatchA2Tili
186186
{
187187
const char *nodeName = context->GetNodeName();
188188
OP_LOGD(nodeName, "NotifyDispatchA2 commGroup = %s", commGroup.c_str());
189-
uint32_t opType1 = OP_TYPE_ALL_TO_ALL;
190-
std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh;level1:pairwise";
189+
uint32_t opType1 = 18; // batch write=18,
190+
std::string algConfigAllToAllStr = "BatchWrite=level1:hierarchy";
191191

192192
AscendC::Mc2CcTilingConfig mc2CcTilingConfig(commGroup, opType1, algConfigAllToAllStr);
193193
mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling);
@@ -442,6 +442,7 @@ static ge::graphStatus NotifyDispatchA2TilingFuncImpl(gert::TilingContext *conte
442442
if (socVersion == "Ascend910B") {
443443
tilingKey = tilingKey + TILING_KEY_A2_TYPE;
444444
}
445+
OP_LOGD(nodeName, "tilingKey is %lu", tilingKey);
445446
context->SetTilingKey(tilingKey);
446447

447448
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());

0 commit comments

Comments
 (0)