diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index a75bfdcddcda..21de2d86070f 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -801,7 +801,7 @@ class PipelineRewriter : public StmtExprMutator { auto make_nop = []() { return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {})); }; - if (!analyzer_.CanProve(extent > 0)) { + if (analyzer_.CanProve(extent <= 0)) { return make_nop(); } bool is_unit_loop = analyzer_.CanProveEqual(extent, 1); diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index 2a1ce2be28f5..a013cf0f65b8 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -153,6 +153,74 @@ def transformed_simple_compute( C[tx, 15] = B[1, tx, 0] + T.float32(1) +@T.prim_func +def dynamic_compute(a_handle: T.handle, c_handle: T.handle): + k = T.int32() + A = T.match_buffer(a_handle, (16, k), "float32") + C = T.match_buffer(c_handle, (16, k), "float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 0, + k, + annotations={ + "software_pipeline_stage": [0, 1], + "software_pipeline_order": [0, 1], + }, + ): + with T.block("compute"): + T.reads(A[tx, i]) + T.writes(C[tx, i]) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = B[tx, 0] + T.float32(1) + + +@T.prim_func +def transformed_dynamic_compute(a_handle: T.handle, c_handle: T.handle): + k = T.int32() + A = T.match_buffer(a_handle, (16, k), "float32") + C = T.match_buffer(c_handle, (16, k), "float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads(A[tx, 0 : T.max(1, k)]) + T.writes(C[tx, T.min(0, k - 1) : T.min(0, k - 1) + T.max(k, 1)]) + B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") + with T.block(""): + T.reads(A[tx, 0]) + T.writes(B[0, tx, 0]) + with T.block(""): + T.where(0 < k) + T.reads(A[tx, 0]) + T.writes(B[0, tx, 0]) + B[0, tx, 0] = A[tx, 0] * T.float32(2) + with T.block(""): + T.reads(A[tx, 1 : 1 + (k - 1)], B[0:2, tx, 0]) + T.writes(B[0:2, tx, 0], C[tx, 0 : k - 1]) + for i in range(k - 1): + with T.block(""): + T.reads(A[tx, i + 1]) + T.writes(B[(i + 1) % 2, tx, 0]) + B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) + with T.block(""): + T.reads(B[i % 2, tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = B[i % 2, tx, 0] + T.float32(1) + with T.block(""): + T.reads(B[(k + 1) % 2, tx, 0]) + T.writes(C[tx, k - 1]) + with T.block(""): + T.where(1 <= k) + T.reads(B[(k + 1) % 2, tx, 0]) + T.writes(C[tx, k - 1]) + C[tx, k - 1] = B[(k + 1) % 2, tx, 0] + T.float32(1) + + @T.prim_func def simple_compute_with_other_annotation( A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32") @@ -1069,6 +1137,10 @@ def test_simple_compute_with_other_annotation(): _check(simple_compute_with_other_annotation, transformed_simple_compute_with_other_annotation) +def test_dynamic_compute(): + _check(dynamic_compute, transformed_dynamic_compute) + + def test_trivial_pipeline(): _check(trivial_pipeline, transformed_trivial_pipeline)