Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions python/tvm/topi/arm_cpu/conv2d_spatial_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tvm
from tvm import te
from tvm import autotvm
from tvm.target import Target
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity, AnnotateEntity, ReorderEntity
from .. import nn
from ..utils import get_const_tuple
Expand Down Expand Up @@ -316,12 +317,23 @@ def _tile_size(axis, candidates):
return candidate
return 1

# Tile size 8 results in efficient vectorization for these schedules.
# If the axis is not divisible by 8, try 4
# For data tensors with unity height and width we can leave it to the
# backend to vectorize the inner loop. This has been observed to be more
# performant on SVE targets with a vector width > 128bits.
target = Target.current(allow_none=False)
if target.features.has_sve and OW == OH and OW == 1:
Comment thread
FranklandJack marked this conversation as resolved.
tile_size = [OC]
vectorize = "none"
else:
# Tile size 8 results in efficient vectorization for these schedules.
# If the axis is not divisible by 8, try 4
tile_size = [8, 4]
vectorize = "vec"

cfg["tile_oh"] = SplitEntity([-1, 1])
cfg["tile_ow"] = SplitEntity([-1, _tile_size(OW, [8, 4])])
cfg["tile_co"] = SplitEntity([-1, _tile_size(OC, [8, 4])])
cfg["ann_spatial"] = AnnotateEntity(["none", "vec"])
cfg["tile_co"] = SplitEntity([-1, _tile_size(OC, tile_size)])
cfg["ann_spatial"] = AnnotateEntity(["none", vectorize])
cfg["ann_reduce"] = AnnotateEntity(["none", "none"])
cfg["reorder_conv"] = ReorderEntity([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
cfg["compat"] = OtherOptionEntity(0)
Expand Down
35 changes: 18 additions & 17 deletions src/target/parsers/aprofile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,29 +106,30 @@ static TargetFeatures GetFeatures(TargetJSON target) {
Optional<String> mtriple = Downcast<Optional<String>>(target.Get("mtriple"));
Optional<Array<String>> mattr = Downcast<Optional<Array<String>>>(target.Get("mattr"));

double arch_version = GetArchVersion(mattr);
const double arch_version = GetArchVersion(mattr);

bool is_aarch64 = IsAArch64(mtriple);
const bool is_aarch64 = IsAArch64(mtriple);

bool simd_flag = HasFlag(mcpu, mattr, "+neon") || HasFlag(mcpu, mattr, "+simd");
bool has_asimd = is_aarch64 || simd_flag;
const bool simd_flag = HasFlag(mcpu, mattr, "+neon") || HasFlag(mcpu, mattr, "+simd");
const bool has_asimd = is_aarch64 || simd_flag;
const bool has_sve = HasFlag(mcpu, mattr, "+sve");

bool i8mm_flag = HasFlag(mcpu, mattr, "+i8mm");
bool i8mm_disable = HasFlag(mcpu, mattr, "+noi8mm");
bool i8mm_default = arch_version >= 8.6;
bool i8mm_support = arch_version >= 8.2 && arch_version <= 8.5;
bool has_i8mm = (i8mm_default && !i8mm_disable) || (i8mm_support && i8mm_flag);
const bool i8mm_flag = HasFlag(mcpu, mattr, "+i8mm");
const bool i8mm_disable = HasFlag(mcpu, mattr, "+noi8mm");
const bool i8mm_default = arch_version >= 8.6;
const bool i8mm_support = arch_version >= 8.2 && arch_version <= 8.5;
const bool has_i8mm = (i8mm_default && !i8mm_disable) || (i8mm_support && i8mm_flag);

bool dotprod_flag = HasFlag(mcpu, mattr, "+dotprod");
bool dotprod_disable = HasFlag(mcpu, mattr, "+nodotprod");
bool dotprod_default = arch_version >= 8.4;
bool dotprod_support = arch_version >= 8.2 && arch_version <= 8.3;
bool has_dotprod = (dotprod_default && !dotprod_disable) || (dotprod_support && dotprod_flag);
const bool dotprod_flag = HasFlag(mcpu, mattr, "+dotprod");
const bool dotprod_disable = HasFlag(mcpu, mattr, "+nodotprod");
const bool dotprod_default = arch_version >= 8.4;
const bool dotprod_support = arch_version >= 8.2 && arch_version <= 8.3;
const bool has_dotprod =
(dotprod_default && !dotprod_disable) || (dotprod_support && dotprod_flag);

return {
{"is_aarch64", Bool(is_aarch64)},
{"has_asimd", Bool(has_asimd)},
{"has_dotprod", Bool(has_dotprod)},
{"is_aarch64", Bool(is_aarch64)}, {"has_asimd", Bool(has_asimd)},
{"has_sve", Bool(has_sve)}, {"has_dotprod", Bool(has_dotprod)},
Comment thread
FranklandJack marked this conversation as resolved.
{"has_matmul_i8", Bool(has_i8mm)},
};
}
Expand Down
19 changes: 19 additions & 0 deletions tests/cpp/target/parsers/aprofile_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,28 @@ TEST(AProfileParser, ArchVersionInvalidLetter) {
ASSERT_EQ(Downcast<Bool>(features.at("has_dotprod")), false);
}

using AProfileOptionalSVE = testing::TestWithParam<float>;
TEST_P(AProfileOptionalSVE, OptionalSVESupport) {
const std::string arch_attr = "+v" + std::to_string(GetParam()) + "a";

// Check that the "has_sve" feature is not set by default when "+sve" isn't set as an attribute.
TargetJSON target = ParseTargetWithAttrs("", "aarch64-arm-none-eabi", {arch_attr});
TargetFeatures features = Downcast<TargetFeatures>(target.at("features"));
EXPECT_TRUE(IsArch(target));
EXPECT_FALSE(Downcast<Bool>(features.at("has_sve")));

// Check that the "has_sve" feature is set when "+sve" is explicitly set as an attribute.
target = ParseTargetWithAttrs("", "aarch64-arm-none-eabi", {arch_attr, "+sve"});
features = Downcast<TargetFeatures>(target.at("features"));
EXPECT_TRUE(IsArch(target));
EXPECT_TRUE(Downcast<Bool>(features.at("has_sve")));
}

INSTANTIATE_TEST_CASE_P(AProfileParser, AProfileOptionalI8MM, ::testing::ValuesIn(optionalI8MM));
INSTANTIATE_TEST_CASE_P(AProfileParser, AProfileOptionalDotProd,
::testing::ValuesIn(optionalDotProd));
INSTANTIATE_TEST_CASE_P(AProfileParser, AProfileOptionalSVE,
::testing::Values(8.0, 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7, 8.8, 8.9, 9.0));

} // namespace aprofile
} // namespace parsers
Expand Down