Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,18 @@ def schedule_dense_arm_cpu(attrs, inputs, out_type, target):
plevel=12,
)

if (
target.features.is_aarch64
and data.dtype in ["float16", "float32"]
and weight.dtype in ["float16", "float32"]
and out_type.dtype in ["float16", "float32"]
):
strategy.add_implementation(
wrap_compute_dense(topi.arm_cpu.dense_gemm),
wrap_topi_schedule(topi.arm_cpu.schedule_dense_gemm),
name="dense_gemm.arm_cpu",
plevel=11,
)
# Fallback to x86 schedules as there is currently no arm_cpu schedule for dense
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_nopack),
Expand Down Expand Up @@ -773,6 +785,19 @@ def matmul_strategy_arm_cpu(attrs, inputs, out_type, target):
lambda: None,
name="matmul.arm_cpu.sme",
)
elif (
target.features.is_aarch64
and data.dtype in ["float16", "float32"]
and weight.dtype in ["float16", "float32"]
and out_type.dtype in ["float16", "float32"]
and not (attrs.transpose_a or attrs.transpose_b)
and len(data.shape) == 2
):
strategy.add_implementation(
wrap_compute_matmul(topi.arm_cpu.dense_gemm),
wrap_topi_schedule(topi.arm_cpu.schedule_dense_gemm),
name="matmul.arm_cpu.neon",
)
return strategy

logger.warning("matmul is not optimized for arm cpu.")
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,11 @@ def _multi_gpu_exists():
"x86", "x86 Architecture", run_time_check=lambda: platform.machine() == "x86_64"
)

# Mark a test as requiring the aarch64 Architecture to run.
requires_aarch64 = Feature(
"AArch64", "AArch64 Architecture", run_time_check=lambda: platform.machine() == "aarch64"
)

