From c2da40d2df663bea7c478ff085b38a46173608d9 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 4 Apr 2024 20:53:34 +0800 Subject: [PATCH] [DLight] Fix a corner case for reduction rule The current rule will fail when the output shape is only one element, because of missing `preserve_unit_loops`. This PR fixes it and adding a test case. --- python/tvm/dlight/gpu/reduction.py | 2 +- tests/python/dlight/test_gpu_reduction.py | 93 +++++++++++++++++++---- 2 files changed, 79 insertions(+), 16 deletions(-) diff --git a/python/tvm/dlight/gpu/reduction.py b/python/tvm/dlight/gpu/reduction.py index 651e09dc5232..4cc142ab1614 100644 --- a/python/tvm/dlight/gpu/reduction.py +++ b/python/tvm/dlight/gpu/reduction.py @@ -217,7 +217,7 @@ def _sch_inner_reduction( # pylint: disable=too-many-arguments # Schedule epilogue if epilogue_info is not None: epilogue = epilogue_info.block_rv - sch.reverse_compute_at(epilogue, bx) + sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True) if is_broadcast_epilogue(sch, block, epilogue): sch.set_scope(block, 0, "shared") _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name diff --git a/tests/python/dlight/test_gpu_reduction.py b/tests/python/dlight/test_gpu_reduction.py index def124a9b29a..1ce57eb53d22 100644 --- a/tests/python/dlight/test_gpu_reduction.py +++ b/tests/python/dlight/test_gpu_reduction.py @@ -377,11 +377,12 @@ def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, D_handle: T with T.init(): C_local[0, 0, v0] = T.float16(0) C_local[0, 0, v0] = C_local[0, 0, v0] + C_rf_local[vax1_0_fused_1, 0, 0, v0] - with T.block("sigmoid"): - v0 = T.axis.spatial(4096, ax0_fused) - T.reads(C_local[0, 0, v0]) - T.writes(D[0, 0, v0]) - D[0, 0, v0] = T.sigmoid(C_local[0, 0, v0]) + for ax0 in range(1): + with T.block("sigmoid"): + v0 = T.axis.spatial(4096, ax0_fused + ax0) + T.reads(C_local[0, 0, v0]) + T.writes(D[0, 0, v0]) + D[0, 0, v0] = T.sigmoid(C_local[0, 0, v0]) # fmt: on @@ -465,11 +466,12 @@ def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, C_handle: T with T.init(): C_fp32_local[0, 0, v0] = T.float32(0) C_fp32_local[0, 0, v0] = C_fp32_local[0, 0, v0] + C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] - with T.block("cast"): - v0 = T.axis.spatial(4096, ax0_fused) - T.reads(C_fp32_local[0, 0, v0]) - T.writes(C[0, 0, v0]) - C[0, 0, v0] = T.Cast("float16", C_fp32_local[0, 0, v0]) + for ax0 in range(1): + with T.block("cast"): + v0 = T.axis.spatial(4096, ax0_fused + ax0) + T.reads(C_fp32_local[0, 0, v0]) + T.writes(C[0, 0, v0]) + C[0, 0, v0] = T.Cast("float16", C_fp32_local[0, 0, v0]) # fmt: on @@ -760,11 +762,12 @@ def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), "float32")): with T.init(): temp_local_local[v0] = T.float32(0) temp_local_local[v0] = temp_local_local[v0] + temp_local_rf_local[vax1_fused_1, v0] - with T.block("add"): - v0 = T.axis.spatial(256, ax0_fused) - T.reads(temp_local_local[v0]) - T.writes(B[v0]) - B[v0] = temp_local_local[v0] + T.float32(1) + for ax0 in range(1): + with T.block("add"): + v0 = T.axis.spatial(256, ax0_fused + ax0) + T.reads(temp_local_local[v0]) + T.writes(B[v0]) + B[v0] = temp_local_local[v0] + T.float32(1) # fmt: on target = Target("nvidia/geforce-rtx-3090-ti") @@ -1089,5 +1092,65 @@ def main(var_A: T.handle, B: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), " assert_structural_equal(mod, Expected) +def test_gemv_output_one_element(): + # fmt: off + @I.ir_module + class Before: + @T.prim_func(private=True) + def main(A: T.Buffer((T.int64(1), T.int64(2048)), "float16"), weight: T.Buffer((T.int64(1), T.int64(2048)), "float16"), out: T.Buffer((T.int64(1), T.int64(1)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1)), "float16") + for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(2048)): + with T.block("NT_matmul"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + with T.init(): + NT_matmul_intermediate[v_i0, v_i1] = T.float16(0) + NT_matmul_intermediate[v_i0, v_i1] = NT_matmul_intermediate[v_i0, v_i1] + A[v_i0, v_k] * weight[v_i1, v_k] + for i0, i1 in T.grid(T.int64(1), T.int64(1)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + out[v_i0, v_i1] = T.sigmoid(NT_matmul_intermediate[v_i0, v_i1]) + + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def main(A: T.Buffer((T.int64(1), T.int64(2048)), "float16"), weight: T.Buffer((T.int64(1), T.int64(2048)), "float16"), out: T.Buffer((T.int64(1), T.int64(1)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + NT_matmul_intermediate_shared = T.alloc_buffer((T.int64(1), T.int64(1)), "float16", scope="shared") + NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(1024), T.int64(1), T.int64(1)), "float16", scope="local") + for ax0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"): + for ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("NT_matmul_rf_init"): + vax1_fused_1 = T.axis.spatial(T.int64(1024), ax1_fused_1) + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + NT_matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), T.int64(0)] = T.float16(0) + for ax1_fused_0, u in T.grid(T.int64(2), 1): + with T.block("NT_matmul_rf_update"): + vax1_fused_1 = T.axis.spatial(T.int64(1024), ax1_fused_1) + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + vax1_fused_0 = T.axis.reduce(T.int64(2), ax1_fused_0) + NT_matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), T.int64(0)] = NT_matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), T.int64(0)] + A[T.int64(0), vax1_fused_0 * T.int64(1024) + vax1_fused_1] * weight[T.int64(0), vax1_fused_0 * T.int64(1024) + vax1_fused_1] + for ax1_fused in range(T.int64(1)): + for ax0 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("NT_matmul"): + vax1_fused_1 = T.axis.reduce(T.int64(1024), ax0) + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + with T.init(): + NT_matmul_intermediate_shared[T.int64(0), T.int64(0)] = T.float16(0) + NT_matmul_intermediate_shared[T.int64(0), T.int64(0)] = NT_matmul_intermediate_shared[T.int64(0), T.int64(0)] + NT_matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), T.int64(0)] + for ax0_fused_0 in range(T.int64(1)): + for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("compute"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + T.where(ax0_fused_0 * T.int64(1024) + ax0_fused_1 < T.int64(1)) + out[T.int64(0), T.int64(0)] = T.sigmoid(NT_matmul_intermediate_shared[T.int64(0), T.int64(0)]) + # fmt: on + + with Target("nvidia/geforce-rtx-3090-ti"): + mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Before) # pylint: disable=not-callable + assert_structural_equal(mod, Expected) + + if __name__ == "__main__": tvm.testing.main()