diff --git a/python/tvm/dlight/base/__init__.py b/python/tvm/dlight/base/__init__.py index 73b265d6ceb1..a19a292fa13e 100644 --- a/python/tvm/dlight/base/__init__.py +++ b/python/tvm/dlight/base/__init__.py @@ -18,7 +18,8 @@ from .analysis import ( BlockInfo, IterInfo, - collect_vars_used_in_access_region, + collect_block_iter_vars_used_in_access_region, + collect_vars_used_in_prim_expr, detect_dominant_read, is_broadcast_epilogue, normalize_prim_func, diff --git a/python/tvm/dlight/base/analysis.py b/python/tvm/dlight/base/analysis.py index 96bc6286388c..be260b894203 100644 --- a/python/tvm/dlight/base/analysis.py +++ b/python/tvm/dlight/base/analysis.py @@ -208,17 +208,27 @@ def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV: return sch.get_block(block.name_hint) -def collect_vars_used_in_access_region(region: List[ir.Range]) -> Set[tir.Var]: - """Collect the variables used in the access region of a buffer region.""" - tir_vars: Set[tir.Var] = set() +def collect_block_iter_vars_used_in_access_region( + block: tir.Block, region: List[ir.Range] +) -> Set[tir.Var]: + """Collect the block iter variables used in the access region of a buffer region.""" + tir_vars = set() + for expr in region: + assert expr.extent == 1 + tir_vars |= collect_vars_used_in_prim_expr(expr.min) + tir_vars &= set(iter_var.var for iter_var in block.iter_vars) + return tir_vars + + +def collect_vars_used_in_prim_expr(expr: tir.PrimExpr) -> Set[tir.Var]: + """Collect the variables used in the PrimExpr.""" + tir_vars = set() def _collect_tir_var(expr): if isinstance(expr, tir.Var): tir_vars.add(expr) - for expr in region: - assert expr.extent == 1 - tir.stmt_functor.post_order_visit(expr.min, _collect_tir_var) + tir.stmt_functor.post_order_visit(expr, _collect_tir_var) return tir_vars @@ -227,7 +237,7 @@ def detect_dominant_read(block: tir.Block) -> tir.PrimExpr: dominant_read = None num_read_iters = -1 for buffer_region in block.reads: - tir_vars = collect_vars_used_in_access_region(buffer_region.region) + tir_vars = collect_block_iter_vars_used_in_access_region(block, buffer_region.region) if num_read_iters < len(tir_vars): num_read_iters = len(tir_vars) dominant_read = buffer_region @@ -247,7 +257,9 @@ def is_broadcast_epilogue( for buffer_region in sch.get(epilogue).reads: if buffer_region.buffer not in write_buffers: continue - tir_vars = collect_vars_used_in_access_region(buffer_region.region) + tir_vars = collect_block_iter_vars_used_in_access_region( + sch.get(epilogue), buffer_region.region + ) if len(tir_vars) < len(epilogue_iters): return True return False diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index 27b155c6a754..d453b84bc055 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -24,7 +24,8 @@ from ..base import ( BlockInfo, - collect_vars_used_in_access_region, + collect_block_iter_vars_used_in_access_region, + collect_vars_used_in_prim_expr, detect_dominant_read, is_broadcast_epilogue, normalize_prim_func, @@ -86,7 +87,10 @@ def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffe conditions.append(len(block_stmt.reads) >= 2) conditions.append(len(block_stmt.writes) == 1) conditions.append(_get_reduction_expr(block_stmt) is not None) - conditions.append(len(collect_vars_used_in_access_region(block_stmt.writes[0].region)) > 0) + conditions.append( + len(collect_block_iter_vars_used_in_access_region(block_stmt, block_stmt.writes[0].region)) + > 0 + ) if not all(conditions): return None @@ -94,7 +98,8 @@ def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffe ret = [ read.buffer for read in block_stmt.reads - if len(collect_vars_used_in_access_region(read.region)) < iter_num + if len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) < iter_num + and len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) > 0 ] return ret if 0 < len(ret) < len(block_stmt.reads) else None @@ -109,12 +114,19 @@ def normalize( detect_dominant_read(block_stmt), input_iters={i.var: i.dom for i in block_stmt.iter_vars}, ) - - buffers_use_vars = [collect_vars_used_in_access_region(buf.region) for buf in block_stmt.writes] + buffers_use_vars = [ + collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) + for buf in block_stmt.writes + ] buffers_use_vars.extend( - [collect_vars_used_in_access_region(buf.region) for buf in block_stmt.reads] + [ + collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) + for buf in block_stmt.reads + ] ) - if access.base != 0: + if collect_vars_used_in_prim_expr(access.base) & set( + iter_var.var for iter_var in block_stmt.iter_vars + ): return None iter_to_info = {i.var: i for i in block_info.iters} batch_loops, s_loops, r_loops, c_loops = [], [], [], [] @@ -420,15 +432,15 @@ def apply( elif target.kind.name == "metal": # Note that the following tile size is tuned on M2 Ultra for 7B TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" - VEC_C = 4 + VEC_C = 1 LOAD_V_SHARED = False LOAD_V_VEC = -1 UNROLL = 256 if isinstance(len_S, int): if len_S > len_R: - TS, TR = 1, 64 + TS, TR = 4, 16 else: - TS, TR = 1, 256 + TS, TR = 2, 64 elif target.kind.name == "rocm": VEC_C = 4 LOAD_V_SHARED = True diff --git a/python/tvm/dlight/gpu/reduction.py b/python/tvm/dlight/gpu/reduction.py index 2ccc11f7f49e..f07ee45f3729 100644 --- a/python/tvm/dlight/gpu/reduction.py +++ b/python/tvm/dlight/gpu/reduction.py @@ -85,7 +85,7 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- ): return None # Step 2. Normalize the block, merge spatial and reduction iters - is_inner_reduction, c_factor = self._normalize( + is_inner_reduction, c_factor, loop_order, s_split_index = self._normalize( sch, block_info, arith.normalize_to_iter_sum( @@ -97,9 +97,13 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- return None # Step 3. Do the scheduling if is_inner_reduction: - self._sch_inner_reduction(sch, target, block, c_factor, epilogue) + self._sch_inner_reduction( + sch, target, block, c_factor, epilogue, loop_order, s_split_index + ) else: - self._sch_inner_spatial(sch, target, block, block_info, c_factor, epilogue) + self._sch_inner_spatial( + sch, target, block, block_info, c_factor, epilogue, loop_order, s_split_index + ) return sch def _normalize( # pylint: disable=too-many-branches @@ -112,6 +116,7 @@ def _normalize( # pylint: disable=too-many-branches return None, None iter_to_info = {i.var: i for i in block_info.iters} s_loops, r_loops, c_loops, c_factor = [], [], [], None + s_split_loop, s_split_index = None, None for split_expr in access.args: var = split_expr.source.source info = iter_to_info.pop(var) @@ -120,6 +125,8 @@ def _normalize( # pylint: disable=too-many-branches if split_expr.lower_factor > 1: if c_loops: return None, None + s_split_loop = loop + s_split_index = len(s_loops) loop, c_loop = sch.split(loop, factors=[None, split_expr.lower_factor]) c_loops.append(c_loop) if not is_inner_reduction: @@ -135,6 +142,22 @@ def _normalize( # pylint: disable=too-many-branches s_loops.append(info.loop_rv) else: return None, None + + loop_order = {} + s_block_var_loops = [] + for i in block_info.iters: + if i.loop_rv in s_loops or i.loop_rv == s_split_loop: + s_block_var_loops.append(i.loop_rv) + + for i in range(len(s_block_var_loops)): + for j in range(len(s_loops)): + if s_block_var_loops[i] == s_loops[j]: + loop_order[i] = j + break + if s_block_var_loops[i] == s_split_loop: + loop_order[i] = s_split_index + break + assert s_loops assert r_loops if len(s_loops) != len([i for i in block_info.iters if i.kind == "S"]): @@ -144,7 +167,7 @@ def _normalize( # pylint: disable=too-many-branches sch.reorder(*s_loops, *r_loops, *c_loops) sch.fuse(*s_loops) sch.fuse(*r_loops) - return is_inner_reduction, c_factor + return is_inner_reduction, c_factor, loop_order, s_split_index def _sch_inner_reduction( # pylint: disable=too-many-arguments self, @@ -153,6 +176,8 @@ def _sch_inner_reduction( # pylint: disable=too-many-arguments block: tir.schedule.BlockRV, unroll_spatial_factor: Optional[int], epilogue_info: Optional[BlockInfo], + loop_order, + s_split_index, ): # pylint: disable=invalid-name _, r, _ = sch.get_loops(block) @@ -174,11 +199,20 @@ def _sch_inner_reduction( # pylint: disable=too-many-arguments # Schedule the write back block sch.reverse_compute_at(block, bx, preserve_unit_loops=True) _, tx, *s = sch.get_loops(block) - s = sch.fuse(*s) - sch.reorder(s, tx) + if unroll_spatial_factor: - s, inner = sch.split(s, factors=[None, unroll_spatial_factor]) - sch.reorder(s, tx, inner) + assert len(s) == len(loop_order) + new_order_s = [s[loop_order[i]] for i in range(len(s))] + sch.reorder(*new_order_s) + new_order_s[s_split_index], c = sch.split( + new_order_s[s_split_index], factors=[None, unroll_spatial_factor] + ) + sch.reorder(*new_order_s, c) + s = sch.fuse(*new_order_s) + sch.reorder(s, tx, c) + else: + s = sch.fuse(*s) + sch.reorder(s, tx) sch.bind(tx, "threadIdx.x") # Schedule epilogue if epilogue_info is not None: @@ -201,6 +235,8 @@ def _sch_inner_spatial( block_info: BlockInfo, unroll_spatial_factor: Optional[int], epilogue_info: Optional[BlockInfo], + loop_order, + s_split_index, ): # pylint: disable=invalid-name s, r, _ = sch.get_loops(block) @@ -226,12 +262,22 @@ def _sch_inner_spatial( # Schedule the write back block sch.reverse_compute_at(block, bx, preserve_unit_loops=True) _, r, *s = sch.get_loops(block) - s = sch.fuse(*s) - sch.reorder(s, r) if unroll_spatial_factor: - s, _ = sch.split(s, factors=[None, unroll_spatial_factor]) + assert len(s) == len(loop_order) + new_order_s = [s[loop_order[i]] for i in range(len(s))] + sch.reorder(*new_order_s) + new_order_s[s_split_index], c = sch.split( + new_order_s[s_split_index], factors=[None, unroll_spatial_factor] + ) + sch.reorder(*new_order_s, c) + s = sch.fuse(*new_order_s) + sch.reorder(s, c, r) + else: + s = sch.fuse(*s) + sch.reorder(s, r) sch.bind(s, "threadIdx.x") sch.bind(r, "threadIdx.y") + # Schedule epilogue if epilogue_info is not None: epilogue = epilogue_info.block_rv diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index 0383902cd67f..b57d67f83bbd 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -444,10 +444,22 @@ PrimExpr BlockBufferAccessSimplifier::VisitExpr_(const BufferLoadNode* op) { /******** PrimFunc-level analysis and transformation ********/ +void GetLeafBlocksHelper(Schedule sch, BlockRV cur_block_rv, Array* leaf_blocks) { + Array blocks = sch->GetChildBlocks(cur_block_rv); + if (blocks.empty()) { + leaf_blocks->push_back(cur_block_rv); + } else { + for (const BlockRV& block : blocks) { + GetLeafBlocksHelper(sch, block, leaf_blocks); + } + } +} + Optional NormalizePrimFunc(Schedule sch) { BlockRV root_block = sch->GetBlock("root"); - Array blocks = sch->GetChildBlocks(root_block); - for (const BlockRV& block : blocks) { + Array leaf_blocks; + GetLeafBlocksHelper(sch, root_block, &leaf_blocks); + for (const BlockRV& block : leaf_blocks) { StmtSRef block_sref = sch->GetSRef(block); Array loops = GetLoops(block_sref); Array binds = GetBlockRealize(sch->state(), block_sref)->iter_values; @@ -465,10 +477,11 @@ Optional NormalizePrimFunc(Schedule sch) { } } } + Array> block_loops; Array> block_iters; Array block_is_reduction; - for (const BlockRV& block : blocks) { + for (const BlockRV& block : leaf_blocks) { Array iters = sch->Get(block)->iter_vars; bool has_spatial_iter = false; Array index_map_inputs; @@ -498,7 +511,7 @@ Optional NormalizePrimFunc(Schedule sch) { sch->GetSRef(root_block)); block_is_reduction.push_back(Bool(is_reduction)); } - return Array{blocks, block_loops, block_iters, block_is_reduction}; + return Array{leaf_blocks, block_loops, block_iters, block_is_reduction}; } TVM_REGISTER_GLOBAL("tir.schedule.NormalizePrimFunc").set_body_typed(NormalizePrimFunc); diff --git a/tests/python/dlight/test_gpu_gemv.py b/tests/python/dlight/test_gpu_gemv.py index 83d2c3c06cb1..b4f4250bbf7d 100644 --- a/tests/python/dlight/test_gpu_gemv.py +++ b/tests/python/dlight/test_gpu_gemv.py @@ -209,58 +209,61 @@ def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128) def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): - var_NT_matmul_intermediate_rf_local = T.alloc_buffer((256, 1, 1, 22016), "float16", scope="local") - var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((64, 1, 1, 22016), "float16", scope="local") + var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 22016), "float16", scope="local") + var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((16, 1, 1, 22016), "float16", scope="local") lv571_local = T.alloc_buffer((22016, 512), "uint32", scope="local") - for u_fused_ax0_fused_fused_0 in T.thread_binding(22016, thread="blockIdx.x"): - for u_fused_ax0_fused_fused_1 in T.thread_binding(1, thread="threadIdx.x"): - for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(64, thread="threadIdx.y"): + for u_fused_ax0_fused_fused_0 in T.thread_binding(5504, thread="blockIdx.x"): + for u_fused_ax0_fused_fused_1 in T.thread_binding(4, thread="threadIdx.x"): + for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(16, thread="threadIdx.y"): for u_fused_ax0_fused_fused_2_init in range(1): - for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(4): + for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(1): with T.block("NT_matmul_rf_init"): - vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(256, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) - v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) + v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) T.reads() T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0) - for ax1_0_fused_ax1_1_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax1_0_fused_ax1_1_fused_0 in T.serial(32, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0_0, ax1 in T.grid(1, 1): for ax0_1 in T.vectorized(1): with T.block("lv571_local"): - v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 + ax0_0 + ax0_1) - v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 64 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1) + v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1) + v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 16 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1) T.reads(lv571[v0, v1]) T.writes(lv571_local[v0, v1]) lv571_local[v0, v1] = lv571[v0, v1] - for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 2): - for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(4): + for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 8): + for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(1): with T.block("NT_matmul_rf_update"): - vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(256, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) - v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) + v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2]) - T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv1654[0, 0, vax1_0_fused_ax1_1_fused_0 * 512 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4], lv571_local[v0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + vax1_0_fused_ax1_1_fused_2 // 2], lv572[v0, (vax1_0_fused_ax1_1_fused_0 * 512 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) // 32]) + T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv1654[0, 0, vax1_0_fused_ax1_1_fused_0 * 128 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused * 8 + vax1_0_fused_ax1_1_fused_2], lv571_local[v0, vax1_0_fused_ax1_1_fused_0 * 16 + vax1_0_fused_ax1_1_fused_2 // 8 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused], lv572[v0, (vax1_0_fused_ax1_1_fused_0 * 128 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused * 8 + vax1_0_fused_ax1_1_fused_2) // 32]) T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) - var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv1654[0, 0, vax1_0_fused_ax1_1_fused_0 * 512 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv571_local[v0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + vax1_0_fused_ax1_1_fused_2 // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 512 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv572[v0, (vax1_0_fused_ax1_1_fused_0 * 512 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) // 32]) - for ax2_fused_0 in T.thread_binding(1, thread="threadIdx.x"): - for ax0 in T.thread_binding(64, thread="threadIdx.y"): + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv1654[0, 0, vax1_0_fused_ax1_1_fused_0 * 128 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused * 8 + vax1_0_fused_ax1_1_fused_2] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv571_local[v0, vax1_0_fused_ax1_1_fused_0 * 16 + vax1_0_fused_ax1_1_fused_2 // 8 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 128 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused * 8 + vax1_0_fused_ax1_1_fused_2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv572[v0, (vax1_0_fused_ax1_1_fused_0 * 128 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused * 8 + vax1_0_fused_ax1_1_fused_2) // 32]) + for ax2_fused_0 in T.thread_binding(4, thread="threadIdx.x"): + for ax0 in T.thread_binding(16, thread="threadIdx.y"): for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_fused_1_1 in T.vectorized(1): with T.block("NT_matmul_rf_init"): - vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, v0 = T.axis.remap("SS", [ax0, u_fused_ax0_fused_fused_0]) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(16, ax0) + v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) T.reads() T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0) - for ax1 in range(4): + for ax1 in range(1): with T.block("NT_matmul_rf_update"): - vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, v0 = T.axis.remap("SRS", [ax0, ax1, u_fused_ax0_fused_fused_0]) - T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) + T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]) T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) - var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0] for ax1_fused_1 in range(1): - for ax1_fused_0 in T.thread_binding(1, thread="threadIdx.x"): - for ax0 in T.thread_binding(64, thread="threadIdx.y"): + for ax1_fused_0 in T.thread_binding(4, thread="threadIdx.x"): + for ax0 in T.thread_binding(16, thread="threadIdx.y"): with T.block("NT_matmul"): - vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, v0 = T.axis.remap("RS", [ax0, u_fused_ax0_fused_fused_0]) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(16, ax0) + v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + ax1_fused_0 + ax1_fused_1) T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) T.writes(var_NT_matmul_intermediate[0, 0, v0]) with T.init(): @@ -893,5 +896,106 @@ def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), tvm.ir.assert_structural_equal(mod["main"], expected) +def test_blockized_gemv(): + # fmt: off + @T.prim_func(private=True) + def before(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "float16"), indptr: T.Buffer((2,), "int32"), o: T.Buffer((2, 16384), "float16")): + # with T.block("root"): + for expert_id in T.thread_binding(2, thread="blockIdx.y"): + with T.block("gemv_o"): + v_expert_id_o = T.axis.spatial(2, expert_id) + vi_o = T.axis.spatial(1, 0) + vj_o = T.axis.reduce(1, 0) + T.reads(x[0, 0:4096], w[indptr[v_expert_id_o], 0:16384, 0:4096], indptr[v_expert_id_o]) + T.writes(o[v_expert_id_o, 0:16384]) + for i, j in T.grid(16384, 4096): + with T.block("gemv"): + vi_i, vj_i = T.axis.remap("SR", [i, j]) + T.reads(x[0, vj_i], w[indptr[v_expert_id_o], vi_i, vj_i], indptr[v_expert_id_o]) + T.writes(o[v_expert_id_o, vi_i]) + with T.init(): + o[v_expert_id_o, vi_i] = T.float16(0) + o[v_expert_id_o, vi_i] = o[v_expert_id_o, vi_i] + x[0, vj_i] * w[indptr[v_expert_id_o], vi_i, vj_i] + + @T.prim_func(private=True) + def expected(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "float16"), indptr: T.Buffer((2,), "int32"), o: T.Buffer((2, 16384), "float16")): + T.func_attr({"tir.is_scheduled": 1}) + # with T.block("root"): + for expert_id in T.thread_binding(2, thread="blockIdx.y"): + with T.block("gemv_o"): + v_expert_id_o = T.axis.spatial(2, expert_id) + vi_o = T.axis.spatial(1, 0) + vj_o = T.axis.reduce(1, 0) + T.reads(x[0, 0:4096], w[indptr[v_expert_id_o], 0:16384, 0:4096], indptr[v_expert_id_o]) + T.writes(o[v_expert_id_o, 0:16384]) + o_rf_local = T.alloc_buffer((16, 2, 16384), "float16", scope="local") + o_rf_local_1 = T.alloc_buffer((16, 2, 16384), "float16", scope="local") + w_local = T.alloc_buffer((1, 16384, 4096), "float16", scope="local") + for u_fused_ax0_fused_fused_0 in T.thread_binding(4096, thread="blockIdx.x"): + for u_fused_ax0_fused_fused_1 in T.thread_binding(4, thread="threadIdx.x"): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 in T.thread_binding(16, thread="threadIdx.y"): + for u_fused_ax0_fused_fused_2_init in range(1): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init in T.vectorized(1): + with T.block("gemv_rf_init"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(16, ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init) + v0 = T.axis.spatial(16384, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) + T.reads() + T.writes(o_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, v_expert_id_o, v0]) + o_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, v_expert_id_o, v0] = T.float16(0) + for ax1_fused_u_fused_0 in T.serial(32, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax0, ax1_0, ax2 in T.grid(1, 1, 8): + for ax1_1 in T.vectorized(1): + with T.block("w_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(16384, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax1_0 + ax1_1) + v2 = T.axis.spatial(4096, ax1_fused_u_fused_0 * 128 + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * 8 + ax2) + T.reads(w[indptr[v_expert_id_o] + v0, v1, v2]) + T.writes(w_local[v0, v1, v2]) + w_local[v0, v1, v2] = w[indptr[v_expert_id_o] + v0, v1, v2] + for u_fused_ax0_fused_fused_2, ax1_fused_u_fused_2 in T.grid(1, 8): + for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T.vectorized(1): + with T.block("gemv_rf_update"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(16, ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1) + v0 = T.axis.spatial(16384, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) + vax1_fused_u_fused_0, vax1_fused_u_fused_2 = T.axis.remap("RR", [ax1_fused_u_fused_0, ax1_fused_u_fused_2]) + T.reads(o_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, v_expert_id_o, v0], x[0, vax1_fused_u_fused_0 * 128 + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused * 8 + vax1_fused_u_fused_2], w_local[0, v0, vax1_fused_u_fused_0 * 128 + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused * 8 + vax1_fused_u_fused_2], indptr[v_expert_id_o]) + T.writes(o_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, v_expert_id_o, v0]) + o_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, v_expert_id_o, v0] = o_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, v_expert_id_o, v0] + x[0, vax1_fused_u_fused_0 * 128 + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused * 8 + vax1_fused_u_fused_2] * w_local[0, v0, vax1_fused_u_fused_0 * 128 + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused * 8 + vax1_fused_u_fused_2] + for ax2_fused_0 in T.thread_binding(4, thread="threadIdx.x"): + for ax0 in T.thread_binding(16, thread="threadIdx.y"): + for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax2_fused_1_1 in T.vectorized(1): + with T.block("gemv_rf_init"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.spatial(16, ax0) + v0 = T.axis.spatial(16384, u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) + T.reads() + T.writes(o_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, v_expert_id_o, v0]) + o_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, v_expert_id_o, v0] = T.float16(0) + for ax1 in range(1): + with T.block("gemv_rf_update"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial(16384, u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) + T.reads(o_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, v_expert_id_o, v0], o_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, v_expert_id_o, v0]) + T.writes(o_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, v_expert_id_o, v0]) + o_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, v_expert_id_o, v0] = o_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, v_expert_id_o, v0] + o_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, v_expert_id_o, v0] + for ax1_fused_1 in range(1): + for ax1_fused_0 in T.thread_binding(4, thread="threadIdx.x"): + for ax0 in T.thread_binding(16, thread="threadIdx.y"): + with T.block("gemv"): + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.reduce(16, ax0) + v0 = T.axis.spatial(16384, u_fused_ax0_fused_fused_0 * 4 + ax1_fused_0 + ax1_fused_1) + T.reads(o_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, v_expert_id_o, v0]) + T.writes(o[v_expert_id_o, v0]) + with T.init(): + o[v_expert_id_o, v0] = T.float16(0) + o[v_expert_id_o, v0] = o[v_expert_id_o, v0] + o_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, v_expert_id_o, v0] + # fmt: on + + mod = tvm.IRModule({"main": before}) + with Target("metal"): + mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/dlight/test_gpu_reduction.py b/tests/python/dlight/test_gpu_reduction.py index ac34d1ee91a4..75d2eeeb0716 100644 --- a/tests/python/dlight/test_gpu_reduction.py +++ b/tests/python/dlight/test_gpu_reduction.py @@ -223,7 +223,7 @@ def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, C_handle: T for ax1_fused_1 in range(8): with T.block("matmul"): vax1_fused_1 = T.axis.reduce(1024, ax0) - v0 = T.axis.spatial(4096, ax0_0_fused * 8 + ax1_fused_0 * 8 + ax1_fused_1) + v0 = T.axis.spatial(4096, ax0_0_fused * 8 + ax1_fused_1) T.reads(C_rf_local[vax1_fused_1, 0, 0, v0]) T.writes(C[0, 0, v0]) with T.init(): @@ -922,5 +922,89 @@ def main(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64 assert_structural_equal(mod, Expected) +def test_repeat_transpose_gemv(): + # fmt: off + + @I.ir_module + class Before: + @T.prim_func(private=True) + def fused_relax_repeat_relax_permute_dims_relax_matmul1(p_lv716: T.handle, p_astype66: T.handle, var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + kv_seq_len = T.int64() + lv716 = T.match_buffer(p_lv716, (T.int64(1), kv_seq_len, T.int64(8), T.int64(128)), "float16") + astype66 = T.match_buffer(p_astype66, (T.int64(1), T.int64(32), T.int64(1), kv_seq_len), "float16") + # with T.block("root"): + var_T_repeat_intermediate = T.alloc_buffer((T.int64(1), kv_seq_len, T.int64(32), T.int64(128)), "float16") + var_T_transpose_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), kv_seq_len, T.int64(128)), "float16") + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), kv_seq_len, T.int64(32), T.int64(128)): + with T.block("T_repeat"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(lv716[v_ax0, v_ax1, v_ax2 // T.int64(4), v_ax3]) + T.writes(var_T_repeat_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_repeat_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv716[v_ax0, v_ax1, v_ax2 // T.int64(4), v_ax3] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), kv_seq_len, T.int64(128)): + with T.block("T_transpose"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_repeat_intermediate[v_ax0, v_ax2, v_ax1, v_ax3]) + T.writes(var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_T_repeat_intermediate[v_ax0, v_ax2, v_ax1, v_ax3] + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), kv_seq_len): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(astype66[v_i0, v_i1, v_i2, v_k], var_T_transpose_intermediate[v_i0, v_i1, v_k, v_i3]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + astype66[v_i0, v_i1, v_i2, v_k] * var_T_transpose_intermediate[v_i0, v_i1, v_k, v_i3] + @I.ir_module + class Expected: + @T.prim_func(private=True) + def fused_relax_repeat_relax_permute_dims_relax_matmul1(p_lv716: T.handle, p_astype66: T.handle, var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + kv_seq_len = T.int64() + lv716 = T.match_buffer(p_lv716, (T.int64(1), kv_seq_len, T.int64(8), T.int64(128)), "float16") + astype66 = T.match_buffer(p_astype66, (T.int64(1), T.int64(32), T.int64(1), kv_seq_len), "float16") + # with T.block("root"): + var_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(16), T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16", scope="local") + for ax0_0_ax1_fused_0 in T.thread_binding(T.int64(64), thread="blockIdx.x"): + for ax0_0_ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"): + for ax2_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0_1_init in range(T.int64(4)): + with T.block("matmul_rf_init"): + vax2_fused_1 = T.axis.spatial(T.int64(16), ax2_fused_1) + v0 = T.axis.spatial(T.int64(32), (ax0_0_ax1_fused_0 * T.int64(16) + ax0_0_ax1_fused_1) // T.int64(128) * T.int64(4) + ax0_1_init) + v1 = T.axis.spatial(T.int64(128), (ax0_0_ax1_fused_0 * T.int64(16) + ax0_0_ax1_fused_1) % T.int64(128)) + T.reads() + T.writes(var_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) + var_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = T.float16(0) + for ax2_fused_0, ax0_1 in T.grid((kv_seq_len + T.int64(15)) // T.int64(16), T.int64(4)): + with T.block("matmul_rf_update"): + vax2_fused_1 = T.axis.spatial(T.int64(16), ax2_fused_1) + v0 = T.axis.spatial(T.int64(32), (ax0_0_ax1_fused_0 * T.int64(16) + ax0_0_ax1_fused_1) // T.int64(128) * T.int64(4) + ax0_1) + v1 = T.axis.spatial(T.int64(128), (ax0_0_ax1_fused_0 * T.int64(16) + ax0_0_ax1_fused_1) % T.int64(128)) + vax2_fused_0 = T.axis.reduce((kv_seq_len + T.int64(15)) // T.int64(16), ax2_fused_0) + T.where(ax2_fused_0 * T.int64(16) + ax2_fused_1 < kv_seq_len) + T.reads(var_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1], astype66[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(16) + vax2_fused_1], lv716[T.int64(0), vax2_fused_0 * T.int64(16) + vax2_fused_1, v0 // T.int64(4), v1]) + T.writes(var_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) + var_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = var_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] + astype66[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(16) + vax2_fused_1] * lv716[T.int64(0), vax2_fused_0 * T.int64(16) + vax2_fused_1, v0 // T.int64(4), v1] + for ax1_0_ax2_fused in T.thread_binding(T.int64(16), thread="threadIdx.x"): + for ax1_1 in range(T.int64(4)): + for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + with T.block("matmul"): + vax2_fused_1 = T.axis.reduce(T.int64(16), ax0) + v0 = T.axis.spatial(T.int64(32), ax0_0_ax1_fused_0 // T.int64(8) * T.int64(4) + ax1_1) + v1 = T.axis.spatial(T.int64(128), ax0_0_ax1_fused_0 % T.int64(8) * T.int64(16) + ax1_0_ax2_fused) + T.reads(var_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]) + T.writes(var_matmul_intermediate[T.int64(0), v0, T.int64(0), v1]) + with T.init(): + var_matmul_intermediate[T.int64(0), v0, T.int64(0), v1] = T.float16(0) + var_matmul_intermediate[T.int64(0), v0, T.int64(0), v1] = var_matmul_intermediate[T.int64(0), v0, T.int64(0), v1] + var_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] + # fmt: on + + with Target("nvidia/geforce-rtx-3090-ti"): + mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Before) # pylint: disable=not-callable + assert_structural_equal(mod, Expected) + + if __name__ == "__main__": tvm.testing.main()