From 772e1f0ed21bdc0cf8b4a31c865f512ad49d3d19 Mon Sep 17 00:00:00 2001 From: Hongyi Jin Date: Thu, 15 Feb 2024 17:31:17 -0500 Subject: [PATCH 1/8] low batch --- python/tvm/dlight/gpu/__init__.py | 1 + python/tvm/dlight/gpu/low_batch_gemv.py | 563 ++++++++++++++++++++++++ src/driver/driver_api.cc | 9 +- src/tir/transforms/hoist_expression.cc | 9 +- 4 files changed, 579 insertions(+), 3 deletions(-) create mode 100644 python/tvm/dlight/gpu/low_batch_gemv.py diff --git a/python/tvm/dlight/gpu/__init__.py b/python/tvm/dlight/gpu/__init__.py index 7db383a161cd..077fdcaeb023 100644 --- a/python/tvm/dlight/gpu/__init__.py +++ b/python/tvm/dlight/gpu/__init__.py @@ -19,6 +19,7 @@ For CUDA/ROCm/Vulkan/Metal-specific rules, use `tvm.dlight.cuda/rocm/vulkan/metal` instead """ from .gemv import GEMV +from .low_batch_gemv import LowBatchGEMV from .fallback import Fallback from .matmul import Matmul from .reduction import Reduction diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py b/python/tvm/dlight/gpu/low_batch_gemv.py new file mode 100644 index 000000000000..3a01640faed8 --- /dev/null +++ b/python/tvm/dlight/gpu/low_batch_gemv.py @@ -0,0 +1,563 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A rule for GEMV and DecodeGEMV.""" +import re +from functools import reduce +from typing import List, Optional, Union, Set + +from tvm import DataType, arith, ir, tir +from tvm.target import Target + +from ..base import ( + BlockInfo, + collect_block_iter_vars_used_in_access_region, + collect_vars_used_in_prim_expr, + is_broadcast_epilogue, + normalize_prim_func, + try_inline_contiguous_spatial, +) +from .base import GPUScheduleRule +from .matmul import auto_inline_consumer_chain, auto_inline_producers + + +def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: + # Detect and return `Y` in `X[...] = X[...] + Y` + buffer_store = block.body + if not isinstance(buffer_store, tir.BufferStore): + return None + if not isinstance(buffer_store.value, tir.Add): + return None + if not ir.structural_equal( + buffer_store.value.a, + tir.BufferLoad(buffer_store.buffer, block.body.indices), + map_free_vars=True, + ): + return None + return buffer_store.value.b + + +def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV): + loop: tir.For = sch.get(loop_rv) + return loop.extent.value if isinstance(loop.extent, tir.IntImm) else loop.extent + + +def get_bytes(dtype: Union[DataType, str]) -> int: + num = re.findall(r"\d+", dtype) + if len(num) != 1: + raise ValueError(f"Cannot get bytes from {dtype}") + return int(num[0]) // 8 + + +def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]: + """Check if the block is a GEMV. + + Parameters + ---------- + + sch : tir.Schedule + The schedule + + block_info : BlockInfo + The block info to be checked + + + Returns + ------- + ret : Optional[List[tir.Buffer]] + The vector buffers used in the GEMV if it is a GEMV, otherwise None. + """ + block = block_info.block_rv + block_stmt = sch.get(block) + conditions = [] + conditions.append(block_info.is_reduction()) + conditions.append(len(block_stmt.reads) >= 2) + conditions.append(len(block_stmt.writes) == 1) + conditions.append(_get_reduction_expr(block_stmt) is not None) + conditions.append( + len(collect_block_iter_vars_used_in_access_region(block_stmt, block_stmt.writes[0].region)) + > 0 + ) + if not all(conditions): + return None + const_iter_vars = set(iter_var.var for iter_var in block_stmt.iter_vars if isinstance(iter_var.dom.extent, tir.IntImm)) + if len(const_iter_vars) == len(block_stmt.iter_vars): + return None + ret = [ + read.buffer + for read in block_stmt.reads + if len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region) & const_iter_vars) < len(const_iter_vars) + and len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region) & const_iter_vars) > 0 + ] + return ret if 0 < len(ret) < len(block_stmt.reads) else None + +def detect_dominant_read(block: tir.Block, const_iter_vars: Set[tir.Var]) -> tir.PrimExpr: + """Detect the dominant read indices in the block.""" + dominant_read = None + num_read_iters = -1 + for buffer_region in block.reads: + tir_vars = collect_block_iter_vars_used_in_access_region(block, buffer_region.region) & const_iter_vars + if num_read_iters < len(tir_vars): + num_read_iters = len(tir_vars) + dominant_read = buffer_region + assert dominant_read is not None + (result,) = dominant_read.buffer.offset_of([e.min for e in dominant_read.region]) + return result + +def normalize( + sch: tir.Schedule, + block_info: BlockInfo, +) -> Optional[bool]: + """Normalize the main block.""" + block_stmt: tir.Block = sch.get(block_info.block_rv) + const_iter_vars = set( + iter_var.var + for iter_var in block_stmt.iter_vars + if isinstance(iter_var.dom.extent, tir.IntImm) + ) + dynamic_iter_vars = set( + iter_var.var + for iter_var in block_stmt.iter_vars + if iter_var.var not in const_iter_vars + ) + access = arith.normalize_to_iter_sum( + detect_dominant_read(block_stmt, const_iter_vars), + input_iters={i.var: i.dom for i in block_stmt.iter_vars}, + ) + buffers_use_vars = [ + collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) + for buf in block_stmt.writes + ] + buffers_use_vars.extend( + [ + collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) + for buf in block_stmt.reads + ] + ) + if collect_vars_used_in_prim_expr(access.base) & set( + iter_var.var for iter_var in block_stmt.iter_vars + ): + return None + iter_to_info = {i.var: i for i in block_info.iters} + batch_loops, s_loops, r_loops, c_loops = [], [], [], [] + inner_axis = access.args[-1].source.source + is_inner_reduction = iter_to_info[inner_axis].kind == "R" + + for split_expr in access.args: + var = split_expr.source.source + info = iter_to_info.get(var) + loop = info.loop_rv + is_reduction = info.kind == "R" + if split_expr.lower_factor > 1: + if c_loops: + return None + loop, c_loop = sch.split(loop, factors=[None, split_expr.lower_factor]) + # we only support the reduction dim being grouped atm + if not is_reduction: + return None + c_loops.append(c_loop) + if is_reduction: + r_loops.append(loop) + elif all([var in buf_vars for buf_vars in buffers_use_vars]): + batch_loops.append(loop) + else: + s_loops.append(loop) + + assert s_loops + assert r_loops + if not c_loops: + c_loops = [sch.add_unit_loop(block_info.block_rv)] + dynamic_loops = [iter_to_info[var].loop_rv for var in dynamic_iter_vars] + assert len(dynamic_loops) == 1 + if not batch_loops: + batch_loops = [sch.add_unit_loop(block_info.block_rv)] + sch.reorder(*dynamic_loops, *batch_loops, *s_loops, *r_loops, *c_loops) + sch.fuse(*batch_loops) + sch.fuse(*s_loops) + sch.fuse(*r_loops) + return is_inner_reduction + + +class LowBatchGEMV(GPUScheduleRule): + """A rule for GEMV and DecodeGEMV.""" + + def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Union[None, tir.Schedule, List[tir.Schedule]]: + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + + reduction_block_infos = [block_info for block_info in block_infos if block_info.is_reduction()] + if len(reduction_block_infos) != 1: + return None + reduction_block_info = reduction_block_infos[0] + vector_input_buffers = is_gemv(sch, reduction_block_info) + if vector_input_buffers is None: + return None + batch_pad = 4 if len(block_infos) == 1 else 1 + pad_value = [iter.dom if isinstance(iter.dom, int) else batch_pad for iter in reduction_block_info.iters ] + sch.pad_einsum(reduction_block_info.block_rv, pad_value) + block_infos = normalize_prim_func(sch) + block_infos = [block_info for block_info in block_infos if "pad" not in block_info.name] + block_infos = try_inline_contiguous_spatial(sch, block_infos) + if len(block_infos) == 1: + epilogue = None + elif len(block_infos) == 2: + epilogue = block_infos[1] + if not epilogue.is_injective(): + return None + else: + return None + + block_info = block_infos[0] + if len(block_info.iters) not in [2, 3]: + # either [B, S, R] = [B, S, R] * [B, R] + # or [S, R] = [S, R] * [R] + return None + block = block_info.block_rv + vector_input_buffers = is_gemv(sch, block_info) + if vector_input_buffers is None: + return None + + # Step 1. Normalize the block, merge spatial and reduction iters + is_inner_reduction = normalize(sch, block_info) + # Step 2. Do the scheduling + if is_inner_reduction is None: + return None + elif is_inner_reduction: + self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue, batch_pad) + return sch + else: + raise NotImplementedError("Outer reduction is not supported yet") + + def sch_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument + self, + sch: tir.Schedule, + target: Target, + block: tir.schedule.BlockRV, + vector_input_buffers: List[tir.Buffer], + epilogue_info: Optional[BlockInfo], + batch_pad: int, + ): + """Schedule the inner reduction block.""" + + def get_max_factor(n, factors): + factors = sorted(factors, reverse=True) + for factor in factors: + if n % factor == 0: + return factor + return 1 + + def apply( + sch: tir.Schedule, + gemv, + TAG_S, + TAG_R, + TS, + TR, + TILE_S, + TILE_R, + VEC_LOAD, + VEC_C, + LOAD_V_SHARED, + LOAD_V_VEC, + UNROLL, + ): + # rfactor: reduce to tx * vec_c + auto_inline_producers(sch, gemv) + + _, b, s, r, c = sch.get_loops(block=gemv) + s = sch.fuse(b, s) + r = sch.fuse(r, c) + bx, ts, tile_s = sch.split(s, factors=[None, TS, TILE_S], preserve_unit_iters=True) + r, tr, tile_r_vec_n, vec_c = sch.split( + r, factors=[None, TR, TILE_R // VEC_C, VEC_C], preserve_unit_iters=True + ) + sch.reorder(r, tile_r_vec_n, tr, vec_c) + tr_vec_c = sch.fuse(tr, vec_c) + rf = sch.rfactor(tr_vec_c, 0) + + # rfactor: reduce to tx + _, bx, ts, tile_s, tr_vec_c = sch.get_loops(block=gemv) + tr, vec_c = sch.split(tr_vec_c, factors=[TR, None], preserve_unit_iters=True) + rf2 = sch.rfactor(tr, 0) + # bind, vectorize compute + batch_loop, bx, ts, tile_s, r, tile_r_vec_n, tr_vec_c = sch.get_loops(block=rf) + tr, vec_c = sch.split(tr_vec_c, factors=[TR, None], preserve_unit_iters=True) + sch.reorder(bx, ts, tr, r, tile_s, tile_r_vec_n, vec_c) + sch.bind(bx, "blockIdx.x") + sch.bind(ts, TAG_S) + sch.bind(tr, TAG_R) + sch.vectorize(vec_c) + by, batch = sch.split(batch_loop, factors=[None, batch_pad]) + sch.bind(by, "blockIdx.y") + sch.reorder(bx, ts, tr, r, batch) + + shared_mem_usage = 0 + for buf in vector_input_buffers: + buf_size = reduce( + lambda x, y: x * y, buf.shape, tir.IntImm(buf.shape[0].dtype, 1) + ) * get_bytes(buf.dtype) + shared_mem_usage += buf_size + LOAD_V_SHARED = ( + LOAD_V_SHARED + and isinstance(shared_mem_usage, tir.IntImm) + and shared_mem_usage.value <= target.max_shared_memory_per_block + ) + + # vectorize load A + # (TODO) this is now actually problematic since the number of loops is dependent on the + # number of dimensions of A_q + Aq_local = sch.cache_read(rf, read_buffer_index=1, storage_scope="local") + sch.compute_at(Aq_local, r, preserve_unit_loops=True) + + s_local, r_local = sch.get_loops(block=Aq_local)[-2:] + s_local, vec_load = sch.split( + s_local, factors=[None, VEC_LOAD], preserve_unit_iters=True + ) + sch.reorder(s_local, r_local, vec_load) # either s_local or r_local should be 1 + sch.vectorize(vec_load) + + + # load vector into shared memory, shape should be the whole vector + if LOAD_V_SHARED: + assert len(vector_input_buffers) == 1 + V_shared = sch.cache_read(rf, read_buffer_index=0, storage_scope="shared") + sch.compute_at(V_shared, tr, preserve_unit_loops=True) + l = sch.get_loops(block=V_shared)[-1] + loop: tir.For = sch.get(l) + if isinstance(loop.extent, tir.IntImm): + # avoid introducing predicates when vector length is too large + vec_length = max( + min( + get_max_factor( + (int)(loop.extent), + [TS * TR * 1, TS * TR * 2, TS * TR * 4, TS * TR * 8], + ) + // TS + // TR, + LOAD_V_VEC, + ), + 1, + ) + else: + vec_length = LOAD_V_VEC + if TAG_R == "threadIdx.x": + _, ty, tx, vec = sch.split( + l, factors=[None, TS, TR, vec_length], preserve_unit_iters=True + ) + else: + _, ty, tx, vec = sch.split( + l, factors=[None, TR, TS, vec_length], preserve_unit_iters=True + ) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + # reduce tile_s * tr * vec to tile_s * tr + sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True) + tr, vec_c, batch_loop, *ts_tile_s = sch.get_loops(block=rf2)[2:] + ts_tile_s = sch.fuse(*ts_tile_s) + ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + tile_s, vec_s = sch.split( + tile_s, + factors=[None, get_max_factor(TILE_S, [1, 2, 4, 8])], + preserve_unit_iters=True, + ) + sch.reorder(ts, tr, tile_s, batch_loop, vec_s, vec_c) + sch.bind(ts, TAG_S) + sch.bind(tr, TAG_R) + sch.vectorize(vec_s) + + # reduce tile_s * tr to tile_s + sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True) + + tr, batch_loop, *ts_tile_s = sch.get_loops(block=gemv)[2:] + ts_tile_s = sch.fuse(*ts_tile_s) + ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + sch.reorder(tile_s, batch_loop, ts, tr) + sch.bind(ts, TAG_S) + sch.bind(tr, TAG_R) + + sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[4]) + sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1]) + + sch.set_scope(rf, buffer_index=0, storage_scope="local") + sch.set_scope(rf2, buffer_index=0, storage_scope="local") + + unroll_factor = UNROLL + + sch.annotate( + block_or_loop=sch.get_loops(rf)[4], + ann_key="pragma_auto_unroll_max_step", + ann_val=unroll_factor, + ) + sch.annotate( + block_or_loop=sch.get_loops(rf)[4], ann_key="pragma_unroll_explicit", ann_val=1 + ) + + sch.annotate( + block_or_loop=sch.get_loops(rf2)[4], + ann_key="pragma_auto_unroll_max_step", + ann_val=unroll_factor, + ) + sch.annotate( + block_or_loop=sch.get_loops(rf2)[4], ann_key="pragma_unroll_explicit", ann_val=1 + ) + + if LOAD_V_SHARED: + sch.annotate( + block_or_loop=sch.get_loops(V_shared)[-4], + ann_key="pragma_unroll_explicit", + ann_val=unroll_factor, + ) + sch.annotate( + block_or_loop=sch.get_loops(V_shared)[-4], ann_key="pragma_vectorize", ann_val=1 + ) + + epilogue = sch.get_consumers(gemv) + # Schedule epilogue + if epilogue: + epilogue = epilogue[0] + if is_broadcast_epilogue(sch, block, epilogue): + sch.reverse_compute_at(epilogue, bx) + sch.set_scope(block, 0, "shared") + _, _,_, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name + _, tx = sch.split(sch.fuse(*s), factors=[None, TX]) + sch.bind(tx, "threadIdx.x") + else: + sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True) + ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[3:]) + ts_tile_s = sch.get_loops(epilogue)[-1] + ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + sch.bind(ts, TAG_S) + sch.set_scope(block, 0, "local") + + + return sch + + # Specify the `len_tx` and `len_ty` according to the loop extent + _, batch, s, r, c = sch.get_loops(block=block) + len_batch, len_s, len_r, len_c = ( + get_extent(sch, batch), + get_extent(sch, s), + get_extent(sch, r), + get_extent(sch, c), + ) + len_S = len_batch * len_s + len_R = len_r * len_c + + TAG_S, TAG_R = "threadIdx.y", "threadIdx.x" + if target.kind.name == "cuda": + VEC_C = 4 + LOAD_V_SHARED = True + LOAD_V_VEC = 8 + UNROLL = 256 + if isinstance(len_S, int): + if len_S > len_R: + TS, TR = 4, 64 + else: + TS, TR = 16, 32 + elif target.kind.name == "metal": + # Note that the following tile size is tuned on M2 Ultra for 7B + TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" + VEC_C = 1 + LOAD_V_SHARED = False + LOAD_V_VEC = -1 + UNROLL = 256 + if isinstance(len_S, int): + if len_S > len_R: + TS, TR = 4, 16 + else: + TS, TR = 2, 64 + elif target.kind.name == "rocm": + VEC_C = 4 + LOAD_V_SHARED = True + LOAD_V_VEC = 8 + UNROLL = 256 + if isinstance(len_S, int): + if len_S > len_R: + TS, TR = 1, 128 + else: + TS, TR = 8, 64 + elif target.kind.name == "opencl" and "android" in str(target.host): + TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" + VEC_C = 8 + LOAD_V_SHARED = False + LOAD_V_VEC = -1 + UNROLL = 8 + TS, TR = 2, 32 + elif target.kind.name == "vulkan": + VEC_C = 4 + LOAD_V_SHARED = True + LOAD_V_VEC = 4 + UNROLL = 256 + if isinstance(len_S, int): + if len_S > len_R: + TS, TR = 4, 32 + else: + TS, TR = 16, 32 + elif target.kind.name == "opencl" and "mali" in str(target.attrs): + VEC_C = 8 + LOAD_V_SHARED = False + LOAD_V_VEC = -1 + UNROLL = 64 + TS, TR = 1, 64 + else: + VEC_C = 1 + LOAD_V_SHARED = False + LOAD_V_VEC = -1 + UNROLL = 64 + TS, TR = 1, 64 + + if not isinstance(len_S, int): + TS, TR = 1, 64 + + while TS * TR > target.max_num_threads: + if TS > 1: + TS //= 2 + else: + TR //= 2 + + TILE_S, TILE_R = ( + 1, + len_c + if len_c > 1 + else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1), + ) + VEC_C = min(get_max_factor(TILE_R, [1, 2, 4, 8]), VEC_C) + VEC_LOAD = 1 + + return apply( + sch, + gemv=block, + TAG_S=TAG_S, + TAG_R=TAG_R, + TS=TS, + TR=TR, + TILE_S=TILE_S, + TILE_R=TILE_R, + VEC_LOAD=VEC_LOAD, + VEC_C=VEC_C, + LOAD_V_SHARED=LOAD_V_SHARED, + LOAD_V_VEC=LOAD_V_VEC, + UNROLL=UNROLL, + ) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 17cd5c49a1bf..8547184edcae 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -240,6 +240,10 @@ Array CreatePassList(bool disable_loop_partition) { if (use_async_copy) { pass_list.push_back(tir::transform::LowerAsyncDMA()); } + // HoistIfThenElse must be applied before UnrollLoop + // because HoistIfThenElse could utilize for loop structure + // which might be unrolled in UnrollLoop + pass_list.push_back(tir::transform::HoistIfThenElse()); pass_list.push_back(tir::transform::UnrollLoop()); // Add user-defined phase-2 passes @@ -250,7 +254,6 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::RemoveNoOp()); pass_list.push_back(tir::transform::RewriteUnsafeSelect()); - pass_list.push_back(tir::transform::HoistIfThenElse()); // Add user-defined phase-3 passes pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end()); @@ -585,7 +588,6 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::ThreadSync("shared")); mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn")); - mixed_pass_list.push_back(tir::transform::MergeSharedMemoryAllocations()); mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); mixed_pass_list.push_back(tir::transform::InferFragment()); mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); @@ -603,6 +605,9 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions()); mixed_pass_list.push_back(tir::transform::SplitHostDevice()); + // MergeSharedMemoryAllocations must be applied after SplitHostDevice + // because the merged allocation site is at the beginning of each device function + mixed_pass_list.push_back(tir::transform::MergeSharedMemoryAllocations()); bool unpacked_api = mixed_mod->GetAttr(tvm::attr::kExecutor) .value_or(relay::Executor::Create("graph", {})) diff --git a/src/tir/transforms/hoist_expression.cc b/src/tir/transforms/hoist_expression.cc index 494fd7184fc3..f0fc90ee3244 100644 --- a/src/tir/transforms/hoist_expression.cc +++ b/src/tir/transforms/hoist_expression.cc @@ -558,7 +558,14 @@ Pass HoistIfThenElse() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto cfg = ctx->GetConfig("tir.HoistIfThenElse"); - + auto flag = f->GetAttr("tir.HoistIfThenElseExprWithBlock"); + if (flag && flag.value().IntValue() == 1) { + HoistExpressionConfig config(static_cast(HoistedConditionals::kUsingBlockVar) | + static_cast(HoistedConditionals::kIfElseExpr), + static_cast(HoistedLetBindings::kNone)); + n->body = ExpressionHoister::Hoist(std::move(n->body), config); + return f; + } if (!cfg.defined()) { cfg = AttrsWithDefaultValues(); } From 34ee0a3dd927172688477915afe91341b1798834 Mon Sep 17 00:00:00 2001 From: Hongyi Jin Date: Thu, 15 Feb 2024 17:34:12 -0500 Subject: [PATCH 2/8] fix --- python/tvm/dlight/gpu/low_batch_gemv.py | 60 ++++++++++++++++--------- src/driver/driver_api.cc | 2 +- 2 files changed, 39 insertions(+), 23 deletions(-) diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py b/python/tvm/dlight/gpu/low_batch_gemv.py index 3a01640faed8..8505944e8aaf 100644 --- a/python/tvm/dlight/gpu/low_batch_gemv.py +++ b/python/tvm/dlight/gpu/low_batch_gemv.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""A rule for GEMV and DecodeGEMV.""" +"""A rule for low-batch GEMM / decode-GEMM using GEMV schedule.""" import re from functools import reduce from typing import List, Optional, Union, Set @@ -31,7 +31,7 @@ try_inline_contiguous_spatial, ) from .base import GPUScheduleRule -from .matmul import auto_inline_consumer_chain, auto_inline_producers +from .matmul import auto_inline_producers def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: @@ -63,7 +63,7 @@ def get_bytes(dtype: Union[DataType, str]) -> int: def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]: - """Check if the block is a GEMV. + """Check if the block is a low batch GEMM. Parameters ---------- @@ -78,7 +78,7 @@ def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffe Returns ------- ret : Optional[List[tir.Buffer]] - The vector buffers used in the GEMV if it is a GEMV, otherwise None. + The vector-like buffers used in the low batch GEMM if it is a low batch GEMM, otherwise None. """ block = block_info.block_rv block_stmt = sch.get(block) @@ -93,23 +93,37 @@ def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffe ) if not all(conditions): return None - const_iter_vars = set(iter_var.var for iter_var in block_stmt.iter_vars if isinstance(iter_var.dom.extent, tir.IntImm)) + const_iter_vars = set( + iter_var.var + for iter_var in block_stmt.iter_vars + if isinstance(iter_var.dom.extent, tir.IntImm) + ) if len(const_iter_vars) == len(block_stmt.iter_vars): return None ret = [ read.buffer for read in block_stmt.reads - if len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region) & const_iter_vars) < len(const_iter_vars) - and len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region) & const_iter_vars) > 0 + if len( + collect_block_iter_vars_used_in_access_region(block_stmt, read.region) & const_iter_vars + ) + < len(const_iter_vars) + and len( + collect_block_iter_vars_used_in_access_region(block_stmt, read.region) & const_iter_vars + ) + > 0 ] return ret if 0 < len(ret) < len(block_stmt.reads) else None + def detect_dominant_read(block: tir.Block, const_iter_vars: Set[tir.Var]) -> tir.PrimExpr: """Detect the dominant read indices in the block.""" dominant_read = None num_read_iters = -1 for buffer_region in block.reads: - tir_vars = collect_block_iter_vars_used_in_access_region(block, buffer_region.region) & const_iter_vars + tir_vars = ( + collect_block_iter_vars_used_in_access_region(block, buffer_region.region) + & const_iter_vars + ) if num_read_iters < len(tir_vars): num_read_iters = len(tir_vars) dominant_read = buffer_region @@ -117,6 +131,7 @@ def detect_dominant_read(block: tir.Block, const_iter_vars: Set[tir.Var]) -> tir (result,) = dominant_read.buffer.offset_of([e.min for e in dominant_read.region]) return result + def normalize( sch: tir.Schedule, block_info: BlockInfo, @@ -129,9 +144,7 @@ def normalize( if isinstance(iter_var.dom.extent, tir.IntImm) ) dynamic_iter_vars = set( - iter_var.var - for iter_var in block_stmt.iter_vars - if iter_var.var not in const_iter_vars + iter_var.var for iter_var in block_stmt.iter_vars if iter_var.var not in const_iter_vars ) access = arith.normalize_to_iter_sum( detect_dominant_read(block_stmt, const_iter_vars), @@ -192,7 +205,7 @@ def normalize( class LowBatchGEMV(GPUScheduleRule): - """A rule for GEMV and DecodeGEMV.""" + """A rule for low batch GEMM / decode-GEMM.""" def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements self, @@ -204,8 +217,10 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- return None sch = tir.Schedule(func) block_infos = normalize_prim_func(sch) - - reduction_block_infos = [block_info for block_info in block_infos if block_info.is_reduction()] + + reduction_block_infos = [ + block_info for block_info in block_infos if block_info.is_reduction() + ] if len(reduction_block_infos) != 1: return None reduction_block_info = reduction_block_infos[0] @@ -213,7 +228,10 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- if vector_input_buffers is None: return None batch_pad = 4 if len(block_infos) == 1 else 1 - pad_value = [iter.dom if isinstance(iter.dom, int) else batch_pad for iter in reduction_block_info.iters ] + pad_value = [ + iter.dom if isinstance(iter.dom, int) else batch_pad + for iter in reduction_block_info.iters + ] sch.pad_einsum(reduction_block_info.block_rv, pad_value) block_infos = normalize_prim_func(sch) block_infos = [block_info for block_info in block_infos if "pad" not in block_info.name] @@ -328,7 +346,7 @@ def apply( # number of dimensions of A_q Aq_local = sch.cache_read(rf, read_buffer_index=1, storage_scope="local") sch.compute_at(Aq_local, r, preserve_unit_loops=True) - + s_local, r_local = sch.get_loops(block=Aq_local)[-2:] s_local, vec_load = sch.split( s_local, factors=[None, VEC_LOAD], preserve_unit_iters=True @@ -336,7 +354,6 @@ def apply( sch.reorder(s_local, r_local, vec_load) # either s_local or r_local should be 1 sch.vectorize(vec_load) - # load vector into shared memory, shape should be the whole vector if LOAD_V_SHARED: assert len(vector_input_buffers) == 1 @@ -389,7 +406,7 @@ def apply( # reduce tile_s * tr to tile_s sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True) - + tr, batch_loop, *ts_tile_s = sch.get_loops(block=gemv)[2:] ts_tile_s = sch.fuse(*ts_tile_s) ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) @@ -432,7 +449,7 @@ def apply( sch.annotate( block_or_loop=sch.get_loops(V_shared)[-4], ann_key="pragma_vectorize", ann_val=1 ) - + epilogue = sch.get_consumers(gemv) # Schedule epilogue if epilogue: @@ -440,7 +457,7 @@ def apply( if is_broadcast_epilogue(sch, block, epilogue): sch.reverse_compute_at(epilogue, bx) sch.set_scope(block, 0, "shared") - _, _,_, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name + _, _, _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name _, tx = sch.split(sch.fuse(*s), factors=[None, TX]) sch.bind(tx, "threadIdx.x") else: @@ -451,7 +468,6 @@ def apply( sch.bind(ts, TAG_S) sch.set_scope(block, 0, "local") - return sch # Specify the `len_tx` and `len_ty` according to the loop extent @@ -545,7 +561,7 @@ def apply( ) VEC_C = min(get_max_factor(TILE_R, [1, 2, 4, 8]), VEC_C) VEC_LOAD = 1 - + return apply( sch, gemv=block, diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 8547184edcae..692ce35f00de 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -605,7 +605,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions()); mixed_pass_list.push_back(tir::transform::SplitHostDevice()); - // MergeSharedMemoryAllocations must be applied after SplitHostDevice + // MergeSharedMemoryAllocations must be applied after SplitHostDevice // because the merged allocation site is at the beginning of each device function mixed_pass_list.push_back(tir::transform::MergeSharedMemoryAllocations()); From 12192da03c52b9b4679c99d4f1c0918f9ee70838 Mon Sep 17 00:00:00 2001 From: Hongyi Jin Date: Thu, 15 Feb 2024 17:53:39 -0500 Subject: [PATCH 3/8] fix lint --- python/tvm/dlight/gpu/low_batch_gemv.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py b/python/tvm/dlight/gpu/low_batch_gemv.py index 8505944e8aaf..60ff9c6b1639 100644 --- a/python/tvm/dlight/gpu/low_batch_gemv.py +++ b/python/tvm/dlight/gpu/low_batch_gemv.py @@ -78,7 +78,8 @@ def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffe Returns ------- ret : Optional[List[tir.Buffer]] - The vector-like buffers used in the low batch GEMM if it is a low batch GEMM, otherwise None. + The vector-like buffers used in the low batch GEMM if it is a low batch GEMM, + otherwise None. """ block = block_info.block_rv block_stmt = sch.get(block) From cd3761b85e1500bcf7b428a92b94ebab4dcc88f6 Mon Sep 17 00:00:00 2001 From: Hongyi Jin Date: Mon, 19 Feb 2024 17:41:40 -0500 Subject: [PATCH 4/8] do dequantize only once --- python/tvm/dlight/gpu/low_batch_gemv.py | 46 ++++++++++++++++--------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py b/python/tvm/dlight/gpu/low_batch_gemv.py index 60ff9c6b1639..973a11037046 100644 --- a/python/tvm/dlight/gpu/low_batch_gemv.py +++ b/python/tvm/dlight/gpu/low_batch_gemv.py @@ -207,6 +207,10 @@ def normalize( class LowBatchGEMV(GPUScheduleRule): """A rule for low batch GEMM / decode-GEMM.""" + + def __init__(self, bucket = 1): + self.bucket = bucket + def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements self, @@ -228,14 +232,21 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- vector_input_buffers = is_gemv(sch, reduction_block_info) if vector_input_buffers is None: return None - batch_pad = 4 if len(block_infos) == 1 else 1 + batch_pad = self.bucket pad_value = [ iter.dom if isinstance(iter.dom, int) else batch_pad for iter in reduction_block_info.iters ] sch.pad_einsum(reduction_block_info.block_rv, pad_value) block_infos = normalize_prim_func(sch) - block_infos = [block_info for block_info in block_infos if "pad" not in block_info.name] + dequantize_block = None + pad_input_block = None + for block_info in block_infos: + if "dequantize" in block_info.name: + dequantize_block = block_info.block_rv + elif "pad" in block_info.name and len(sch.get_producers(block_info.block_rv)) == 0: + pad_input_block = block_info.block_rv + block_infos = [block_info for block_info in block_infos if "pad" not in block_info.name and "dequantize" not in block_info.name] block_infos = try_inline_contiguous_spatial(sch, block_infos) if len(block_infos) == 1: epilogue = None @@ -262,7 +273,7 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- if is_inner_reduction is None: return None elif is_inner_reduction: - self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue, batch_pad) + self.sch_inner_reduction(sch, target, block, dequantize_block, pad_input_block, vector_input_buffers, epilogue, batch_pad) return sch else: raise NotImplementedError("Outer reduction is not supported yet") @@ -272,6 +283,8 @@ def sch_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, un sch: tir.Schedule, target: Target, block: tir.schedule.BlockRV, + dequantize_block: Optional[tir.schedule.BlockRV], + pad_input_block: Optional[tir.schedule.BlockRV], vector_input_buffers: List[tir.Buffer], epilogue_info: Optional[BlockInfo], batch_pad: int, @@ -301,7 +314,6 @@ def apply( UNROLL, ): # rfactor: reduce to tx * vec_c - auto_inline_producers(sch, gemv) _, b, s, r, c = sch.get_loops(block=gemv) s = sch.fuse(b, s) @@ -345,15 +357,16 @@ def apply( # vectorize load A # (TODO) this is now actually problematic since the number of loops is dependent on the # number of dimensions of A_q - Aq_local = sch.cache_read(rf, read_buffer_index=1, storage_scope="local") - sch.compute_at(Aq_local, r, preserve_unit_loops=True) - - s_local, r_local = sch.get_loops(block=Aq_local)[-2:] - s_local, vec_load = sch.split( - s_local, factors=[None, VEC_LOAD], preserve_unit_iters=True - ) - sch.reorder(s_local, r_local, vec_load) # either s_local or r_local should be 1 - sch.vectorize(vec_load) + if dequantize_block is not None: + sch.compute_at(dequantize_block, r, preserve_unit_loops=True) + sch.set_scope(dequantize_block, 0, "local") + + s_local, r_local = sch.get_loops(block=dequantize_block)[-2:] + s_local, vec_load = sch.split( + s_local, factors=[None, VEC_LOAD], preserve_unit_iters=True + ) + sch.reorder(s_local, r_local, vec_load) # either s_local or r_local should be 1 + sch.vectorize(vec_load) # load vector into shared memory, shape should be the whole vector if LOAD_V_SHARED: @@ -389,6 +402,8 @@ def apply( sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") sch.vectorize(vec) + if pad_input_block is not None: + sch.compute_inline(pad_input_block) # reduce tile_s * tr * vec to tile_s * tr sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True) @@ -502,7 +517,7 @@ def apply( UNROLL = 256 if isinstance(len_S, int): if len_S > len_R: - TS, TR = 4, 16 + TS, TR = 2, 32 else: TS, TR = 2, 64 elif target.kind.name == "rocm": @@ -555,14 +570,13 @@ def apply( TR //= 2 TILE_S, TILE_R = ( - 1, + 2, len_c if len_c > 1 else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1), ) VEC_C = min(get_max_factor(TILE_R, [1, 2, 4, 8]), VEC_C) VEC_LOAD = 1 - return apply( sch, gemv=block, From 9f612ef84be1a9bc3fbc242929a4de03600b57b2 Mon Sep 17 00:00:00 2001 From: Hongyi Jin Date: Mon, 19 Feb 2024 19:54:24 -0500 Subject: [PATCH 5/8] change default --- python/tvm/dlight/gpu/low_batch_gemv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py b/python/tvm/dlight/gpu/low_batch_gemv.py index 973a11037046..a9bf427c31fe 100644 --- a/python/tvm/dlight/gpu/low_batch_gemv.py +++ b/python/tvm/dlight/gpu/low_batch_gemv.py @@ -208,7 +208,7 @@ def normalize( class LowBatchGEMV(GPUScheduleRule): """A rule for low batch GEMM / decode-GEMM.""" - def __init__(self, bucket = 1): + def __init__(self, bucket = 4): self.bucket = bucket From b14dfa12151cf6eaa3ff93c58baa49df5d050de1 Mon Sep 17 00:00:00 2001 From: Hongyi Jin Date: Mon, 19 Feb 2024 20:17:30 -0500 Subject: [PATCH 6/8] add test --- python/tvm/dlight/gpu/low_batch_gemv.py | 24 +- .../python/dlight/test_gpu_low_batch_gemv.py | 255 ++++++++++++++++++ 2 files changed, 273 insertions(+), 6 deletions(-) create mode 100644 tests/python/dlight/test_gpu_low_batch_gemv.py diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py b/python/tvm/dlight/gpu/low_batch_gemv.py index a9bf427c31fe..c55a3a86d0db 100644 --- a/python/tvm/dlight/gpu/low_batch_gemv.py +++ b/python/tvm/dlight/gpu/low_batch_gemv.py @@ -207,10 +207,9 @@ def normalize( class LowBatchGEMV(GPUScheduleRule): """A rule for low batch GEMM / decode-GEMM.""" - - def __init__(self, bucket = 4): + + def __init__(self, bucket=4): self.bucket = bucket - def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements self, @@ -246,7 +245,11 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- dequantize_block = block_info.block_rv elif "pad" in block_info.name and len(sch.get_producers(block_info.block_rv)) == 0: pad_input_block = block_info.block_rv - block_infos = [block_info for block_info in block_infos if "pad" not in block_info.name and "dequantize" not in block_info.name] + block_infos = [ + block_info + for block_info in block_infos + if "pad" not in block_info.name and "dequantize" not in block_info.name + ] block_infos = try_inline_contiguous_spatial(sch, block_infos) if len(block_infos) == 1: epilogue = None @@ -273,7 +276,16 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- if is_inner_reduction is None: return None elif is_inner_reduction: - self.sch_inner_reduction(sch, target, block, dequantize_block, pad_input_block, vector_input_buffers, epilogue, batch_pad) + self.sch_inner_reduction( + sch, + target, + block, + dequantize_block, + pad_input_block, + vector_input_buffers, + epilogue, + batch_pad, + ) return sch else: raise NotImplementedError("Outer reduction is not supported yet") @@ -360,7 +372,7 @@ def apply( if dequantize_block is not None: sch.compute_at(dequantize_block, r, preserve_unit_loops=True) sch.set_scope(dequantize_block, 0, "local") - + s_local, r_local = sch.get_loops(block=dequantize_block)[-2:] s_local, vec_load = sch.split( s_local, factors=[None, VEC_LOAD], preserve_unit_iters=True diff --git a/tests/python/dlight/test_gpu_low_batch_gemv.py b/tests/python/dlight/test_gpu_low_batch_gemv.py new file mode 100644 index 000000000000..aea4b1bda745 --- /dev/null +++ b/tests/python/dlight/test_gpu_low_batch_gemv.py @@ -0,0 +1,255 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import pytest + +import tvm.testing +from tvm import dlight as dl +from tvm.script import tir as T +from tvm.target import Target + + +def test_batch_decode_gemv(): + # fmt: off + + @T.prim_func(private=True) + def before(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T.Buffer((T.int64(4096), T.int64(896)), "float16"), p_lv807: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.HoistIfThenElseExprWithBlock": 1}) + batch_size = T.int64() + lv807 = T.match_buffer(p_lv807, (batch_size, T.int64(1), T.int64(28672)), "float16") + NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(4096)), "float16") + # with T.block("root"): + compute = T.alloc_buffer((T.int64(4096), T.int64(28672)), "float16") + dequantize_intermediate_intermediate = T.alloc_buffer((T.int64(4096), T.int64(28672)), "float16") + for i0, i1 in T.grid(T.int64(4096), T.int64(28672)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(lv429[v_i0, v_i1 // T.int64(8)]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(lv429[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15))) + for i0, i1 in T.grid(T.int64(4096), T.int64(28672)): + with T.block("dequantize"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(compute[v_i0, v_i1], lv430[v_i0, v_i1 // T.int64(32)]) + T.writes(dequantize_intermediate_intermediate[v_i0, v_i1]) + dequantize_intermediate_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7)) * lv430[v_i0, v_i1 // T.int64(32)] + for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(4096), T.int64(28672)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv807[v_i0, v_i1, v_k], dequantize_intermediate_intermediate[v_i2, v_k]) + T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv807[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate[v_i2, v_k] + + @T.prim_func(private=True) + def expected(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T.Buffer((T.int64(4096), T.int64(896)), "float16"), p_lv807: T.handle, p_output0: T.handle): + T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + lv807 = T.match_buffer(p_lv807, (batch_size, T.int64(1), T.int64(28672)), "float16") + NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(4096)), "float16") + # with T.block("root"): + dequantize_intermediate_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(28672)), "float16", scope="local") + NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") + NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(64), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") + NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(64), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") + for ax0_0 in T.thread_binding((batch_size + T.int64(3)) // T.int64(4), thread="blockIdx.y"): + for u_fused_ax1_fused_fused_0 in T.thread_binding(T.int64(1024), thread="blockIdx.x"): + for u_fused_ax1_fused_fused_1 in T.thread_binding(T.int64(2), thread="threadIdx.x"): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(64), thread="threadIdx.y"): + for ax0_1_init, u_fused_ax1_fused_fused_2_init in T.grid(T.int64(4), T.int64(2)): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(1)): + with T.block("NT_matmul_rf_init"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(64), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init) + v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2_init) + T.reads() + T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1]) + NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1] = T.float16(0) + for ax2_fused_u_fused_0 in T.serial(T.int64(56), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)): + for ax0_1 in T.vectorized(T.int64(1)): + with T.block("dequantize"): + v0 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + u_fused_ax1_fused_fused_1 * T.int64(2) + ax0_0_1 + ax0_1) + v1 = T.axis.spatial(T.int64(28672), ax2_fused_u_fused_0 * T.int64(512) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(8) + ax1) + T.reads(lv429[v0, v1 // T.int64(8)], lv430[v0, v1 // T.int64(32)]) + T.writes(dequantize_intermediate_intermediate_local[v0, v1]) + dequantize_intermediate_intermediate_local[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv429[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7)) * lv430[v0, v1 // T.int64(32)] + for ax0_1, u_fused_ax1_fused_fused_2, ax2_fused_u_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(8)): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(T.int64(1)): + with T.block("NT_matmul_rf_update"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(64), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1) + v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2) + vax2_fused_u_fused_0, vax2_fused_u_fused_2 = T.axis.remap("RR", [ax2_fused_u_fused_0, ax2_fused_u_fused_2]) + T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1], lv807[v0, T.int64(0), vax2_fused_u_fused_0 * T.int64(512) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + vax2_fused_u_fused_2], dequantize_intermediate_intermediate_local[v1, vax2_fused_u_fused_0 * T.int64(512) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + vax2_fused_u_fused_2]) + T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1]) + NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1] + T.if_then_else(v0 < batch_size, lv807[v0, T.int64(0), vax2_fused_u_fused_0 * T.int64(512) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + vax2_fused_u_fused_2], T.float16(0)) * dequantize_intermediate_intermediate_local[v1, vax2_fused_u_fused_0 * T.int64(512) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + vax2_fused_u_fused_2] + for ax3_fused_0 in T.thread_binding(T.int64(2), thread="threadIdx.x"): + for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.y"): + for ax3_fused_1_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax2 in range(T.int64(4)): + for ax3_fused_1_1 in T.vectorized(T.int64(2)): + with T.block("NT_matmul_rf_init"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(64), ax0) + v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + ax3_fused_0 * T.int64(2) + ax3_fused_1_0 * T.int64(2) + ax3_fused_1_1) + T.reads() + T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1]) + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] = T.float16(0) + for ax1 in range(T.int64(1)): + with T.block("NT_matmul_rf_update"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + ax3_fused_0 * T.int64(2) + ax3_fused_1_0 * T.int64(2) + ax3_fused_1_1) + T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1], NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, T.int64(0), v1]) + T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1]) + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] = NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, T.int64(0), v1] + for ax2_fused_1, ax1 in T.grid(T.int64(2), T.int64(4)): + for ax2_fused_0 in T.thread_binding(T.int64(2), thread="threadIdx.x"): + for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.y"): + with T.block("NT_matmul"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(64), ax0) + v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + ax2_fused_0 * T.int64(2) + ax2_fused_1) + T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1]) + T.writes(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1]) + with T.init(): + NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = T.float16(0) + NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] + NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] + for ax0 in range(T.int64(4)): + for ax1_fused_0 in T.thread_binding(T.int64(2), thread="threadIdx.x"): + for ax1_fused_1 in range(T.int64(2)): + with T.block("NT_matmul_intermediate_pad"): + v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + ax1_fused_0 * T.int64(2) + ax1_fused_1) + T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size) + T.reads(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1]) + T.writes(NT_matmul_intermediate[v0, T.int64(0), v1]) + NT_matmul_intermediate[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] + # fmt: on + mod = tvm.IRModule({"main": before}) + with Target("metal"): + mod = dl.ApplyDefaultSchedule(dl.gpu.LowBatchGEMV(4))(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +def test_batch_gemv(): + N = 4096 + K = 4096 + # fmt: off + @T.prim_func(private=True) + def before(var_A: T.handle, B: T.Buffer((T.int64(N), T.int64(K)), "float16"), var_NT_matmul: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.HoistIfThenElseExprWithBlock": 1}) + batch_size = T.int64() + A = T.match_buffer(var_A, (batch_size, T.int64(1), T.int64(K)), "float16") + NT_matmul = T.match_buffer(var_NT_matmul, (batch_size, T.int64(1), T.int64(N)), "float16") + # with T.block("root"): + for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(N), T.int64(K)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(A[v_i0, v_i1, v_k], B[v_i2, v_k]) + T.writes(NT_matmul[v_i0, v_i1, v_i2]) + with T.init(): + NT_matmul[v_i0, v_i1, v_i2] = T.float16(0) + NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * B[v_i2, v_k] + + @T.prim_func(private=True) + def expected(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), var_NT_matmul: T.handle): + T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + batch_size = T.int64() + A = T.match_buffer(var_A, (batch_size, T.int64(1), T.int64(4096)), "float16") + NT_matmul = T.match_buffer(var_NT_matmul, (batch_size, T.int64(1), T.int64(4096)), "float16") + # with T.block("root"): + NT_matmul_pad_local = T.alloc_buffer(((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") + NT_matmul_pad_rf_local = T.alloc_buffer((T.int64(64), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") + NT_matmul_pad_rf_local_1 = T.alloc_buffer((T.int64(64), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local") + for ax0_0 in T.thread_binding((batch_size + T.int64(3)) // T.int64(4), thread="blockIdx.y"): + for u_fused_ax1_fused_fused_0 in T.thread_binding(T.int64(1024), thread="blockIdx.x"): + for u_fused_ax1_fused_fused_1 in T.thread_binding(T.int64(2), thread="threadIdx.x"): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(64), thread="threadIdx.y"): + for ax0_1_init, u_fused_ax1_fused_fused_2_init in T.grid(T.int64(4), T.int64(2)): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(1)): + with T.block("NT_matmul_rf_init"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(64), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init) + v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2_init) + T.reads() + T.writes(NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1]) + NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1] = T.float16(0) + for ax2_fused_u_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax0_1, u_fused_ax1_fused_fused_2, ax2_fused_u_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(8)): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(T.int64(1)): + with T.block("NT_matmul_rf_update"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(64), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1) + v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2) + vax2_fused_u_fused_0, vax2_fused_u_fused_2 = T.axis.remap("RR", [ax2_fused_u_fused_0, ax2_fused_u_fused_2]) + T.reads(NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1], A[v0, T.int64(0), vax2_fused_u_fused_0 * T.int64(512) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + vax2_fused_u_fused_2], B[v1, vax2_fused_u_fused_0 * T.int64(512) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + vax2_fused_u_fused_2]) + T.writes(NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1]) + NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1] = NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, T.int64(0), v1] + T.if_then_else(v0 < batch_size, A[v0, T.int64(0), vax2_fused_u_fused_0 * T.int64(512) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + vax2_fused_u_fused_2], T.float16(0)) * B[v1, vax2_fused_u_fused_0 * T.int64(512) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) + vax2_fused_u_fused_2] + for ax3_fused_0 in T.thread_binding(T.int64(2), thread="threadIdx.x"): + for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.y"): + for ax3_fused_1_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax2 in range(T.int64(4)): + for ax3_fused_1_1 in T.vectorized(T.int64(2)): + with T.block("NT_matmul_rf_init"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(64), ax0) + v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + ax3_fused_0 * T.int64(2) + ax3_fused_1_0 * T.int64(2) + ax3_fused_1_1) + T.reads() + T.writes(NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1]) + NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] = T.float16(0) + for ax1 in range(T.int64(1)): + with T.block("NT_matmul_rf_update"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + ax3_fused_0 * T.int64(2) + ax3_fused_1_0 * T.int64(2) + ax3_fused_1_1) + T.reads(NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1], NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, T.int64(0), v1]) + T.writes(NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1]) + NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] = NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] + NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, T.int64(0), v1] + for ax2_fused_1, ax1 in T.grid(T.int64(2), T.int64(4)): + for ax2_fused_0 in T.thread_binding(T.int64(2), thread="threadIdx.x"): + for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.y"): + with T.block("NT_matmul"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(64), ax0) + v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + ax2_fused_0 * T.int64(2) + ax2_fused_1) + T.reads(NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1]) + T.writes(NT_matmul_pad_local[v0, T.int64(0), v1]) + with T.init(): + NT_matmul_pad_local[v0, T.int64(0), v1] = T.float16(0) + NT_matmul_pad_local[v0, T.int64(0), v1] = NT_matmul_pad_local[v0, T.int64(0), v1] + NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, T.int64(0), v1] + for ax0 in range(T.int64(4)): + for ax1_fused_0 in T.thread_binding(T.int64(2), thread="threadIdx.x"): + for ax1_fused_1 in range(T.int64(2)): + with T.block("NT_matmul_pad"): + v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0) + v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(4) + ax1_fused_0 * T.int64(2) + ax1_fused_1) + T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size) + T.reads(NT_matmul_pad_local[v0, T.int64(0), v1]) + T.writes(NT_matmul[v0, T.int64(0), v1]) + NT_matmul[v0, T.int64(0), v1] = NT_matmul_pad_local[v0, T.int64(0), v1] + # fmt: on + mod = tvm.IRModule({"main": before}) + with Target("metal"): + mod = dl.ApplyDefaultSchedule(dl.gpu.LowBatchGEMV(4))(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +if __name__ == "__main__": + tvm.testing.main() From b06d9fe5be54020d6744869e7be70423e1ac0ac1 Mon Sep 17 00:00:00 2001 From: Hongyi Jin Date: Mon, 19 Feb 2024 20:59:35 -0500 Subject: [PATCH 7/8] fix lint --- tests/python/dlight/test_gpu_low_batch_gemv.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/dlight/test_gpu_low_batch_gemv.py b/tests/python/dlight/test_gpu_low_batch_gemv.py index aea4b1bda745..5827b7b81077 100644 --- a/tests/python/dlight/test_gpu_low_batch_gemv.py +++ b/tests/python/dlight/test_gpu_low_batch_gemv.py @@ -55,7 +55,7 @@ def before(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T.B with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv807[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate[v_i2, v_k] - + @T.prim_func(private=True) def expected(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T.Buffer((T.int64(4096), T.int64(896)), "float16"), p_lv807: T.handle, p_output0: T.handle): T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) @@ -140,7 +140,7 @@ def expected(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size) T.reads(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1]) T.writes(NT_matmul_intermediate[v0, T.int64(0), v1]) - NT_matmul_intermediate[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] + NT_matmul_intermediate[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] # fmt: on mod = tvm.IRModule({"main": before}) with Target("metal"): @@ -167,7 +167,7 @@ def before(var_A: T.handle, B: T.Buffer((T.int64(N), T.int64(K)), "float16"), va with T.init(): NT_matmul[v_i0, v_i1, v_i2] = T.float16(0) NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * B[v_i2, v_k] - + @T.prim_func(private=True) def expected(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), var_NT_matmul: T.handle): T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) From 75b49fedd64e0260e273563670a963c818964d8c Mon Sep 17 00:00:00 2001 From: Hongyi Jin Date: Tue, 20 Feb 2024 14:53:11 -0500 Subject: [PATCH 8/8] fix lint --- python/tvm/dlight/gpu/low_batch_gemv.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py b/python/tvm/dlight/gpu/low_batch_gemv.py index c55a3a86d0db..dfed020853e9 100644 --- a/python/tvm/dlight/gpu/low_batch_gemv.py +++ b/python/tvm/dlight/gpu/low_batch_gemv.py @@ -31,7 +31,6 @@ try_inline_contiguous_spatial, ) from .base import GPUScheduleRule -from .matmul import auto_inline_producers def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: