HoistIfThenElse is a pass currently not enabled in TVM. I tried to enable it in #5553, but there are too many bugs in this pass. Let's fix them first.
BUG 1: HoistIfThenElse transforms
for (n.inner, 0, 2) {
for (o.inner, 0, 2) {
if ((((threadIdx.y*2) + n.inner) < 2)) {
if ((((threadIdx.z*2) + o.inner) < 4)) {
if ((threadIdx.y < 1)) {
if ((threadIdx.z < 2)) {
tvm_store_matrix_sync(Conv.wmma.accumulator, 16, 16, 16, ((n.inner*2) + o.inner), tvm_access_ptr(type_annotation(), Conv, (((((threadIdx.y*401408) + (n.inner*200704)) + (blockIdx.z*1024)) + (threadIdx.z*512)) + (o.inner*256)), 256, 2), 16, "row_major")
}
}
}
}
}
}
into
if ((((threadIdx.y*2) + n.inner) < 2)) {
if ((threadIdx.y < 1)) {
if ((threadIdx.z < 2)) {
for (n.inner, 0, 2) {
for (o.inner, 0, 2) {
if ((((threadIdx.z*2) + o.inner) < 4)) {
tvm_store_matrix_sync(Conv.wmma.accumulator, 16, 16, 16, ((n.inner*2) + o.inner), tvm_access_ptr(type_annotation(), Conv, (((((threadIdx.y*401408) + (n.inner*200704)) + (blockIdx.z*1024)) + (threadIdx.z*512)) + (o.inner*256)), 256, 2), 16, "row_major")
}
}
}
}
}
}
Possible cause:
https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/src/tir/pass/hoist_if_then_else.cc#L295
It only checks whether if_stmt has a preferred position, but that position is not guaranteed to be the current position. Change it to
if (if_position_map.count(if_stmt.get()) &&
if_position_map.at(if_stmt.get()).as<ForNode>()->loop_var.get() == top_for_var) {
may solve the problem.
BUG 2: src/tir/transforms/split_host_device.cc want the IR to be an SSA form, where each variable can only be defined once. Since we are copying loops into both "then" branches and "else" branches, we have to rename the loop variables in "else" branches to be different from those in "then" branches. I have already written some code for this, see #5553.
BUG 3: IfThenElse nodes containing thread indices should not be hoisted over the definition of the indices. This would happen when Attr node for thread_extent is scheduled into the body of a For node, using a compute_at command. I have already written some code for this, see #5553.
BUG 4:
https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/src/tir/pass/hoist_if_then_else.cc#L371
Look at this line. if_stmt can already been updated when running this line. Look at the example below.
for (i, 0, 10) {
for (j, 0, 10) {
for (k, 0, 10) {
if ((i >= 3)) {
if ((j >= 3)) {
data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] + 0.5f)
}
}
}
}
}
After hoisting j >= 3, if becomes
for (i, 0, 10) {
for (j, 0, 10) {
if ((j >= 3)) {
for (k, 0, 10) {
if ((i >= 3)) {
data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] + 0.5f)
}
}
}
}
}
Now, when we are hoisting i >= 3, we need to compare and remove
if ((i >= 3)) {
if ((j >= 3)) {
data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] + 0.5f)
}
}
But j >= 3 has been gone, so RemoveIf fails. We have to track the updating to IfThenElse just like what we did for For.
BUG 5: It is for tests this time.
https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/tests/python/unittest/test_tir_pass_hoist_if.py#L175
Why do we expect a ('For', 'j') inside itself? As a potential problem, maybe we should change the variable names to prevent there are two is and two js.
These are all the bugs I found.
Beside, I suggest changing all the for (size_t i = 0; i < xxx.size(); i++) into for (size_t i = 0, n = xxx.size(); i < n; i++), since C++ compiler can't detect this loop invariant.
@kevinthesun Maybe you can have a look.
HoistIfThenElseis a pass currently not enabled in TVM. I tried to enable it in #5553, but there are too many bugs in this pass. Let's fix them first.BUG 1:
HoistIfThenElsetransformsinto
Possible cause:
https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/src/tir/pass/hoist_if_then_else.cc#L295
It only checks whether
if_stmthas a preferred position, but that position is not guaranteed to be the current position. Change it toif (if_position_map.count(if_stmt.get()) && if_position_map.at(if_stmt.get()).as<ForNode>()->loop_var.get() == top_for_var) {may solve the problem.
BUG 2:
src/tir/transforms/split_host_device.ccwant the IR to be an SSA form, where each variable can only be defined once. Since we are copying loops into both "then" branches and "else" branches, we have to rename the loop variables in "else" branches to be different from those in "then" branches. I have already written some code for this, see #5553.BUG 3:
IfThenElsenodes containing thread indices should not be hoisted over the definition of the indices. This would happen whenAttrnode forthread_extentis scheduled into the body of aFornode, using acompute_atcommand. I have already written some code for this, see #5553.BUG 4:
https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/src/tir/pass/hoist_if_then_else.cc#L371
Look at this line.
if_stmtcan already been updated when running this line. Look at the example below.After hoisting
j >= 3, if becomesNow, when we are hoisting
i >= 3, we need to compare and removeBut
j >= 3has been gone, soRemoveIffails. We have to track the updating toIfThenElsejust like what we did forFor.BUG 5: It is for tests this time.
https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/tests/python/unittest/test_tir_pass_hoist_if.py#L175
Why do we expect a
('For', 'j')inside itself? As a potential problem, maybe we should change the variable names to prevent there are twois and twojs.These are all the bugs I found.
Beside, I suggest changing all the
for (size_t i = 0; i < xxx.size(); i++)intofor (size_t i = 0, n = xxx.size(); i < n; i++), since C++ compiler can't detect this loop invariant.@kevinthesun Maybe you can have a look.