Skip to content
This repository was archived by the owner on Nov 25, 2022. It is now read-only.

Commit bf64614

Browse files
vinx13liuxinwei
authored andcommitted
[MetaSchedule] Fix thread bindings of MultiLevelTilingTensorCore (apache#13243)
1 parent f5c2f46 commit bf64614

2 files changed

Lines changed: 6 additions & 1 deletion

File tree

src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,11 @@ ScheduleRule ScheduleRule::MultiLevelTilingTensorCore(
556556
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
557557
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write,
558558
bool use_software_pipeline) {
559+
if (tile_binds.defined()) {
560+
for (const String& tile_bind : tile_binds.value()) {
561+
CHECK_NE(tile_bind, "threadIdx.x") << "Cannot bind to threadIdx.x when using tensor core.";
562+
}
563+
}
559564
auto node = MultiLevelTilingInitCommon<MultiLevelTilingTensorCoreNode>(
560565
structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write);
561566

src/meta_schedule/schedule_rule/schedule_rule.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() {
139139
Array<ScheduleRule> results{ScheduleRule::MultiLevelTilingTensorCore(
140140
/*intrin_groups=*/intrin_groups,
141141
/*structure=*/"SSSRRSRS",
142-
/*tile_binds=*/Array<String>{"blockIdx.x", "vthread.x", "threadIdx.x"},
142+
/*tile_binds=*/Array<String>{"blockIdx.y", "blockIdx.x", "threadIdx.y"},
143143
/*max_innermost_factor=*/Integer(4),
144144
/*vector_load_lens=*/Array<Integer>{1, 2, 3, 4, 8, 16},
145145
/*reuse_read=*/

0 commit comments

Comments
 (0)