From f6a3d33030a7116e9a0d9036dcf6d6804cc2d8e2 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Fri, 4 Aug 2023 14:26:58 +0800 Subject: [PATCH] [Unity][DLight] Use less shared memory for gemv This PR fixes the issue of the GEMV rule uses too much shared memory on llama-70B model. --- python/tvm/dlight/gpu/gemv.py | 16 +++--- tests/python/dlight/test_gpu_gemv.py | 80 ++++++++++++++-------------- 2 files changed, 51 insertions(+), 45 deletions(-) diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index 4c11aa778057..13dee1cd54e9 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -219,19 +219,23 @@ def get_extent(loop_rv: tir.schedule.LoopRV): else: len_ty = min(len_s, 4) + # Use `split_k` to prevent too large shared memory usage + split_k: int = 4 + _, tx = sch.split(r, [None, len_tx], preserve_unit_iters=True) # Schedule the RF block rf = sch.rfactor(tx, 0) batch, bx, r, tx, _ = sch.get_loops(rf) sch.reorder(bx, tx, r) + ro, ri = sch.split(r, [split_k, None], preserve_unit_iters=True) bx, ty = sch.split(bx, [None, len_ty], preserve_unit_iters=True) + sch.bind(batch, "blockIdx.y") sch.bind(bx, "blockIdx.x") sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") - unit = sch.add_unit_loop(r) - sch.annotate(unit, "pragma_auto_unroll_max_step", unroll_number) - sch.annotate(unit, "pragma_unroll_explicit", 1) + sch.annotate(ro, "pragma_auto_unroll_max_step", unroll_number) + sch.annotate(ro, "pragma_unroll_explicit", 1) if target.kind.name == "cuda": # Cache read the vector @@ -239,7 +243,7 @@ def cache_shared(index: int): block: tir.Block = sch.get(rf) type_bytes: int = get_bytes(block.reads[index].buffer.dtype) cache = sch.cache_read(rf, index, "shared") - sch.compute_at(cache, unit, preserve_unit_loops=True) + sch.compute_at(cache, ro, preserve_unit_loops=True) fused = sch.fuse(*sch.get_loops(cache)[5:]) loop: tir.For = sch.get(fused) vec_length = vec_bytes // type_bytes @@ -256,7 +260,7 @@ def cache_local(index: int): type_bytes: int = get_bytes(block.reads[index].buffer.dtype) vec_length = vec_bytes // type_bytes cache = sch.cache_read(rf, index, "local") - sch.compute_at(cache, r, preserve_unit_loops=True) + sch.compute_at(cache, ri, preserve_unit_loops=True) fused = sch.fuse(*sch.get_loops(cache)[6:]) loop: tir.For = sch.get(fused) if isinstance(loop.extent, tir.IntImm) and loop.extent.value % vec_length == 0: @@ -273,7 +277,7 @@ def cache_local(index: int): # TODO: cache scale buffer in Decode-GEMV to shared memory sch.set_scope(rf, 0, "local") - sch.decompose_reduction(rf, r) + sch.decompose_reduction(rf, ro) # Schedule the write back block sch.reverse_compute_at(block, ty, preserve_unit_loops=True) _, _, _, tx, *s = sch.get_loops(block) diff --git a/tests/python/dlight/test_gpu_gemv.py b/tests/python/dlight/test_gpu_gemv.py index fb6315f8021c..6cb5cceb4320 100644 --- a/tests/python/dlight/test_gpu_gemv.py +++ b/tests/python/dlight/test_gpu_gemv.py @@ -97,40 +97,40 @@ def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p for ax1_fused_0 in T.thread_binding(n, thread="blockIdx.x"): for ax1_fused_1 in T.thread_binding(1, thread="threadIdx.y"): for ax2_fused_1 in T.thread_binding(32, thread="threadIdx.x"): - for u in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + with T.block("NT_matmul_rf_init"): + vax2_fused_1, v0 = T.axis.remap("SS", [ax2_fused_1, ax0_fused]) + v1 = T.axis.spatial(n, ax1_fused_0 + ax1_fused_1) + T.reads() + T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1]) + var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1] = T.float16(0) + for ax2_fused_0_0 in T.serial(4, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0_ax1_ax2_ax3_fused_0 in range(1): for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(1, thread="threadIdx.y"): for ax0_ax1_ax2_ax3_fused_2 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_3 in T.vectorized(4): + for ax0_ax1_ax2_ax3_fused_3 in T.vectorized(1): with T.block("lv1637_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(32, ax0_fused) v2 = T.axis.spatial(1, 0) - v3 = T.axis.spatial(128, ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 128 + ax0_ax1_ax2_ax3_fused_2 * 4 + ax0_ax1_ax2_ax3_fused_3) + v3 = T.axis.spatial(128, ax2_fused_0_0 * 32 + ax0_ax1_ax2_ax3_fused_0 * 32 + ax0_ax1_ax2_ax3_fused_1 * 32 + ax0_ax1_ax2_ax3_fused_2 + ax0_ax1_ax2_ax3_fused_3) T.reads(lv1637[v0, v1, v2, v3]) T.writes(lv1637_shared[v0, v1, v2, v3]) lv1637_shared[v0, v1, v2, v3] = lv1637[v0, v1, v2, v3] - with T.block("NT_matmul_rf_init"): - vax2_fused_1, v0 = T.axis.remap("SS", [ax2_fused_1, ax0_fused]) - v1 = T.axis.spatial(n, ax1_fused_0 + ax1_fused_1) - T.reads() - T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1]) - var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1] = T.float16(0) - for ax2_fused_0 in range(4): + for ax2_fused_0_1 in range(1): for ax0_ax1_ax2_ax3_fused in T.vectorized(1): with T.block("lv1637_shared_local"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(32, ax0_fused) v2 = T.axis.spatial(1, 0) - v3 = T.axis.spatial(128, ax2_fused_0 * 32 + ax2_fused_1) + v3 = T.axis.spatial(128, ax2_fused_0_0 * 32 + ax2_fused_1) T.reads(lv1637_shared[v0, v1, v2, v3]) T.writes(lv1637_shared_local[v0, v1, v2, v3]) lv1637_shared_local[v0, v1, v2, v3] = lv1637_shared[v0, v1, v2, v3] - for u_1 in range(1): + for u in range(1): with T.block("NT_matmul_rf_update"): vax2_fused_1, v0 = T.axis.remap("SS", [ax2_fused_1, ax0_fused]) v1 = T.axis.spatial(n, ax1_fused_0 + ax1_fused_1) - vax2_fused_0 = T.axis.reduce(4, ax2_fused_0) + vax2_fused_0 = T.axis.reduce(4, ax2_fused_0_0 + ax2_fused_0_1) T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1], lv1637_shared_local[0, v0, 0, vax2_fused_0 * 32 + vax2_fused_1], lv1638[0, v0, v1, vax2_fused_0 * 32 + vax2_fused_1]) T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1]) var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1] = var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1] + lv1637_shared_local[0, v0, 0, vax2_fused_0 * 32 + vax2_fused_1] * lv1638[0, v0, v1, vax2_fused_0 * 32 + vax2_fused_1] @@ -186,31 +186,31 @@ def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 12 for ax0_fused_0 in T.thread_binding(2752, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(8, thread="threadIdx.y"): for ax1_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): - for u in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - for ax0_ax1_ax2_fused_0 in range(2): + with T.block("NT_matmul_rf_init"): + vax1_0_fused_1 = T.axis.spatial(32, ax1_0_fused_1) + v0 = T.axis.spatial(22016, ax0_fused_0 * 8 + ax0_fused_1) + T.reads() + T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) + var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0) + for ax1_0_fused_0_0 in T.serial(4, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax0_ax1_ax2_fused_0 in range(1): for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.y"): for ax0_ax1_ax2_fused_2 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_ax2_fused_3 in T.vectorized(8): + for ax0_ax1_ax2_fused_3 in T.vectorized(4): with T.block("lv1654_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1, 0) - v2 = T.axis.spatial(4096, ax0_ax1_ax2_fused_0 * 2048 + ax0_ax1_ax2_fused_1 * 256 + ax0_ax1_ax2_fused_2 * 8 + ax0_ax1_ax2_fused_3) + v2 = T.axis.spatial(4096, ax1_0_fused_0_0 * 1024 + ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 * 128 + ax0_ax1_ax2_fused_2 * 4 + ax0_ax1_ax2_fused_3) T.reads(lv1654[v0, v1, v2]) T.writes(lv1654_shared[v0, v1, v2]) lv1654_shared[v0, v1, v2] = lv1654[v0, v1, v2] - with T.block("NT_matmul_rf_init"): - vax1_0_fused_1 = T.axis.spatial(32, ax1_0_fused_1) - v0 = T.axis.spatial(22016, ax0_fused_0 * 8 + ax0_fused_1) - T.reads() - T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) - var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0) - for ax1_0_fused_0 in range(16): + for ax1_0_fused_0_1 in range(4): for ax0_ax1_ax2_fused_0 in range(1): for ax0_ax1_ax2_fused_1 in T.vectorized(8): with T.block("lv1654_shared_local"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1, 0) - v2 = T.axis.spatial(4096, ax1_0_fused_0 * 256 + ax1_0_fused_1 * 8 + ax0_ax1_ax2_fused_0 * 8 + ax0_ax1_ax2_fused_1) + v2 = T.axis.spatial(4096, ax1_0_fused_0_0 * 1024 + ax1_0_fused_0_1 * 256 + ax1_0_fused_1 * 8 + ax0_ax1_ax2_fused_0 * 8 + ax0_ax1_ax2_fused_1) T.reads(lv1654_shared[v0, v1, v2]) T.writes(lv1654_shared_local[v0, v1, v2]) lv1654_shared_local[v0, v1, v2] = lv1654_shared[v0, v1, v2] @@ -218,7 +218,8 @@ def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 12 with T.block("NT_matmul_rf_update"): vax1_0_fused_1 = T.axis.spatial(32, ax1_0_fused_1) v0 = T.axis.spatial(22016, ax0_fused_0 * 8 + ax0_fused_1) - vax1_0_fused_0, vax1_1 = T.axis.remap("RR", [ax1_0_fused_0, ax1_1]) + vax1_0_fused_0 = T.axis.reduce(16, ax1_0_fused_0_0 * 4 + ax1_0_fused_0_1) + vax1_1 = T.axis.reduce(8, ax1_1) T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0], lv1654_shared_local[0, 0, vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1], lv571[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 8], lv572[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 32]) T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] + lv1654_shared_local[0, 0, vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv571[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv572[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 32]) @@ -278,31 +279,31 @@ def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 12 for ax0_fused_0 in T.thread_binding(4000, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(8, thread="threadIdx.y"): for ax1_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): - for u in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - for ax0_ax1_ax2_fused_0 in range(2): + with T.block("NT_matmul_rf_init"): + vax1_0_fused_1 = T.axis.spatial(32, ax1_0_fused_1) + v0 = T.axis.spatial(32000, ax0_fused_0 * 8 + ax0_fused_1) + T.reads() + T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) + var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0) + for ax1_0_fused_0_0 in T.serial(4, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax0_ax1_ax2_fused_0 in range(1): for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.y"): for ax0_ax1_ax2_fused_2 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_ax2_fused_3 in T.vectorized(8): + for ax0_ax1_ax2_fused_3 in T.vectorized(4): with T.block("lv3216_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1, 0) - v2 = T.axis.spatial(4096, ax0_ax1_ax2_fused_0 * 2048 + ax0_ax1_ax2_fused_1 * 256 + ax0_ax1_ax2_fused_2 * 8 + ax0_ax1_ax2_fused_3) + v2 = T.axis.spatial(4096, ax1_0_fused_0_0 * 1024 + ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 * 128 + ax0_ax1_ax2_fused_2 * 4 + ax0_ax1_ax2_fused_3) T.reads(lv3216[v0, v1, v2]) T.writes(lv3216_shared[v0, v1, v2]) lv3216_shared[v0, v1, v2] = lv3216[v0, v1, v2] - with T.block("NT_matmul_rf_init"): - vax1_0_fused_1 = T.axis.spatial(32, ax1_0_fused_1) - v0 = T.axis.spatial(32000, ax0_fused_0 * 8 + ax0_fused_1) - T.reads() - T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) - var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0) - for ax1_0_fused_0 in range(16): + for ax1_0_fused_0_1 in range(4): for ax0_ax1_ax2_fused_0 in range(1): for ax0_ax1_ax2_fused_1 in T.vectorized(8): with T.block("lv3216_shared_local"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1, 0) - v2 = T.axis.spatial(4096, ax1_0_fused_0 * 256 + ax1_0_fused_1 * 8 + ax0_ax1_ax2_fused_0 * 8 + ax0_ax1_ax2_fused_1) + v2 = T.axis.spatial(4096, ax1_0_fused_0_0 * 1024 + ax1_0_fused_0_1 * 256 + ax1_0_fused_1 * 8 + ax0_ax1_ax2_fused_0 * 8 + ax0_ax1_ax2_fused_1) T.reads(lv3216_shared[v0, v1, v2]) T.writes(lv3216_shared_local[v0, v1, v2]) lv3216_shared_local[v0, v1, v2] = lv3216_shared[v0, v1, v2] @@ -310,7 +311,8 @@ def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 12 with T.block("NT_matmul_rf_update"): vax1_0_fused_1 = T.axis.spatial(32, ax1_0_fused_1) v0 = T.axis.spatial(32000, ax0_fused_0 * 8 + ax0_fused_1) - vax1_0_fused_0, vax1_1 = T.axis.remap("RR", [ax1_0_fused_0, ax1_1]) + vax1_0_fused_0 = T.axis.reduce(16, ax1_0_fused_0_0 * 4 + ax1_0_fused_0_1) + vax1_1 = T.axis.reduce(8, ax1_1) T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0], lv3216_shared_local[0, 0, vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1], lv771[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 8], lv772[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 32]) T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] + lv3216_shared_local[0, 0, vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv771[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv772[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 32])