# Mark a test as requiring the CUDA runtime.
requires_cuda = Feature(
"cuda",
Expand Down
21 changes: 15 additions & 6 deletions python/tvm/topi/arm_cpu/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,29 @@
# under the License.
"""Dense schedule for ARM CPU"""
from tvm import autotvm

from .mprofile.dsp.dense import (
dense_dsp_schedule,
dense_dsp_compute,
)
from .mprofile.dsp.dense import dense_dsp_schedule, dense_dsp_compute
from .dense_gemm import dense_gemm_compute, dense_gemm_schedule


@autotvm.register_topi_compute("dense_dsp.arm_cpu")
def dense_dsp(cfg, data, weight, bias, out_dtype):
"""Compute dense_dsp with v7e-m DSP instructions."""
"""Compute dense with DSP instructions."""
return dense_dsp_compute(cfg, data, weight, bias=bias, out_dtype=out_dtype)


@autotvm.register_topi_schedule("dense_dsp.arm_cpu")
def schedule_dense_dsp(cfg, outs):
"""Create schedule for dense_dsp"""
return dense_dsp_schedule(cfg, outs)


@autotvm.register_topi_compute("dense_gemm.arm_cpu")
def dense_gemm(cfg, data, weight, bias, out_dtype, transpose_a=False, transpose_b=True):
"""Compute dense using GeMM."""
return dense_gemm_compute(cfg, data, weight, bias, out_dtype, transpose_a, transpose_b)


@autotvm.register_topi_schedule("dense_gemm.arm_cpu")
def schedule_dense_gemm(cfg, outs):
"""Create schedule for dense using GeMM."""
return dense_gemm_schedule(cfg, outs)
34 changes: 27 additions & 7 deletions python/tvm/topi/arm_cpu/dense_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Dense alter op definitions for the `arm_cpu` device key."""

import tvm
Expand Down Expand Up @@ -47,13 +48,11 @@ def _alter_dense(attrs, inputs, tinfos, out_type):

cfg = dispatch_ctx.query(target, workload)
topi_impl = workload[0]

if topi_impl == "matmul.arm_cpu.sme":
# Pre-compute transposed weights and convert to a matmul
assert isinstance(
inputs[1], relay.Constant
), "matmul_sme.arm_cpu requires weights be a Relay Constant"

weight_dtype = tinfos[1].dtype
N, K = tinfos[1].shape
encoded_weight = inputs[1]

# For dense the weights (rhs) are provided in transposed format,
Expand All @@ -65,15 +64,15 @@ def _alter_dense(attrs, inputs, tinfos, out_type):
# float16->float32 schedule the transformation currently happens at runtime
# with the ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE intrinsic.
if weight_dtype == "float32":
encoded_weight = relay.const(encoded_weight.data.numpy().transpose(), weight_dtype)
encoded_weight = relay.transpose(encoded_weight)
transpose_b = False

new_weight = te.placeholder((encoded_weight.data.shape), dtype=weight_dtype)
new_weight = te.placeholder(([K, N]), dtype=weight_dtype)

new_workload = autotvm.task.args_to_workload(
[tinfos[0], new_weight, None, out_type.dtype, False, transpose_b], topi_impl
)
dispatch_ctx.update(target, new_workload, cfg)

return _make.matmul(
inputs[0],
encoded_weight,
Expand All @@ -82,6 +81,27 @@ def _alter_dense(attrs, inputs, tinfos, out_type):
False,
transpose_b,
)
elif topi_impl == "dense_gemm.arm_cpu":

weight_dtype = tinfos[1].dtype
N, K = tinfos[1].shape

encoded_weight = relay.transpose(inputs[1])
new_weight = te.placeholder(([K, N]), dtype=weight_dtype)

new_workload = autotvm.task.args_to_workload(
[tinfos[0], new_weight, None, out_type.dtype, False, False], topi_impl
)
dispatch_ctx.update(target, new_workload, cfg)

return _make.matmul(
inputs[0],
encoded_weight,
attrs.units,
attrs.out_dtype,
False,
False,
)

# x86 schedules are used as a fallback
return tvm.topi.x86.dense_alter_op._alter_dense_layout(attrs, inputs, tinfos, out_type)
174 changes: 174 additions & 0 deletions python/tvm/topi/arm_cpu/dense_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable, too-many-locals
# pylint: disable=unused-argument, redefined-builtin
"""GeMM dense schedule on AArch64"""
import tvm
from tvm import te
from tvm.topi import nn
from tvm.topi.arm_cpu.arm_utils import get_tiling_A, get_tiling_B_transformed, pad_dim_to_multiple
from ..utils import get_const_tuple, traverse_inline
from .. import tag

# Compute function
def dense_gemm_compute(
cfg, data, weight, bias=None, out_dtype=None, transpose_a=False, transpose_b=True
):
"""
Compute dense using GeMM.

Parameters
----------
cfg : Autotvm tuning space config file,
empty in this case, but it's needed as an arg.

data : tvm.te.Tensor
2-D with shape [M, K] or [K, M].

weight : tvm.te.Tensor
2-D with shape [K, N] or [N, K].

bias : Optional[tvm.te.Tensor]
1-D with shape [N]


out_dtype : Optional[str]
Specifies the output data type.

transpose_a : Optional[bool] = False
Whether the data tensor is in transposed format.

transpose_b : Optional[bool] = True
Whether the weight tensor is in transposed format.
Comment thread
eirenevp marked this conversation as resolved.

Returns
-------
out : tvm.te.Tensor
1-D with shape [out_dim]
"""

if out_dtype is None:
out_dtype = data.dtype
M, K = get_const_tuple(data.shape) # batch, in_dim
if bool(transpose_b): # out_dim
(N, _) = get_const_tuple(weight.shape)
else:
(_, N) = get_const_tuple(weight.shape)

tile_M, tile_K = get_tiling_A(False, out_dtype)
tile_N, _ = get_tiling_B_transformed(False, out_dtype, False)

M_padded, pad_M = pad_dim_to_multiple(M, tile_M)
K_padded, pad_K = pad_dim_to_multiple(K, tile_K)
N_padded, pad_N = pad_dim_to_multiple(N, tile_N)
m_pad_after = (pad_M, pad_K)
n_pad_after = (pad_N, pad_K) if transpose_b else (pad_K, pad_N)

if pad_M != 0 or pad_K != 0:
data = nn.pad(data, pad_before=(0, 0), pad_after=m_pad_after, name="data_padded")

k = te.reduce_axis((0, K_padded), name="k")

if bool(transpose_b):
weight = te.compute(
(K_padded, N_padded), lambda x, y: weight[y, x], name="weight_transposed"
)

if pad_N != 0 or pad_K != 0:
weight = nn.pad(weight, pad_before=(0, 0), pad_after=n_pad_after, name="weight_padded")

C = te.compute(
(M_padded, N_padded),
lambda x, y: te.sum(
data[x, k].astype(out_dtype) * weight[k, y].astype(out_dtype),
axis=k,
).astype(out_dtype),
name="C",
)

if bias is not None:
C = te.compute(
(M_padded, N_padded),
lambda i, j: C[i, j] + bias[j].astype(out_dtype),
tag=tag.BROADCAST,
name="dense_biased_output",
)

# We need to ensure that infer bound pass does not remove the padding
# which is necessary for the tensorizations to work. So we need to
# add a dummy reference to the padding area of the result
zero = (
Comment thread
eirenevp marked this conversation as resolved.
tvm.tir.const(1, C.dtype) * C[0, N_padded - 1]
- tvm.tir.const(1, C.dtype) * C[0, N_padded - 1]
)

out = te.compute(
(M, N), lambda x, y: (C[x, y] + zero).astype(out_dtype), name="dense_gemm_output"
)

return out


def _dense_gemm_schedule(s, out):
C = out.op.input_tensors[0]
A = C.op.input_tensors[0]
out_type = A.dtype
tile_M, tile_K = get_tiling_A(False, out_type)
tile_N, _ = get_tiling_B_transformed(False, out_type, False)

if C.op.name == "dense_biased_output":
s[C].compute_inline()
C = C.op.input_tensors[0]
x, y = s[C].op.axis
(k,) = s[C].op.reduce_axis

k_outer, k_inner = s[C].split(k, factor=tile_K)
x_outer, x_inner = s[C].split(x, factor=tile_M)
y_outer, y_inner = s[C].split(y, factor=tile_N)
y_inner_outer, y_inner_inner = s[C].split(y_inner, nparts=4)
s[C].parallel(x_outer)
s[C].reorder(
x_outer,
y_outer,
k_outer,
k_inner,
y_inner_outer,
x_inner,
y_inner_inner,
)
s[C].unroll(y_inner_outer)
s[C].unroll(x_inner)
s[C].vectorize(y_inner_inner)

return s


def dense_gemm_schedule(cfg, outs):
"""Schedule the dense_gemm strategy"""
s = te.create_schedule([x.op for x in outs])
out = outs[0]
x, y = out.op.axis
_, inner = s[out].split(y, 4)
s[out].parallel(x)
s[out].vectorize(inner)

def _callback(op):
if "dense_gemm_output" in op.name:
_dense_gemm_schedule(s, op.output(0))

traverse_inline(s, out.op, _callback)
return s
2 changes: 2 additions & 0 deletions python/tvm/topi/nn/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def matmul(
assert (
len(tensor_a.shape) >= 2 and len(tensor_b.shape) >= 2
), "1-dim matmul is not supported yet."

if bias is not None:
assert len(bias.shape) == 1
if out_dtype is None:
Expand Down Expand Up @@ -229,6 +230,7 @@ def dense(
output : tvm.te.Tensor
2-D with shape [batch, out_dim]
"""

return matmul(
data,
weight,
Expand Down
2 changes: 1 addition & 1 deletion tests/python/frontend/keras/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def get_keras_output(in_data):
def get_tvm_output(in_data, target, dev, dtype="float32"):
shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, in_data)}
mod, params = relay.frontend.from_keras(keras_model, shape_dict, layout=layout)
with tvm.transform.PassContext(opt_level=2):
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target, params=params)
m = graph_executor.GraphModule(lib["default"](dev))
for name, x in zip(keras_model.input_names, in_data):
Expand Down
Loading