From 92be12a1c9a6a0b3379f3ff1fa88fa20e513ea1b Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 6 Jan 2024 21:29:21 -0800 Subject: [PATCH] [Bugfix] Disable SingleEnvThreadVerifier During TensorIR scheduling, the `IterVar`s that represent environment threads may duplicate, i.e. it is legal to have two env threads with the same name tag, which may fail the `SingleEnvThreadVerifier` check during schedule creation. This PR disables this check in this case. In the future, it may be worthwhile to bring it back against post-scheduling TIR. --- src/tir/analysis/verify_well_formed.cc | 2 -- .../tir-schedule/test_tir_schedule_error.py | 26 +++++++++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc index 58eadb20fa01..943a11971115 100644 --- a/src/tir/analysis/verify_well_formed.cc +++ b/src/tir/analysis/verify_well_formed.cc @@ -347,7 +347,6 @@ bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) { } if (!UndefinedVarVerifier::Verify(func, assert_mode)) return false; - if (!SingleEnvThreadVerifier::Verify(func, assert_mode)) return false; // TODO(Siyuan): add more checks here. return true; @@ -364,7 +363,6 @@ bool VerifyWellFormed(const IRModule& mod, bool assert_mode) { } if (!UndefinedVarVerifier::Verify(mod, assert_mode)) return false; - if (!SingleEnvThreadVerifier::Verify(mod, assert_mode)) return false; return true; } diff --git a/tests/python/tir-schedule/test_tir_schedule_error.py b/tests/python/tir-schedule/test_tir_schedule_error.py index 99de5305fdd5..755e822a2c01 100644 --- a/tests/python/tir-schedule/test_tir_schedule_error.py +++ b/tests/python/tir-schedule/test_tir_schedule_error.py @@ -15,9 +15,8 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring -import sys - import pytest + import tvm import tvm.testing from tvm import tir @@ -41,6 +40,25 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] +@T.prim_func +def two_kernels(var_A: T.handle, var_B: T.handle, seq_len: T.int32): + T.func_attr({"tir.noalias": T.bool(True)}) + A = T.match_buffer(var_A, (1, seq_len * 8), "int32") + B = T.match_buffer(var_B, (1, seq_len * 8), "int32", align=8) + with T.block("exclusive_scan"): + T.reads() + T.writes() + s8: T.int32 = seq_len * 8 + if s8 == 0: + blockIdx_x = T.launch_thread("blockIdx.x", 1) + else: + with T.launch_thread("threadIdx.x", 1024) as threadIdx_x: + blockIdx_x = T.launch_thread("blockIdx.x", T.ceildiv(s8, 1024)) + i: T.int32 = blockIdx_x * 1024 + threadIdx_x + if i < s8: + B[i // s8, i % s8] = A[i // s8, i % s8] + + # pylint: enable=no-member,invalid-name,unused-variable @@ -74,5 +92,9 @@ def test_tir_schedule_attribute_error(): sch.non_existent_field() +def test_tir_schedule_two_kernels(): + tir.Schedule(two_kernels) + + if __name__ == "__main__": tvm.testing.main()