From 449b3d755c942afea97d80cfa4ce1c0c490a02b8 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 16 Apr 2024 12:57:27 -0700 Subject: [PATCH 1/2] [Dlight] Enhance vectorization for gpu matmul --- python/tvm/dlight/gpu/matmul.py | 7 ++- tests/python/dlight/test_gpu_matmul.py | 81 +++++++++++++------------- 2 files changed, 45 insertions(+), 43 deletions(-) diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index 0f224b89f9e4..9043e8679e6a 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -873,7 +873,10 @@ def apply( # pylint: disable=too-many-locals,missing-docstring x, [None, config.vthread_x, config.block_size_x, config.micro_size_x] ) ko, ki = sch.split(k, factors=[None, config.micro_size_k]) - sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, yi, xi) + reordered_loops = [by, bx, vy, vx, ty, tx, ko, ki] + ( + [yi, xi] if config.inner_x else [xi, yi] + ) + sch.reorder(*reordered_loops) by = sch.fuse(batch, by) sch.bind(bx, "blockIdx.x") sch.bind(by, "blockIdx.y") @@ -883,7 +886,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring sch.bind(tx, "threadIdx.x") inner_loop = config.micro_size_x if config.inner_x else config.micro_size_y if inner_loop % config.vector_size == 0: - _, v = sch.split(xi, [None, config.vector_size]) + _, v = sch.split(reordered_loops[-1], [None, config.vector_size]) sch.vectorize(v) if config.unroll > 0: diff --git a/tests/python/dlight/test_gpu_matmul.py b/tests/python/dlight/test_gpu_matmul.py index 82f481da469d..a421d9e6c734 100644 --- a/tests/python/dlight/test_gpu_matmul.py +++ b/tests/python/dlight/test_gpu_matmul.py @@ -63,12 +63,12 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"): for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - for ax2_3_init, ax1_3_0_init in T.grid(T.int64(4), T.int64(2)): - for ax1_3_1_init in T.vectorized(T.int64(2)): + for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(2)): + for ax2_3_1_init in T.vectorized(T.int64(2)): with T.block("matmul_init"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_0_init * T.int64(2) + ax1_3_1_init) - v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init) + v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init) + v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0_init * T.int64(2) + ax2_3_1_init) T.reads() T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2]) matmul_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0) @@ -97,12 +97,12 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), T.writes(inp1_reindex_shared[v0, v1, v2]) T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1] - for ax3_1, ax2_3, ax1_3_0 in T.grid(T.int64(16), T.int64(4), T.int64(2)): - for ax1_3_1 in T.vectorized(T.int64(2)): + for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(16), T.int64(4), T.int64(2)): + for ax2_3_1 in T.vectorized(T.int64(2)): with T.block("matmul_update"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_0 * T.int64(2) + ax1_3_1) - v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3) + v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3) + v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0 * T.int64(2) + ax2_3_1) v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1) T.reads(matmul_reindex_pad_local[T.int64(0), v1, v2], inp0_reindex_pad_shared[T.int64(0), v1, v3], inp1_reindex_shared[T.int64(0), v2, v3]) T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2]) @@ -117,7 +117,6 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), T.writes(matmul[T.int64(0), v1, v2]) if v1 < m: matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2] - # fmt: on @@ -151,12 +150,12 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((4096, 4096), "float32"), var_ma for ax1_1 in T.thread_binding(1, thread="vthread.x"): for ax2_2 in T.thread_binding(16, thread="threadIdx.y"): for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - for ax2_3_init, ax1_3_0_init in T.grid(4, 2): - for ax1_3_1_init in T.vectorized(2): + for ax1_3_init, ax2_3_0_init in T.grid(4, 2): + for ax2_3_1_init in T.vectorized(2): with T.block("matmul_init"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0_init * 2 + ax1_3_1_init) - v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init) + v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init) + v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_0_init * 2 + ax2_3_1_init) T.reads() T.writes(matmul_reindex_pad_local[0, v1, v2]) matmul_reindex_pad_local[0, v1, v2] = T.float32(0) @@ -185,12 +184,12 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((4096, 4096), "float32"), var_ma T.writes(inp1_reindex_shared[v0, v1, v2]) T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1] - for ax3_1, ax2_3, ax1_3_0 in T.grid(16, 4, 2): - for ax1_3_1 in T.vectorized(2): + for ax3_1, ax1_3, ax2_3_0 in T.grid(16, 4, 2): + for ax2_3_1 in T.vectorized(2): with T.block("matmul_update"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0 * 2 + ax1_3_1) - v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3) + v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3) + v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_0 * 2 + ax2_3_1) v3 = T.axis.reduce(4096, ax3_0 * 16 + ax3_1) T.reads(matmul_reindex_pad_local[0, v1, v2], inp0_reindex_pad_shared[0, v1, v3], inp1_reindex_shared[0, v2, v3]) T.writes(matmul_reindex_pad_local[0, v1, v2]) @@ -254,12 +253,12 @@ def expected(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer(( for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"): for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - for ax2_3_init, ax1_3_0_init in T.grid(T.int64(4), T.int64(2)): - for ax1_3_1_init in T.vectorized(T.int64(2)): + for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(2)): + for ax2_3_1_init in T.vectorized(T.int64(2)): with T.block("matmul_init"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_0_init * T.int64(2) + ax1_3_1_init) - v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init) + v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init) + v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0_init * T.int64(2) + ax2_3_1_init) T.reads() T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2]) var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = T.float32(0) @@ -288,12 +287,12 @@ def expected(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer(( T.writes(var_decode_intermediate_reindex_shared[v0, v1, v2]) T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) var_decode_intermediate_reindex_shared[v0, v1, v2] = T.Cast("float32", T.bitwise_and(T.shift_right(W[v2 // T.int64(8), v1], T.Cast("uint32", v2 % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(S[v2 // T.int64(32), v1], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(S[v2 // T.int64(32), v1], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for ax3_1, ax2_3, ax1_3_0 in T.grid(T.int64(16), T.int64(4), T.int64(2)): - for ax1_3_1 in T.vectorized(T.int64(2)): + for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(16), T.int64(4), T.int64(2)): + for ax2_3_1 in T.vectorized(T.int64(2)): with T.block("matmul_update"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_0 * T.int64(2) + ax1_3_1) - v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3) + v1 = T.axis.spatial(T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3) + v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0 * T.int64(2) + ax2_3_1) v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1) T.reads(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2], A_reindex_shared[T.int64(0), v1, v3], var_decode_intermediate_reindex_shared[T.int64(0), v2, v3]) T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2]) @@ -417,12 +416,12 @@ def expected(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), lv14: T.Bu for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"): for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - for ax2_3_init, ax1_3_0_init in T.grid(T.int64(4), T.int64(2)): - for ax1_3_1_init in T.vectorized(T.int64(2)): + for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(2)): + for ax2_3_1_init in T.vectorized(T.int64(2)): with T.block("matmul_init"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_0_init * T.int64(2) + ax1_3_1_init) - v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init) + v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init) + v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0_init * T.int64(2) + ax2_3_1_init) T.reads() T.writes(var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2]) var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0) @@ -451,12 +450,12 @@ def expected(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), lv14: T.Bu T.writes(p_output0_intermediate_1_reindex_shared[v0, v1, v2]) T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) p_output0_intermediate_1_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv13[v2, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv14[v2, v1 // T.int64(32)] - for ax3_1, ax2_3, ax1_3_0 in T.grid(T.int64(16), T.int64(4), T.int64(2)): - for ax1_3_1 in T.vectorized(T.int64(2)): + for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(16), T.int64(4), T.int64(2)): + for ax2_3_1 in T.vectorized(T.int64(2)): with T.block("matmul_update"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_0 * T.int64(2) + ax1_3_1) - v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3) + v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3) + v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0 * T.int64(2) + ax2_3_1) v3 = T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1) T.reads(var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2], lv48_reindex_pad_shared[T.int64(0), v1, v3], p_output0_intermediate_1_reindex_shared[T.int64(0), v2, v3]) T.writes(var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2]) @@ -546,12 +545,12 @@ def expected(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048), T.int64(2048)), "fl for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"): for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"): for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - for ax2_3_init, ax1_3_0_init in T.grid(T.int64(4), T.int64(2)): - for ax1_3_1_init in T.vectorized(T.int64(2)): + for ax1_3_init, ax2_3_0_init in T.grid(T.int64(4), T.int64(2)): + for ax2_3_1_init in T.vectorized(T.int64(2)): with T.block("NT_matmul_init"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_0_init * T.int64(2) + ax1_3_1_init) - v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init) + v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init) + v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0_init * T.int64(2) + ax2_3_1_init) T.reads() T.writes(var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2]) var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = T.float16(0) @@ -580,12 +579,12 @@ def expected(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048), T.int64(2048)), "fl T.writes(lv9_reindex_shared[v0, v1, v2]) T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) lv9_reindex_shared[v0, v1, v2] = lv9[v1, v2] - for ax3_1, ax2_3, ax1_3_0 in T.grid(T.int64(16), T.int64(4), T.int64(2)): - for ax1_3_1 in T.vectorized(T.int64(2)): + for ax3_1, ax1_3, ax2_3_0 in T.grid(T.int64(16), T.int64(4), T.int64(2)): + for ax2_3_1 in T.vectorized(T.int64(2)): with T.block("NT_matmul_update"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_0 * T.int64(2) + ax1_3_1) - v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3) + v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3) + v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_0 * T.int64(2) + ax2_3_1) v3 = T.axis.reduce(T.int64(2048), ax3_0 * T.int64(16) + ax3_1) T.reads(var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2], lv26_reindex_pad_shared[T.int64(0), v1, v3], lv9_reindex_shared[T.int64(0), v2, v3]) T.writes(var_NT_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2]) From f52ffa7e81e1ca9fda094bdae3879e7390af703a Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 17 Apr 2024 10:21:22 -0700 Subject: [PATCH 2/2] fix --- .../python/dlight/test_gpu_matmul_tensorize.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/python/dlight/test_gpu_matmul_tensorize.py b/tests/python/dlight/test_gpu_matmul_tensorize.py index 72ffb307194a..095447766e28 100644 --- a/tests/python/dlight/test_gpu_matmul_tensorize.py +++ b/tests/python/dlight/test_gpu_matmul_tensorize.py @@ -190,7 +190,7 @@ def before(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T.ha @T.prim_func def expected(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T.handle): - T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) m = T.int32() X = T.match_buffer(var_X, (m, 256), "float16") compute = T.match_buffer(var_compute, (m, 15)) @@ -204,12 +204,12 @@ def expected(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T. for ax1_1 in T.thread_binding(1, thread="vthread.x"): for ax2_2 in T.thread_binding(16, thread="threadIdx.y"): for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - for ax2_3_init, ax1_3_0_init in T.grid(4, 2): - for ax1_3_1_init in T.vectorized(2): + for ax1_3_init, ax2_3_0_init in T.grid(4, 2): + for ax2_3_1_init in T.vectorized(2): with T.block("compute_init"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0_init * 2 + ax1_3_1_init) - v2 = T.axis.spatial(64, ax2_1 * 64 + ax2_2 * 4 + ax2_3_init) + v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init) + v2 = T.axis.spatial(64, ax2_1 * 64 + ax2_2 * 4 + ax2_3_0_init * 2 + ax2_3_1_init) T.reads() T.writes(compute_reindex_pad_local[0, v1, v2]) compute_reindex_pad_local[0, v1, v2] = T.float32(0) @@ -238,12 +238,12 @@ def expected(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T. T.writes(W_reindex_pad_shared[v0, v1, v2]) T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) W_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < 15, W[v1, v2], T.float16(0)) - for ax3_1, ax2_3, ax1_3_0 in T.grid(16, 4, 2): - for ax1_3_1 in T.vectorized(2): + for ax3_1, ax1_3, ax2_3_0 in T.grid(16, 4, 2): + for ax2_3_1 in T.vectorized(2): with T.block("compute_update"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0 * 2 + ax1_3_1) - v2 = T.axis.spatial(64, ax2_1 * 64 + ax2_2 * 4 + ax2_3) + v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3) + v2 = T.axis.spatial(64, ax2_1 * 64 + ax2_2 * 4 + ax2_3_0 * 2 + ax2_3_1) v3 = T.axis.reduce(256, ax3_0 * 16 + ax3_1) T.reads(compute_reindex_pad_local[0, v1, v2], X_reindex_pad_shared[0, v1, v3], W_reindex_pad_shared[0, v2, v3]) T.writes(compute_reindex_pad_local[0, v1, v2])