From 298052a52d24b28ba1b0d8901f1493ff32a85280 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Mon, 10 Jun 2024 08:46:10 +0000 Subject: [PATCH] [SME] Extract gemm block correctly when fused with bias/activation Prior to this commit, the scheduling assumed the gemm block would be the second to last block in the function ("unpadding" step is the final block). However, when dense is fused with a bias or activation the gemm block is no longer the second to last block. This commit instead searches a single reduction block to use as the gemm block. Change-Id: I1932a490bb3fb72c0c081862349486838c15e6de --- python/tvm/topi/arm_cpu/matmul.py | 8 +++--- .../codegen/test_target_codegen_aarch64.py | 15 +++++++++++ .../relay/strategy/arm_cpu/test_dense.py | 26 ++++++++++++------- 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/python/tvm/topi/arm_cpu/matmul.py b/python/tvm/topi/arm_cpu/matmul.py index 2f09e24c87a2..23b8734a0ba4 100644 --- a/python/tvm/topi/arm_cpu/matmul.py +++ b/python/tvm/topi/arm_cpu/matmul.py @@ -26,6 +26,7 @@ from tvm.topi.utils import get_const_tuple from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes from tvm.topi.arm_cpu.arm_utils import pad_dim_to_multiple +from tvm.dlight.base.analysis import normalize_prim_func @autotvm.register_topi_compute("matmul.arm_cpu.sme") @@ -126,9 +127,10 @@ def tir_schedule_matmul_sme(sch): in_dtype = main_func.buffer_map[data_handle].dtype out_dtype = "float32" - root_block = sch.get_block(main_func.body.block.name_hint) - gemm_block = sch.get_child_blocks(root_block)[-2] - + block_infos = normalize_prim_func(sch) + reduction_block_infos = [block_info for block_info in block_infos if block_info.is_reduction()] + assert len(reduction_block_infos) == 1, "Expected a single gemm reduction block." + gemm_block = reduction_block_infos[0].block_rv gemm_block_name = sch.get(gemm_block).name_hint transpose = gemm_block_name.split("_")[-1] transpose_b = transpose[1] == "T" diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 77c22761a9c8..9b0408b949a0 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -540,6 +540,21 @@ def check_correct_assembly(dtype): check_correct_assembly(dtype=dtype) +def test_matmul_sme_no_reduction_block(): + @T.prim_func + def prim_func(a: T.handle, b: T.handle): + A = T.match_buffer(a, (4,)) + B = T.match_buffer(b, (4,)) + for i in range(3): + with T.block("block"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + + sch = tvm.tir.Schedule(prim_func) + with pytest.raises(AssertionError, match="Expected a single gemm reduction block."): + tvm.topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch) + + @pytest.mark.skipif( llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" ) diff --git a/tests/python/relay/strategy/arm_cpu/test_dense.py b/tests/python/relay/strategy/arm_cpu/test_dense.py index 3a8427e8154d..fee8a87f1253 100644 --- a/tests/python/relay/strategy/arm_cpu/test_dense.py +++ b/tests/python/relay/strategy/arm_cpu/test_dense.py @@ -99,16 +99,16 @@ class TestDense(BasicDenseTests): ) @tvm.testing.requires_aprofile_aem_fvp @pytest.mark.parametrize( - "data_shape,weight_shape", + "data_shape,weight_shape,enable_bias", [ - ((32, 32), (32, 32)), - ((2, 35), (6, 35)), - ((3, 3), (68, 3)), - ((79, 65), (152, 65)), + ((32, 32), (32, 32), False), + ((2, 35), (6, 35), False), + ((3, 3), (68, 3), False), + ((79, 65), (152, 65), True), ], ) @pytest.mark.parametrize("in_dtype", ["float32", "float16"]) -def test_sme_dense(data_shape, weight_shape, in_dtype): +def test_sme_dense(data_shape, weight_shape, enable_bias, in_dtype): np.random.seed(0) out_dtype = "float32" @@ -117,8 +117,14 @@ def test_sme_dense(data_shape, weight_shape, in_dtype): weight_data = np.random.uniform(size=weight_shape).astype(in_dtype) weight = relay.const(weight_data, dtype=in_dtype) - dense = relay.nn.dense(inp, weight, out_dtype=out_dtype) - func = relay.Function(relay.analysis.free_vars(dense), dense) + relay_op = relay.nn.dense(inp, weight, out_dtype=out_dtype) + + if enable_bias: + bias_data = np.random.uniform(size=weight_shape[0]).astype(out_dtype) + bias = relay.const(bias_data, dtype=out_dtype) + relay_op = relay.nn.bias_add(relay_op, bias) + + func = relay.Function(relay.analysis.free_vars(relay_op), relay_op) ir_mod = tvm.IRModule.from_expr(func) ir_mod = tvm.relay.transform.InferType()(ir_mod) @@ -147,8 +153,10 @@ def test_sme_dense(data_shape, weight_shape, in_dtype): runtime=runtime, params=params, ) + + bias_postfix = "_add" if enable_bias else "" generated_func = executor_factory.lowered_ir_mods.items()[0][1][ - "tvmgen_default_fused_nn_matmul" + f"tvmgen_default_fused_nn_matmul{bias_postfix}" ] extra_memory_in_bytes = calculate_extra_workspace_size_from_scalable_extents(generated_func, 4)