diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index abc288f0eb24..1b2e8e9db04a 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -429,6 +429,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // Fix all local allocations as all statements are built. Stmt body = SeqStmt::Flatten(seq); for (Buffer buf : new_alloc_bufs) { + body = DeclBuffer(buf, body); body = Allocate(buf->data, buf->dtype, buf->shape, const_true(buf->dtype.lanes()), body); } diff --git a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py index f797d35d47ca..d8c9568da90e 100644 --- a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py +++ b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py @@ -29,6 +29,7 @@ class BaseFailure(BaseCompare): class TestBasic(BaseCompare): + @T.prim_func(private=True) def before(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) A_flat = T.Buffer(4096, data=A.data) @@ -54,6 +55,7 @@ def before(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")): if threadIdx_x == 0: B[i] = reduce[0] + @T.prim_func(private=True) def expected(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) A_flat = T.Buffer(4096, data=A.data) @@ -70,10 +72,10 @@ def expected(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")): T.reinterpret("handle", T.uint64(0)), ): mask_data = T.allocate([1], "uint32", "local") - mask = T.Buffer(1, "uint32", data=mask_data, scope="local") + mask = T.decl_buffer(1, "uint32", data=mask_data, scope="local") t0_data = T.allocate([1], "float32", "local") - t0 = T.Buffer(1, data=t0_data, scope="local") + t0 = T.decl_buffer(1, data=t0_data, scope="local") reduce[0] = A_flat[0] mask[0] = T.tvm_warp_activemask() @@ -94,9 +96,10 @@ def expected(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")): class TestBasicWithDeclBuffer(BaseCompare): + @T.prim_func(private=True) def before(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) - A_flat = T.Buffer(4096, data=A.data) + A_flat = T.decl_buffer(4096, data=A.data) for i in range(128): threadIdx_x = T.launch_thread("threadIdx.x", 32) @@ -118,9 +121,10 @@ def before(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")): if threadIdx_x == 0: B[i] = reduce[0] + @T.prim_func(private=True) def expected(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) - A_flat = T.Buffer(4096, data=A.data) + A_flat = T.decl_buffer(4096, data=A.data) for i in range(128): threadIdx_x = T.launch_thread("threadIdx.x", 32) @@ -133,10 +137,10 @@ def expected(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")): T.reinterpret("handle", T.uint64(0)), ): mask_data = T.allocate([1], "uint32", "local") - mask = T.Buffer(1, "uint32", data=mask_data, scope="local") + mask = T.decl_buffer(1, "uint32", data=mask_data, scope="local") t0_data = T.allocate([1], "float32", "local") - t0 = T.Buffer(1, data=t0_data, scope="local") + t0 = T.decl_buffer(1, data=t0_data, scope="local") reduce[0] = A_flat[0] mask[0] = T.tvm_warp_activemask() @@ -157,18 +161,19 @@ def expected(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")): class TestReduceSummation(BaseCompare): + @T.prim_func(private=True) def before(A: T.Buffer((128, 128), "float32"), B: T.Buffer(128, "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) - A_flat = T.Buffer((16384,), data=A.data) + A_flat = T.decl_buffer(16384, data=A.data) for i in range(128): threadIdx_x = T.launch_thread("threadIdx.x", 32) normal_reduce_data = T.allocate([1], "float32", "local") - normal_reduce = T.Buffer(1, data=normal_reduce_data, scope="local") + normal_reduce = T.decl_buffer(1, data=normal_reduce_data, scope="local") reduce_data = T.allocate([1], "float32", "local") - reduce = T.Buffer(1, data=reduce_data, scope="local") + reduce = T.decl_buffer(1, data=reduce_data, scope="local") normal_reduce[0] = T.float32(0) @@ -190,18 +195,19 @@ def before(A: T.Buffer((128, 128), "float32"), B: T.Buffer(128, "float32")): if threadIdx_x == 0: B[i] = reduce[0] + @T.prim_func(private=True) def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer(128, "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) - A_flat = T.Buffer(16384, data=A.data) + A_flat = T.decl_buffer(16384, data=A.data) for i in range(128): threadIdx_x = T.launch_thread("threadIdx.x", 32) normal_reduce_data = T.allocate([1], "float32", "local") - normal_reduce = T.Buffer(1, data=normal_reduce_data, scope="local") + normal_reduce = T.decl_buffer(1, data=normal_reduce_data, scope="local") reduce_data = T.allocate([1], "float32", "local") - reduce = T.Buffer(1, data=reduce_data, scope="local") + reduce = T.decl_buffer(1, data=reduce_data, scope="local") normal_reduce[0] = T.float32(0) for ko in range(4): @@ -212,10 +218,10 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer(128, "float32")): T.reinterpret("handle", T.uint64(0)), ): mask_data = T.allocate([1], "uint32", "local") - mask = T.Buffer(1, "uint32", data=mask_data, scope="local") + mask = T.decl_buffer(1, "uint32", data=mask_data, scope="local") t0_data = T.allocate([1], "float32", "local") - t0 = T.Buffer(1, data=t0_data, scope="local") + t0 = T.decl_buffer(1, data=t0_data, scope="local") reduce[0] = normal_reduce[0] mask[0] = T.tvm_warp_activemask() @@ -236,7 +242,7 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer(128, "float32")): class TestMultiGroupReduction(BaseCompare): - @T.prim_func + @T.prim_func(private=True) def before(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32,), "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) threadIdx_y = T.launch_thread("threadIdx.y", 32) @@ -260,7 +266,7 @@ def before(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32,), "float32")): B_1 = T.Buffer((32,), data=B.data) B_1[threadIdx_y] = cross_thread_B_1[0] - @T.prim_func + @T.prim_func(private=True) def expected(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32,), "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) threadIdx_y = T.launch_thread("threadIdx.y", 32) @@ -272,33 +278,31 @@ def expected(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32,), "float32")): "reduce_scope", T.reinterpret("handle", T.uint64(0)), ): - mask = T.allocate([1], "uint32", "local") - t0 = T.allocate([1], "float32", "local") + mask = T.decl_buffer([1], "uint32", scope="local") + t0 = T.decl_buffer([1], "float32", scope="local") A_1 = T.Buffer((1024,), data=A.data) red_buf0_1[0] = A_1[threadIdx_y * 32 + threadIdx_x] - mask_1 = T.Buffer((1,), "uint32", data=mask, scope="local") - mask_1[0] = T.tvm_warp_activemask() - - t0_1 = T.Buffer((1,), data=t0, scope="local") - t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 16, 32, 32) - red_buf0_1[0] = red_buf0_1[0] + t0_1[0] - t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 8, 32, 32) - red_buf0_1[0] = red_buf0_1[0] + t0_1[0] - t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 4, 32, 32) - red_buf0_1[0] = red_buf0_1[0] + t0_1[0] - t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 2, 32, 32) - red_buf0_1[0] = red_buf0_1[0] + t0_1[0] - t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 32) - red_buf0_1[0] = red_buf0_1[0] + t0_1[0] - red_buf0_1[0] = T.tvm_warp_shuffle(mask_1[0], red_buf0_1[0], 32 * threadIdx_y, 32, 32) + mask[0] = T.tvm_warp_activemask() + + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0_1[0], 16, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0_1[0], 8, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0_1[0], 4, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0_1[0], 2, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0_1[0], 1, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0[0] + red_buf0_1[0] = T.tvm_warp_shuffle(mask[0], red_buf0_1[0], 32 * threadIdx_y, 32, 32) if threadIdx_x == 0: B_1 = T.Buffer((32,), data=B.data) B_1[threadIdx_y] = red_buf0_1[0] class TestMultiGroupMask1(BaseCompare): - @T.prim_func + @T.prim_func(private=True) def before(A: T.Buffer((32, 8), "float32"), B: T.Buffer((32,), "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) threadIdx_y = T.launch_thread("threadIdx.y", 32) @@ -322,7 +326,7 @@ def before(A: T.Buffer((32, 8), "float32"), B: T.Buffer((32,), "float32")): B_1 = T.Buffer((32,), data=B.data) B_1[threadIdx_y] = cross_thread_B_1[0] - @T.prim_func + @T.prim_func(private=True) def expected(A: T.Buffer((32, 8), "float32"), B: T.Buffer((32,), "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) threadIdx_y = T.launch_thread("threadIdx.y", 32) @@ -334,30 +338,28 @@ def expected(A: T.Buffer((32, 8), "float32"), B: T.Buffer((32,), "float32")): "reduce_scope", T.reinterpret("handle", T.uint64(0)), ): - mask = T.allocate([1], "uint32", "local") - t0 = T.allocate([1], "float32", "local") + mask = T.decl_buffer([1], "uint32", scope="local") + t0 = T.decl_buffer([1], "float32", scope="local") A_1 = T.Buffer((256,), data=A.data) red_buf0_1[0] = A_1[threadIdx_y * 8 + threadIdx_x] - mask_1 = T.Buffer((1,), "uint32", data=mask, scope="local") - mask_1[0] = T.bitwise_and( + mask[0] = T.bitwise_and( T.tvm_warp_activemask(), T.shift_left(T.uint32(255), T.uint32(8) * T.Cast("uint32", threadIdx_y)), ) - t0_1 = T.Buffer((1,), data=t0, scope="local") - t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 4, 32, 32) - red_buf0_1[0] = red_buf0_1[0] + t0_1[0] - t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 2, 32, 32) - red_buf0_1[0] = red_buf0_1[0] + t0_1[0] - t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 32) - red_buf0_1[0] = red_buf0_1[0] + t0_1[0] - red_buf0_1[0] = T.tvm_warp_shuffle(mask_1[0], red_buf0_1[0], 8 * threadIdx_y, 32, 32) + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0_1[0], 4, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0_1[0], 2, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0_1[0], 1, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0[0] + red_buf0_1[0] = T.tvm_warp_shuffle(mask[0], red_buf0_1[0], 8 * threadIdx_y, 32, 32) if threadIdx_x == 0: B_1 = T.Buffer((32,), data=B.data) B_1[threadIdx_y] = red_buf0_1[0] class TestMultiWarpReduce1(BaseCompare): - @T.prim_func + @T.prim_func(private=True) def before(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) for i in range(128): @@ -381,7 +383,7 @@ def before(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32")): B_1 = T.Buffer((128,), data=B.data) B_1[i] = cross_thread_B_1[0] - @T.prim_func + @T.prim_func(private=True) def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) for i in range(128): @@ -394,45 +396,38 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32")) "reduce_scope", T.reinterpret("handle", T.uint64(0)), ): - red_buf0 = T.allocate([1], "float32", "local") - mask = T.allocate([1], "uint32", "local") - t0 = T.allocate([1], "float32", "local") - red_buf0_1 = T.allocate([1], "float32", "local") - mask_1 = T.allocate([1], "uint32", "local") - t0_1 = T.allocate([1], "float32", "local") - red_buf_staging = T.allocate([4], "float32", "shared") - red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local") + red_buf0 = T.decl_buffer([1], "float32", scope="local") + mask = T.decl_buffer([1], "uint32", scope="local") + t0 = T.decl_buffer([1], "float32", scope="local") + red_buf0_1 = T.decl_buffer([1], "float32", scope="local") + mask_1 = T.decl_buffer([1], "uint32", scope="local") + t0_1 = T.decl_buffer([1], "float32", scope="local") + red_buf_staging = T.decl_buffer([4], "float32", scope="shared") A_1 = T.Buffer((16384,), data=A.data) - red_buf0_2[0] = A_1[i * 128 + threadIdx_x] - mask_2 = T.Buffer((1,), "uint32", data=mask_1, scope="local") - mask_2[0] = T.tvm_warp_activemask() - t0_2 = T.Buffer((1,), data=t0_1, scope="local") - t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 16, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 8, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 4, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 2, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - red_buf_staging_1 = T.Buffer((4,), data=red_buf_staging, scope="shared") + red_buf0_1[0] = A_1[i * 128 + threadIdx_x] + mask_1[0] = T.tvm_warp_activemask() + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 16, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 8, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 4, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 2, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] if threadIdx_x % 32 == 0: - red_buf_staging_1[threadIdx_x // 32] = red_buf0_2[0] + red_buf_staging[threadIdx_x // 32] = red_buf0_1[0] T.tvm_storage_sync("shared") - red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local") if threadIdx_x < 4: - red_buf0_3[0] = red_buf_staging_1[threadIdx_x] - mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local") - mask_3[0] = T.bitwise_and(T.tvm_warp_activemask(), T.uint32(15)) - t0_3 = T.Buffer((1,), data=t0, scope="local") - t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32, 32) - red_buf0_3[0] = red_buf0_3[0] + t0_3[0] - t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32) - red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + red_buf0[0] = red_buf_staging[threadIdx_x] + mask[0] = T.bitwise_and(T.tvm_warp_activemask(), T.uint32(15)) + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 2, 32, 32) + red_buf0[0] = red_buf0[0] + t0[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 1, 32, 32) + red_buf0[0] = red_buf0[0] + t0[0] if threadIdx_x == 0: - red_result_1[0] = red_buf0_3[0] + red_result_1[0] = red_buf0[0] T.tvm_storage_sync("shared") if threadIdx_x == 0: B_1 = T.Buffer((128,), data=B.data) @@ -440,7 +435,7 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32")) class TestMultiWarpReduce2(BaseCompare): - @T.prim_func + @T.prim_func(private=True) def before(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) threadIdx_x = T.launch_thread("threadIdx.x", 1024) @@ -459,7 +454,7 @@ def before(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")): B_1 = T.Buffer((1,), data=B.data) B_1[0] = cross_thread_B_1[0] - @T.prim_func + @T.prim_func(private=True) def expected(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) threadIdx_x = T.launch_thread("threadIdx.x", 1024) @@ -471,51 +466,44 @@ def expected(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")): "reduce_scope", T.reinterpret("handle", T.uint64(0)), ): - red_buf0 = T.allocate([1], "float32", "local") - mask = T.allocate([1], "uint32", "local") - t0 = T.allocate([1], "float32", "local") - red_buf0_1 = T.allocate([1], "float32", "local") - mask_1 = T.allocate([1], "uint32", "local") - t0_1 = T.allocate([1], "float32", "local") - red_buf_staging = T.allocate([32], "float32", "shared") - red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local") + red_buf0 = T.decl_buffer([1], "float32", scope="local") + mask = T.decl_buffer([1], "uint32", scope="local") + t0 = T.decl_buffer([1], "float32", scope="local") + red_buf0_1 = T.decl_buffer([1], "float32", scope="local") + mask_1 = T.decl_buffer([1], "uint32", scope="local") + t0_1 = T.decl_buffer([1], "float32", scope="local") + red_buf_staging = T.decl_buffer([32], "float32", scope="shared") A_1 = T.Buffer((1024,), data=A.data) - red_buf0_2[0] = A_1[threadIdx_x] - mask_2 = T.Buffer((1,), "uint32", data=mask_1, scope="local") - mask_2[0] = T.tvm_warp_activemask() - t0_2 = T.Buffer((1,), data=t0_1, scope="local") - t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 16, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 8, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 4, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 2, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - red_buf_staging_1 = T.Buffer((32,), data=red_buf_staging, scope="shared") + red_buf0_1[0] = A_1[threadIdx_x] + mask_1[0] = T.tvm_warp_activemask() + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 16, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 8, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 4, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 2, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] if threadIdx_x % 32 == 0: - red_buf_staging_1[threadIdx_x // 32] = red_buf0_2[0] + red_buf_staging[threadIdx_x // 32] = red_buf0_1[0] T.tvm_storage_sync("shared") - red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local") if threadIdx_x < 32: - red_buf0_3[0] = red_buf_staging_1[threadIdx_x] - mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local") - mask_3[0] = T.tvm_warp_activemask() - t0_3 = T.Buffer((1,), data=t0, scope="local") - t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 16, 32, 32) - red_buf0_3[0] = red_buf0_3[0] + t0_3[0] - t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 8, 32, 32) - red_buf0_3[0] = red_buf0_3[0] + t0_3[0] - t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 4, 32, 32) - red_buf0_3[0] = red_buf0_3[0] + t0_3[0] - t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32, 32) - red_buf0_3[0] = red_buf0_3[0] + t0_3[0] - t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32) - red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + red_buf0[0] = red_buf_staging[threadIdx_x] + mask[0] = T.tvm_warp_activemask() + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 16, 32, 32) + red_buf0[0] = red_buf0[0] + t0[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 8, 32, 32) + red_buf0[0] = red_buf0[0] + t0[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 4, 32, 32) + red_buf0[0] = red_buf0[0] + t0[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 2, 32, 32) + red_buf0[0] = red_buf0[0] + t0[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 1, 32, 32) + red_buf0[0] = red_buf0[0] + t0[0] if threadIdx_x == 0: - red_result_1[0] = red_buf0_3[0] + red_result_1[0] = red_buf0[0] T.tvm_storage_sync("shared") if threadIdx_x == 0: B_1 = T.Buffer((1,), data=B.data) @@ -523,7 +511,7 @@ def expected(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")): class TestMultiGroupMultiWarpReduction(BaseCompare): - @T.prim_func + @T.prim_func(private=True) def before(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) threadIdx_y = T.launch_thread("threadIdx.y", 4) @@ -547,7 +535,7 @@ def before(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")): B_1 = T.Buffer((4,), data=B.data) B_1[threadIdx_y] = cross_thread_B_1[0] - @T.prim_func + @T.prim_func(private=True) def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) threadIdx_y = T.launch_thread("threadIdx.y", 4) @@ -560,47 +548,40 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")): "reduce_scope", T.reinterpret("handle", T.uint64(0)), ): - red_buf0 = T.allocate([1], "float32", "local") - mask = T.allocate([1], "uint32", "local") - t0 = T.allocate([1], "float32", "local") - red_buf0_1 = T.allocate([1], "float32", "local") - mask_1 = T.allocate([1], "uint32", "local") - t0_1 = T.allocate([1], "float32", "local") - red_buf_staging = T.allocate([16], "float32", "shared") - red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local") + red_buf0 = T.decl_buffer([1], "float32", scope="local") + mask = T.decl_buffer([1], "uint32", scope="local") + t0 = T.decl_buffer([1], "float32", scope="local") + red_buf0_1 = T.decl_buffer([1], "float32", scope="local") + mask_1 = T.decl_buffer([1], "uint32", scope="local") + t0_1 = T.decl_buffer([1], "float32", scope="local") + red_buf_staging = T.decl_buffer([16], "float32", scope="shared") A_1 = T.Buffer((512,), data=A.data) - red_buf0_2[0] = A_1[threadIdx_y * 128 + threadIdx_x] - mask_2 = T.Buffer((1,), "uint32", data=mask_1, scope="local") - mask_2[0] = T.tvm_warp_activemask() - t0_2 = T.Buffer((1,), data=t0_1, scope="local") - t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 16, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 8, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 4, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 2, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - red_buf_staging_1 = T.Buffer((16,), data=red_buf_staging, scope="shared") + red_buf0_1[0] = A_1[threadIdx_y * 128 + threadIdx_x] + mask_1[0] = T.tvm_warp_activemask() + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 16, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 8, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 4, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 2, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] if threadIdx_x % 32 == 0: - red_buf_staging_1[threadIdx_y * 4 + threadIdx_x // 32] = red_buf0_2[0] + red_buf_staging[threadIdx_y * 4 + threadIdx_x // 32] = red_buf0_1[0] T.tvm_storage_sync("shared") - red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local") if threadIdx_x < 4: - red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 4 + threadIdx_x] - mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local") - mask_3[0] = T.bitwise_and( + red_buf0[0] = red_buf_staging[threadIdx_y * 4 + threadIdx_x] + mask[0] = T.bitwise_and( T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15, threadIdx_y * 4)) ) - t0_3 = T.Buffer((1,), data=t0, scope="local") - t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32, 32) - red_buf0_3[0] = red_buf0_3[0] + t0_3[0] - t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32) - red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 2, 32, 32) + red_buf0[0] = red_buf0[0] + t0[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 1, 32, 32) + red_buf0[0] = red_buf0[0] + t0[0] if threadIdx_x == 0: - red_result_1[threadIdx_y] = red_buf0_3[0] + red_result_1[threadIdx_y] = red_buf0[0] T.tvm_storage_sync("shared") if threadIdx_x == 0: B_1 = T.Buffer((4,), data=B.data) @@ -608,7 +589,7 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")): class TestMultiGroupMultiWarpPredicatedReduction(BaseCompare): - @T.prim_func + @T.prim_func(private=True) def before(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) threadIdx_y = T.launch_thread("threadIdx.y", 2) @@ -633,7 +614,7 @@ def before(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")): B_1 = T.Buffer((2,), data=B.data) B_1[threadIdx_y] = cross_thread_B_1[0] - @T.prim_func + @T.prim_func(private=True) def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) threadIdx_y = T.launch_thread("threadIdx.y", 2) @@ -652,50 +633,43 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")): "reduce_scope", T.reinterpret("handle", T.uint64(0)), ): - red_buf0 = T.allocate([1], "float32", "local") - mask = T.allocate([1], "uint32", "local") - t0 = T.allocate([1], "float32", "local") - red_buf0_1 = T.allocate([1], "float32", "local") - mask_1 = T.allocate([1], "uint32", "local") - t0_1 = T.allocate([1], "float32", "local") - red_buf_staging = T.allocate([32], "float32", "shared") - red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local") - red_buf0_2[0] = in_thread_B_1[0] - mask_2 = T.Buffer((1,), "uint32", data=mask_1, scope="local") - mask_2[0] = T.tvm_warp_activemask() - t0_2 = T.Buffer((1,), data=t0_1, scope="local") - t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 16, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 8, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 4, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 2, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - red_buf_staging_1 = T.Buffer((32,), data=red_buf_staging, scope="shared") + red_buf0 = T.decl_buffer([1], "float32", scope="local") + mask = T.decl_buffer([1], "uint32", scope="local") + t0 = T.decl_buffer([1], "float32", scope="local") + red_buf0_1 = T.decl_buffer([1], "float32", scope="local") + mask_1 = T.decl_buffer([1], "uint32", scope="local") + t0_1 = T.decl_buffer([1], "float32", scope="local") + red_buf_staging = T.decl_buffer([32], "float32", scope="shared") + red_buf0_1[0] = in_thread_B_1[0] + mask_1[0] = T.tvm_warp_activemask() + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 16, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 8, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 4, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 2, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] if threadIdx_x % 32 == 0: - red_buf_staging_1[threadIdx_y * 16 + threadIdx_x // 32] = red_buf0_2[0] + red_buf_staging[threadIdx_y * 16 + threadIdx_x // 32] = red_buf0_1[0] T.tvm_storage_sync("shared") - red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local") if threadIdx_x < 16: - red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 16 + threadIdx_x] - mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local") - mask_3[0] = T.bitwise_and( + red_buf0[0] = red_buf_staging[threadIdx_y * 16 + threadIdx_x] + mask[0] = T.bitwise_and( T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535, threadIdx_y * 16)) ) - t0_3 = T.Buffer((1,), data=t0, scope="local") - t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 8, 32, 32) - red_buf0_3[0] = red_buf0_3[0] + t0_3[0] - t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 4, 32, 32) - red_buf0_3[0] = red_buf0_3[0] + t0_3[0] - t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32, 32) - red_buf0_3[0] = red_buf0_3[0] + t0_3[0] - t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32) - red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 8, 32, 32) + red_buf0[0] = red_buf0[0] + t0[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 4, 32, 32) + red_buf0[0] = red_buf0[0] + t0[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 2, 32, 32) + red_buf0[0] = red_buf0[0] + t0[0] + t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 1, 32, 32) + red_buf0[0] = red_buf0[0] + t0[0] if threadIdx_x == 0: - red_result_1[threadIdx_y] = red_buf0_3[0] + red_result_1[threadIdx_y] = red_buf0[0] T.tvm_storage_sync("shared") if threadIdx_x == 0: B_1 = T.Buffer((2,), data=B.data) @@ -703,7 +677,7 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")): class TestMetalNoMask(BaseCompare): - @T.prim_func + @T.prim_func(private=True) def before(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32")): T.func_attr( { @@ -740,7 +714,7 @@ def before(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float B_1 = T.Buffer((2,), data=B.data) B_1[threadIdx_y] = cross_thread_B_1[0] - @T.prim_func + @T.prim_func(private=True) def expected(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32")): T.func_attr( { @@ -766,39 +740,34 @@ def expected(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "flo "reduce_scope", T.reinterpret("handle", T.uint64(0)), ): - red_buf0 = T.allocate([1], "float32", "local") - t0 = T.allocate([1], "float32", "local") - red_buf0_1 = T.allocate([1], "float32", "local") - t0_1 = T.allocate([1], "float32", "local") - red_buf_staging = T.allocate([8], "float32", "shared") - red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local") + red_buf0 = T.decl_buffer([1], "float32", scope="local") + t0 = T.decl_buffer([1], "float32", scope="local") + red_buf0_1 = T.decl_buffer([1], "float32", scope="local") + t0_1 = T.decl_buffer([1], "float32", scope="local") + red_buf_staging = T.decl_buffer([8], "float32", scope="shared") A_1 = T.Buffer((256,), data=A.data) - red_buf0_2[0] = A_1[threadIdx_y * 128 + threadIdx_x] - t0_2 = T.Buffer((1,), data=t0_1, scope="local") - t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 16, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 8, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 4, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 2, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 1, 32, 32) - red_buf0_2[0] = red_buf0_2[0] + t0_2[0] - red_buf_staging_1 = T.Buffer((8,), data=red_buf_staging, scope="shared") + red_buf0_1[0] = A_1[threadIdx_y * 128 + threadIdx_x] + t0_1[0] = T.tvm_warp_shuffle_down(0, red_buf0_1[0], 16, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(0, red_buf0_1[0], 8, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(0, red_buf0_1[0], 4, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(0, red_buf0_1[0], 2, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(0, red_buf0_1[0], 1, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] if threadIdx_x % 32 == 0: - red_buf_staging_1[threadIdx_y * 4 + threadIdx_x // 32] = red_buf0_2[0] + red_buf_staging[threadIdx_y * 4 + threadIdx_x // 32] = red_buf0_1[0] T.tvm_storage_sync("shared") - red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local") if threadIdx_x < 4: - red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 4 + threadIdx_x] - t0_3 = T.Buffer((1,), data=t0, scope="local") - t0_3[0] = T.tvm_warp_shuffle_down(0, red_buf0_3[0], 2, 32, 32) - red_buf0_3[0] = red_buf0_3[0] + t0_3[0] - t0_3[0] = T.tvm_warp_shuffle_down(0, red_buf0_3[0], 1, 32, 32) - red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + red_buf0[0] = red_buf_staging[threadIdx_y * 4 + threadIdx_x] + t0[0] = T.tvm_warp_shuffle_down(0, red_buf0[0], 2, 32, 32) + red_buf0[0] = red_buf0[0] + t0[0] + t0[0] = T.tvm_warp_shuffle_down(0, red_buf0[0], 1, 32, 32) + red_buf0[0] = red_buf0[0] + t0[0] if threadIdx_x == 0: - red_result_1[threadIdx_y] = red_buf0_3[0] + red_result_1[threadIdx_y] = red_buf0[0] T.tvm_storage_sync("shared") if threadIdx_x == 0: B_1 = T.Buffer((2,), data=B.data)