Skip to content

Commit 7016066

Browse files
[Cherry-Pick][Fix][TIR] UnifyThreadBinding creating unit loop with annotation (apache/tvm#14588) (#187)
This PR fixes a behavior of the UnifyThreadBinding pass which (at one place) assumes a return value is always a ForNode, which is not right. To be more specific, when a thread-binding loop has an annotation, the current behavior is assuming that the post-recursive-mutation value is also a ForNode, and apply the previous annotation directly to the new loop. However, the post-recursive-mutation value is also possibly not a ForNode. In this case, the current behavior is incorrect. This PR creates a new unit-length loop in this case to preserve the annotation. Thanks Bohan for catching this issue. Co-authored-by: Bohan Hou <spectrometerh@gmail.com>
1 parent 65efae0 commit 7016066

2 files changed

Lines changed: 39 additions & 3 deletions

File tree

src/tir/transforms/unify_thread_binding.cc

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,20 @@ class ThreadBindingUnifier : public StmtExprMutator {
6464
if (annotations.empty()) {
6565
return stmt;
6666
}
67-
For new_loop = Downcast<For>(stmt);
68-
new_loop.CopyOnWrite()->annotations = std::move(annotations);
69-
return std::move(new_loop);
67+
if (const auto* loop = stmt.as<ForNode>()) {
68+
For new_loop = GetRef<For>(loop);
69+
new_loop.CopyOnWrite()->annotations = std::move(annotations);
70+
return std::move(new_loop);
71+
} else {
72+
// Create a new unit loop with the annotation.
73+
DataType dtype = op->loop_var->dtype;
74+
return For(/*loop_var=*/Var("var", dtype), //
75+
/*min=*/IntImm(dtype, 0), //
76+
/*extent=*/IntImm(dtype, 1), //
77+
/*kind=*/ForKind::kSerial, stmt, //
78+
/*thread_binding=*/NullOpt, //
79+
/*annotation=*/std::move(annotations));
80+
}
7081
}
7182

7283
template <typename Node>

tests/python/unittest/test_tir_transform_unify_thread_binding.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,31 @@ def test_implicit_block():
286286
_check(element_wise_implicit_block, unified_element_wise_implicit_block)
287287

288288

289+
def test_inner_binding_with_annotation():
290+
@T.prim_func
291+
def inner_binding_with_annotation(A: T.Buffer((64,), "float32"), B: T.Buffer((64,), "float32")):
292+
for bx in T.thread_binding(32, "blockIdx.x"):
293+
for tx in T.thread_binding(2, "threadIdx.x", annotations={"my_annotation": 1}):
294+
with T.block("block"):
295+
v = T.axis.spatial(64, bx * 2 + tx)
296+
B[v] = A[v]
297+
298+
@T.prim_func
299+
def unified_inner_binding_with_annotation(
300+
A: T.Buffer((64,), "float32"), B: T.Buffer((64,), "float32")
301+
):
302+
for blockIdx_x in T.thread_binding(32, thread="blockIdx.x"):
303+
for threadIdx_x in T.thread_binding(2, thread="threadIdx.x"):
304+
for var in T.serial(1, annotations={"my_annotation": 1}):
305+
with T.block("block"):
306+
v = T.axis.spatial(64, blockIdx_x * 2 + threadIdx_x)
307+
T.reads(A[v])
308+
T.writes(B[v])
309+
B[v] = A[v]
310+
311+
_check(inner_binding_with_annotation, unified_inner_binding_with_annotation)
312+
313+
289314
def test_lower_te():
290315
a = te.placeholder((32, 2, 2))
291316
b = te.compute((32, 2, 2), lambda i, j, k: a[i, j, k] * 2.0)

0 commit comments

Comments
 (0)