diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 602a198a2bf6..10c7676282e7 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -231,17 +231,18 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { // Current analysis may not be powerful enough to prove expressions containing // the same symbolic value multiple times. However, when the symbolic values are // "T.vscale" and the compile target uses a scalable architecture extension like - // SVE, we can make some assumptions about the value of vscale and iterate over a + // VLA, we can make some assumptions about the value of vscale and iterate over a // space of pre-defined values to attempt to prove the expression. Target curr_target = Target::Current(); if (ContainsVscaleCall(simplified)) { - if (TargetHasSVE(curr_target)) { - return CanProveVscaleExpressionFromKnownValues(this, simplified, kAArch64VScaleValues); + if (TargetHasVLA(curr_target)) { + auto kVScaleValues = GetVScaleValues(curr_target); + return CanProveVscaleExpressionFromKnownValues(this, simplified, kVScaleValues); } LOG(WARNING) << "The expression contains scalable values. An attempt to prove by substituting " "with known values of vscale was not performed. This proof currently only supports " - "AArch64 SVE targets, but the target was " + "VLA targets, but the target was " << curr_target; } return false; diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index ac8ac917114a..7409ecc6f37e 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -364,15 +364,16 @@ class ConstIntBoundAnalyzer::Impl // only special handle >> and & which can be // used for index calculation. + auto curr_target = Target::Current(); if (op->op.same_as(tir::builtin::shift_right())) { return VisitRightShift(op); } else if (op->op.same_as(tir::builtin::shift_left())) { return VisitLeftShift(op); } else if (op->op.same_as(tir::builtin::bitwise_and())) { return VisitBitwiseAnd(op); - } else if (op->op.same_as(tir::builtin::vscale()) && TargetHasSVE(Target::Current())) { - unsigned int max_val = - *std::max_element(kAArch64VScaleValues.begin(), kAArch64VScaleValues.end()); + } else if (op->op.same_as(tir::builtin::vscale()) && TargetHasVLA(curr_target)) { + auto kVScaleValues = GetVScaleValues(curr_target); + unsigned int max_val = *std::max_element(kVScaleValues.begin(), kVScaleValues.end()); return MakeBound(1, max_val); } else { return Everything(op->dtype); diff --git a/src/arith/scalable_expression.cc b/src/arith/scalable_expression.cc index beb75c1f3e09..1937b9c34e03 100644 --- a/src/arith/scalable_expression.cc +++ b/src/arith/scalable_expression.cc @@ -86,14 +86,41 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr return can_prove_expr; } -bool TargetHasSVE(Optional target) { +bool TargetHasVLA(Optional target) { if (!target.defined()) { target = Target::Current(); } + bool has_vla{false}; if (target.defined()) { - return Downcast(target)->GetFeature("has_sve").value_or(Bool(false)); + // aarch64 + has_vla = Downcast(target)->GetFeature("has_sve").value_or(Bool(false)); + // riscv{32,64} + static auto target_has_feature_fn = + tvm::ffi::Function::GetGlobalRequired("target.target_has_feature"); + has_vla |= target_has_feature_fn("v", target).cast(); } - return false; + return has_vla; +} + +const std::vector GetVScaleValues(Optional target) { + unsigned int vector_width = 0; + std::vector kVScaleValues; + if (!target.defined()) { + target = Target::Current(); + } + if (target.defined()) { + static auto llvm_get_vector_width_fn = + tvm::ffi::Function::GetGlobalRequired("target.llvm_get_vector_width"); + vector_width = llvm_get_vector_width_fn(target).cast(); + } + // scale list with powers of two + for (unsigned int i = 0;; ++i) { + auto power = static_cast(std::pow(2, i)); + if (power > (vector_width / 8)) break; + kVScaleValues.push_back(power); + } + + return kVScaleValues; } } // namespace arith diff --git a/src/arith/scalable_expression.h b/src/arith/scalable_expression.h index 70d753a299ef..2470d5dcd827 100644 --- a/src/arith/scalable_expression.h +++ b/src/arith/scalable_expression.h @@ -35,9 +35,6 @@ namespace tvm { namespace arith { -/*! \brief A list of known vscale values to try for an AArch64 SVE target. */ -static const std::vector kAArch64VScaleValues = {1, 2, 4, 8, 16}; - /*! * \brief Check if an expr is a call to the vscale intrinsic. * \param expr The expr to check @@ -80,10 +77,18 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr /*! * \brief Check whether the compilation target supports SVE + * \brief Check whether the compilation target supports VLA + * \param target The target to check. + * \return Whether VLA is supported + */ +bool TargetHasVLA(Optional target = std::nullopt); + +/*! + * \brief Get a list of known vscale values to try for an VLA target. * \param target The target to check. - * \return Whether SVE is supported + * \return A list of vscale values as std::vector */ -bool TargetHasSVE(Optional target = std::nullopt); +const std::vector GetVScaleValues(Optional target = std::nullopt); } // namespace arith } // namespace tvm diff --git a/src/target/llvm/codegen_aarch64.cc b/src/target/llvm/codegen_aarch64.cc index 1399fc083a08..b690c0fc28b1 100644 --- a/src/target/llvm/codegen_aarch64.cc +++ b/src/target/llvm/codegen_aarch64.cc @@ -57,8 +57,8 @@ void CodeGenAArch64::SetTargetAttributes(llvm::Function* func) { #if TVM_LLVM_VERSION >= 130 // Add vscale_range() function attribute when appropriate. if (llvm_target_->TargetHasCPUFeature("sve") || llvm_target_->TargetHasCPUFeature("sme")) { - unsigned int max_val = - *std::max_element(arith::kAArch64VScaleValues.begin(), arith::kAArch64VScaleValues.end()); + auto kVScaleValues = arith::GetVScaleValues(Target::Current()); + unsigned int max_val = *std::max_element(kVScaleValues.begin(), kVScaleValues.end()); func->addFnAttr( llvm::Attribute::getWithVScaleRangeArgs(*llvm_target_->GetContext(), 1, max_val)); } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 3df73f0edb8d..54b2daf83632 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -80,8 +80,8 @@ bool EnableBufferLevelPredication(Target target) { return enable_buffer_predication.value(); } - // Use buffer-level predication by default for AArch64 SVE targets - return arith::TargetHasSVE(target); + // Use buffer-level predication by default for VLA targets + return arith::TargetHasVLA(target); } /*! @@ -972,7 +972,7 @@ class LoopVectorizer : public StmtMutator { if (!extent_as_int || extent_as_int->value < 1) { bool is_scalable_expr = CheckContains::ExprContains(op->extent, arith::IsVScaleCall); - ICHECK(is_scalable_expr && arith::TargetHasSVE(target_)) + ICHECK(is_scalable_expr && arith::TargetHasVLA(target_)) << "Failed to vectorize loop with extent " << op->extent << " for target " << target_; } ICHECK(is_zero(op->min)); diff --git a/tests/python/arith/test_arith_simplify.py b/tests/python/arith/test_arith_simplify.py index 3b0237740045..4971acbd4512 100644 --- a/tests/python/arith/test_arith_simplify.py +++ b/tests/python/arith/test_arith_simplify.py @@ -113,7 +113,7 @@ def test_simplify_vscale_comparison_without_sve_target(capfd): warning_msg = ( "Warning: The expression contains scalable values. An attempt to prove by substituting " "with known values of vscale was not performed. This proof currently only supports " - "AArch64 SVE targets, but the target was llvm -keys=arm_cpu,cpu -mtriple=aarch64-linux-gnu" + "VLA targets, but the target was llvm -keys=arm_cpu,cpu -mtriple=aarch64-linux-gnu" ) capture = capfd.readouterr().err assert warning_msg in capture diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 43870044d528..2c8f185d8ecd 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -43,7 +43,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: A[i] * B[i], name="C") - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and mul instructions using z registers assembly = f.get_source("asm") @@ -73,7 +75,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: A[i] + B[i], name="C") - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and add instructions using z registers assembly = f.get_source("asm") @@ -103,7 +107,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: A[i] - B[i], name="C") - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and sub instructions using z registers assembly = f.get_source("asm") @@ -134,7 +140,9 @@ def check_correct_assembly(type): B = te.placeholder(m, dtype=type, name="B") C = te.placeholder(m, dtype=type, name="C") D = te.compute((m), lambda i: A[i] * B[i] + C[i], name="D") - f = tvm.tir.build(te.create_prim_func([A, B, C, D]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C, D])) # Verify we see SVE load instructions and either mad or mla instructions using z registers assembly = f.get_source("asm") @@ -164,7 +172,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: tvm.te.max(A[i], B[i])) - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and cmgt + sel instructions or a max instruction, all using z registers assembly = f.get_source("asm") @@ -198,7 +208,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: tvm.te.min(A[i], B[i])) - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and cmgt + sel instructions or a min instruction, all using z registers assembly = f.get_source("asm") @@ -232,7 +244,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: tvm.te.div(A[i], B[i])) - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and div instructions using z registers assembly = f.get_source("asm") @@ -261,7 +275,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: tvm.te.floormod(A[i], B[i]), name="C") - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and mls instructions using z registers assembly = f.get_source("asm") @@ -291,7 +307,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: A[i] == B[i], name="C") - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and cmpeq or cmeq instructions using z registers assembly = f.get_source("asm") @@ -321,7 +339,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: A[i] != B[i], name="C") - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and cmpgt, cmgt, cmpne or cmne instructions, all using z registers assembly = f.get_source("asm") @@ -350,7 +370,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: A[i] | B[i], name="C") - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and orr instructions using z registers assembly = f.get_source("asm") @@ -379,7 +401,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: A[i] & B[i], name="C") - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and and instructions using z registers assembly = f.get_source("asm") @@ -407,7 +431,9 @@ def check_correct_assembly(type): m = te.var("m") A = te.placeholder(m, dtype=type, name="A") C = te.compute((m), lambda i: ~A[i], name="C") - f = tvm.tir.build(te.create_prim_func([A, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, C])) # Verify we see SVE load instructions and eor instructions using z registers assembly = f.get_source("asm") @@ -440,7 +466,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype="int32", name="B") C = te.compute((m), lambda i: A[B[i]], name="C") - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see gather instructions in the assembly assembly = f.get_source("asm") @@ -451,65 +479,6 @@ def check_correct_assembly(type): check_correct_assembly(type=dtype) -@pytest.mark.skipif( - llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" -) -def test_codegen_vscale(): - target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - vscale = tvm.tir.vscale() - - @T.prim_func - def main(A: T.Buffer((5,), "int32")): - for i in range(5): - A[i] = 2 * vscale - - build_mod = tvm.tir.build(main, target=target) - llvm = build_mod.get_source() - - assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM." - - -@pytest.mark.skipif( - llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" -) -def test_scalable_buffer_load_store(): - target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - - @T.prim_func - def my_func(a: T.handle, b: T.handle): - A = T.match_buffer(a, (128,), "float32") - B = T.match_buffer(b, (128,), "float32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) - B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())] - - mod = tvm.tir.build(my_func, target=target) - llvm = mod.get_source("ll") - - assert re.findall(r"load ", llvm), "No scalable load in generated LLVM." - assert re.findall(r" store ", llvm), "No scalable store in generated LLVM." - - -@pytest.mark.skipif( - llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" -) -def test_scalable_broadcast(): - target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - - @T.prim_func - def my_func(a: T.handle): - A = T.match_buffer(a, (128,), "float32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) - A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale()) - - mod = tvm.tir.build(my_func, target=target) - llvm = mod.get_source("ll") - - assert re.findall( - r"shufflevector \( insertelement \(", llvm - ), "No scalable broadcast in generated LLVM." - assert re.findall(r" store ", llvm), "No scalable store in generated LLVM." - - @pytest.mark.skipif( llvm_version_major() < 13, reason="Function attribute vscale_range() is not supported in earlier versions of LLVM", @@ -529,7 +498,9 @@ def test_vscale_range_function_attribute(mattr, expect_attr): m = te.var("m") A = te.placeholder(m, dtype="float32", name="A") C = te.compute((m), lambda i: A[i] + 1, name="C") - f = tvm.tir.build(te.create_prim_func([A, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, C])) # Check if the vscale_range() attribute exists ll = f.get_source("ll") @@ -545,49 +516,5 @@ def test_vscale_range_function_attribute(mattr, expect_attr): ), f"Unexpected function attribute vscale_range() was found in generated LLVM IR" -@pytest.mark.skip( - reason="Vscale and get.active.lane.mask are not supported in earlier versions of LLVM", -) -def test_get_active_lane_mask(): - target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - - @T.prim_func - def before(a: T.handle): - A = T.match_buffer(a, (30,), "int1") - for i in range(T.ceildiv(30, T.vscale() * 4)): - A[i : i + T.vscale() * 4] = T.get_active_lane_mask("uint1xvscalex4", i, 30) - - with tvm.target.Target(target): - out = tvm.tir.build(before) - - ll = out.get_source("ll") - assert "get.active.lane.mask" in ll - - -@pytest.mark.skip( - reason="Vscale and get.active.lane.mask are not supported in earlier versions of LLVM", -) -def test_predicated_scalable_buffer(): - target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - - @T.prim_func - def before(a: T.handle, b: T.handle): - A = T.match_buffer(a, (16,), "float32") - B = T.match_buffer(b, (16,), "float32") - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - for i_0 in T.serial(T.ceildiv(16, 4 * T.vscale())): - for i_1 in T.vectorized(4 * T.vscale()): - if i_0 * 4 * T.vscale() + i_1 < 14: - B[i_0 * 4 * T.vscale() + i_1] = A[i_0 * 4 * T.vscale() + i_1] + 1.0 - - with tvm.target.Target(target): - out = tvm.tir.build(before) - - ll = out.get_source("ll") - assert "get.active.lane.mask" in ll - assert "llvm.masked.load" in ll - assert "llvm.masked.store" in ll - - if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/codegen/test_target_codegen_llvm_vla.py b/tests/python/codegen/test_target_codegen_llvm_vla.py new file mode 100644 index 000000000000..7ca3083dd5e3 --- /dev/null +++ b/tests/python/codegen/test_target_codegen_llvm_vla.py @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Codegen tests for VLA extensions +""" + +import re +import pytest + +import tvm +from tvm import te +from tvm.script import tir as T +from tvm.target.codegen import llvm_version_major + + +@pytest.mark.skipif( + llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" +) +@tvm.testing.parametrize_targets( + "llvm -mtriple=aarch64-linux-gnu -mattr=+sve", + "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", +) +def test_codegen_vscale(target): + vscale = tvm.tir.vscale() + + @T.prim_func + def main(A: T.Buffer((5,), "int32")): + for i in range(5): + A[i] = 2 * vscale + + with tvm.target.Target(target): + build_mod = tvm.tir.build(main) + + llvm = build_mod.get_source() + assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM." + + +@pytest.mark.skipif( + llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" +) +@tvm.testing.parametrize_targets( + "llvm -mtriple=aarch64-linux-gnu -mattr=+sve", + "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", +) +def test_scalable_buffer_load_store(target): + @T.prim_func + def my_func(a: T.handle, b: T.handle): + A = T.match_buffer(a, (128,), "float32") + B = T.match_buffer(b, (128,), "float32") + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())] + + with tvm.target.Target(target): + mod = tvm.tir.build(my_func) + + llvm = mod.get_source("ll") + assert re.findall(r"load ", llvm), "No scalable load in generated LLVM." + assert re.findall(r" store ", llvm), "No scalable store in generated LLVM." + + +@pytest.mark.skipif( + llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" +) +@tvm.testing.parametrize_targets( + "llvm -mtriple=aarch64-linux-gnu -mattr=+sve", + "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", +) +def test_scalable_broadcast(target): + @T.prim_func + def my_func(a: T.handle): + A = T.match_buffer(a, (128,), "float32") + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale()) + + with tvm.target.Target(target): + mod = tvm.tir.build(my_func) + + llvm = mod.get_source("ll") + assert re.findall( + r"shufflevector \( insertelement \(", llvm + ), "No scalable broadcast in generated LLVM." + assert re.findall(r" store ", llvm), "No scalable store in generated LLVM." + + +@pytest.mark.skip( + reason="Vscale and get.active.lane.mask are not supported in earlier versions of LLVM", +) +@tvm.testing.parametrize_targets( + "llvm -mtriple=aarch64-linux-gnu -mattr=+sve", + "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", +) +def test_get_active_lane_mask(target): + @T.prim_func + def before(a: T.handle): + A = T.match_buffer(a, (30,), "int1") + for i in range(T.ceildiv(30, T.vscale() * 4)): + A[i : i + T.vscale() * 4] = T.get_active_lane_mask("uint1xvscalex4", i, 30) + + with tvm.target.Target(target): + out = tvm.tir.build(before) + + ll = out.get_source("ll") + assert "get.active.lane.mask" in ll + + +@pytest.mark.skip( + reason="Vscale and get.active.lane.mask are not supported in earlier versions of LLVM", +) +@tvm.testing.parametrize_targets( + "llvm -mtriple=aarch64-linux-gnu -mattr=+sve", + "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", +) +def test_predicated_scalable_buffer(target): + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(16, 4 * T.vscale())): + for i_1 in T.vectorized(4 * T.vscale()): + if i_0 * 4 * T.vscale() + i_1 < 14: + B[i_0 * 4 * T.vscale() + i_1] = A[i_0 * 4 * T.vscale() + i_1] + 1.0 + + with tvm.target.Target(target): + out = tvm.tir.build(before) + + ll = out.get_source("ll") + assert "get.active.lane.mask" in ll + assert "llvm.masked.load" in ll + assert "llvm.masked.store" in ll + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py index 22344acfe1d4..f09f7417baf6 100644 --- a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py +++ b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py @@ -816,7 +816,7 @@ def before(a: T.handle): warning_msg = ( "Warning: The expression contains scalable values. An attempt to prove by substituting " "with known values of vscale was not performed. This proof currently only supports " - "AArch64 SVE targets, but the target was " + "VLA targets, but the target was " ) captured = capfd.readouterr().err assert warning_msg in captured diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index 9b61255285be..13bb1c60cb53 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -670,7 +670,7 @@ def expected(a: T.handle, b: T.handle): def test_vectorize_with_explicitly_disabled_buffer_level_predication(): - # Since the target has the SVE feature, buffer level predication is enabled + # Since the target has the VLA feature, buffer level predication is enabled # by default. However, it has been explicitly disabled by the pass context # option, so no buffer-level predicates should be added. @T.prim_func