diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 8f3ae9b42460..f7284ec690a4 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -111,7 +111,9 @@ class DataType { return -lanes_as_int; } /*! \return get vscale factor or lanes depending on scalability of the vector. */ - int get_lanes_or_vscale_factor() { return is_scalable_vector() ? vscale_factor() : lanes(); } + int get_lanes_or_vscale_factor() const { + return is_scalable_vector() ? vscale_factor() : lanes(); + } /*! \return whether type is a scalar type. */ bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; } /*! \return whether type is a scalar type. */ diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index ce4a4d6a2845..d06bb779d0bb 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -959,10 +960,16 @@ inline PrimExpr MakeConstScalar(DataType t, bool value, Span span) { template inline PrimExpr make_const(DataType t, ValueType value, Span span) { - if (t.lanes() == 1) { + if (t.is_scalar()) { return MakeConstScalar(t, value, span); } else { - return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span); + if (t.is_fixed_length_vector()) { + return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span); + } else { + PrimExpr lanes = + tir::Mul(tir::Call(DataType::Int(32), tir::builtin::vscale(), {}), t.vscale_factor()); + return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), lanes, span); + } } } diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 90dad720393f..2cd2a698debe 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -196,7 +196,8 @@ TVM_REGISTER_NODE_TYPE(StringImmNode); // Cast Cast::Cast(DataType t, PrimExpr value, Span span) { ICHECK(value.defined()); - ICHECK_EQ(t.lanes(), value.dtype().lanes()); + ICHECK_EQ(t.get_lanes_or_vscale_factor(), value.dtype().get_lanes_or_vscale_factor()); + ICHECK(t.is_scalable_vector() == value.dtype().is_scalable_vector()); ObjectPtr node = make_object(); node->dtype = t; node->value = std::move(value); @@ -354,7 +355,8 @@ And::And(PrimExpr a, PrimExpr b, Span span) { ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; ObjectPtr node = make_object(); - node->dtype = DataType::Bool(a.dtype().lanes()); + node->dtype = + DataType::Bool(a.dtype().get_lanes_or_vscale_factor(), a.dtype().is_scalable_vector()); node->a = std::move(a); node->b = std::move(b); node->span = std::move(span); @@ -376,7 +378,8 @@ Or::Or(PrimExpr a, PrimExpr b, Span span) { ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; ObjectPtr node = make_object(); - node->dtype = DataType::Bool(a.dtype().lanes()); + node->dtype = + DataType::Bool(a.dtype().get_lanes_or_vscale_factor(), a.dtype().is_scalable_vector()); node->a = std::move(a); node->b = std::move(b); node->span = std::move(span); @@ -412,7 +415,9 @@ Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Sp ICHECK(true_value.defined()) << "ValueError: true_value is undefined"; ICHECK(false_value.defined()) << "ValueError: true_value is undefined"; ICHECK(condition.dtype().is_bool()); - ICHECK(condition.dtype().lanes() == true_value.dtype().lanes() || condition.dtype().lanes() == 1); + ICHECK(condition.dtype().get_lanes_or_vscale_factor() == + true_value.dtype().get_lanes_or_vscale_factor() || + condition.dtype().is_scalar()); ICHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types. " << "False type: " << false_value.dtype() << "; True type: " << true_value.dtype(); diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 57536422cf64..a9cc4975801a 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -37,19 +37,36 @@ namespace tvm { namespace tir { -// TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455 -inline PrimExpr BroadcastTo(PrimExpr e, int lanes) { - if (e.dtype().lanes() == lanes) return e; +inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) { + if (is_scalable) { + return Mul(Call(DataType::Int(32), builtin::vscale(), {}), lanes_or_vscale_factor); + } else { + return lanes_or_vscale_factor; + } +} + +inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { + // Check if e is already in the expected form + if (e.dtype().get_lanes_or_vscale_factor() == lanes && + e.dtype().is_scalable_vector() == is_scalable) + return e; + if (const BroadcastNode* op = e.as()) { - ICHECK(!e.dtype().is_scalable_vector()); - int broadcast_lanes = static_cast(Downcast(op->lanes)->value); - if (lanes % broadcast_lanes == 0) { - return Broadcast(op->value, lanes); + ICHECK(op->dtype.is_scalable_vector() == is_scalable) + << "Can't broadcast between scalable and fixed length vectors."; + int e_lanes = op->dtype.get_lanes_or_vscale_factor(); + + if (lanes % e_lanes == 0) { + return Broadcast(op->value, CreateNewLanes(is_scalable, lanes)); } } - ICHECK_EQ(e.dtype().lanes(), 1) << "Cannot broadcast lane=" << e.dtype().lanes() << " to " - << lanes; - return Broadcast(e, lanes); + + ICHECK(e.dtype().is_scalar()) << "Cannot broadcast lanes=" + << e.dtype().get_lanes_or_vscale_factor() + << " is_scalable=" << e.dtype().is_scalable_vector() << " to " + << lanes; + + return Broadcast(e, CreateNewLanes(is_scalable, lanes)); } // Rewrite vectorized allocation access @@ -62,7 +79,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes) { // class VecAllocAccess : public StmtExprMutator { public: - VecAllocAccess(const VarNode* buf, Var var, int var_lanes) + VecAllocAccess(const VarNode* buf, Var var, PrimExpr var_lanes) : buf_(buf), var_(var), var_lanes_(var_lanes) {} PrimExpr VisitExpr_(const BufferLoadNode* op) final { @@ -138,7 +155,7 @@ class VecAllocAccess : public StmtExprMutator { // variable to be replaced Var var_; // the lanes. - int var_lanes_; + PrimExpr var_lanes_; // Analyzer for simplifications arith::Analyzer analyzer_; }; @@ -151,7 +168,7 @@ class Vectorizer : public StmtMutator, public ExprFunctordtype, 0), IntImm(var->dtype, 1), var_lanes); } @@ -182,21 +199,30 @@ class Vectorizer : public StmtMutator, public ExprFunctora) && b.same_as(op->b)) { return GetRef(op); } else { - // TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455 - int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); - if (lanes != 1) { + bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector(); + bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector(); + if (is_vec_a && is_vec_b) { + // Let's not multiply scalable and fixed length vectors + ICHECK(a.dtype().is_scalable_vector() == b.dtype().is_scalable_vector()) + << "Fixed length and scalable vectors can't be mixed in multiplication."; + } + if (is_vec_a || is_vec_b) { const RampNode* b_ramp = b.as(); const RampNode* a_ramp = a.as(); - if (a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) { - int lanes = static_cast(Downcast(a_ramp->lanes)->value); + if (a_ramp && b.dtype().is_scalar() && analyzer_.CanProve(b > 0)) { + PrimExpr lanes = a_ramp->lanes; return Ramp(a_ramp->base * b, a_ramp->stride * b, lanes); } - if (b_ramp && a.dtype().lanes() == 1 && analyzer_.CanProve(a > 0)) { - int lanes = static_cast(Downcast(b_ramp->lanes)->value); + if (b_ramp && a.dtype().is_scalar() && analyzer_.CanProve(a > 0)) { + PrimExpr lanes = b_ramp->lanes; return Ramp(b_ramp->base * a, b_ramp->stride * a, lanes); } + int a_lanes = a.dtype().get_lanes_or_vscale_factor(); + int b_lanes = b.dtype().get_lanes_or_vscale_factor(); + int max_lanes = std::max(a_lanes, b_lanes); + bool is_scalable = a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector(); + return Mul(BroadcastTo(a, max_lanes, is_scalable), BroadcastTo(b, max_lanes, is_scalable)); } - return Mul(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); } return BinaryVec(op); } @@ -227,18 +253,24 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->base); PrimExpr stride = this->VisitExpr(op->stride); - // TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455 - int op_lanes = static_cast(Downcast(op->lanes)->value); - if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) { + ICHECK(!base.dtype().is_scalable_vector()) + << "Creating scalable vectors from existing vectors is not supported."; + ICHECK(!stride.dtype().is_scalable_vector()) + << "Ramp stride with scalable dtype is not supported"; + if (base.dtype().is_fixed_length_vector() && stride.dtype().is_scalar()) { + ICHECK(op->lanes->IsInstance()) + << "Vectorizing over existing scalable vectors is not supported."; const RampNode* base_ramp = base.as(); + int op_lanes = static_cast(Downcast(op->lanes)->value); int base_ramp_lanes = static_cast(Downcast(base_ramp->lanes)->value); - if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), op_lanes))) { + if (analyzer_.CanProve(base_ramp->stride == + stride * make_const(stride.dtype(), base_ramp_lanes))) { return Ramp(base_ramp->base, stride, op_lanes * base_ramp_lanes); } } int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes()); - base = BroadcastTo(base, lanes); - stride = BroadcastTo(stride, lanes); + base = BroadcastTo(base, lanes, false); + stride = BroadcastTo(stride, lanes, false); Array elems; for (int i = 0; i < lanes; ++i) { elems.push_back( @@ -249,7 +281,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->value); - if (value.dtype().lanes() != 1) { + if (value.dtype().is_scalable_or_fixed_length_vector()) { need_scalarize_ = true; return GetRef(op); } @@ -267,16 +299,27 @@ class Vectorizer : public StmtMutator, public ExprFunctorcondition) && t.same_as(op->true_value) && f.same_as(op->false_value)) { return GetRef(op); } else { - int lanes = std::max(std::max(cond.dtype().lanes(), t.dtype().lanes()), f.dtype().lanes()); - return Select(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes)); + int cond_lanes = cond.dtype().get_lanes_or_vscale_factor(); + int t_lanes = t.dtype().get_lanes_or_vscale_factor(); + int f_lanes = f.dtype().get_lanes_or_vscale_factor(); + int lanes = std::max(std::max(cond_lanes, t_lanes), f_lanes); + bool is_scalable = cond.dtype().is_scalable_vector() || t.dtype().is_scalable_vector() || + f.dtype().is_scalable_vector(); + return Select(BroadcastTo(cond, lanes, is_scalable), BroadcastTo(t, lanes, is_scalable), + BroadcastTo(f, lanes, is_scalable)); } } + PrimExpr VisitExpr_(const CastNode* op) final { PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { return GetRef(op); } else { - return Cast(op->dtype.with_lanes(value.dtype().lanes()), value); + if (value.dtype().is_scalable_vector()) { + return Cast(op->dtype.with_scalable_vscale_factor(value.dtype().vscale_factor()), value); + } else { + return Cast(op->dtype.with_lanes(value.dtype().lanes()), value); + } } } @@ -312,10 +355,17 @@ class Vectorizer : public StmtMutator, public ExprFunctorargs[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) { return GetRef(op); } else { - int lanes = std::max(t.dtype().lanes(), f.dtype().lanes()); - t = BroadcastTo(t, lanes); - f = BroadcastTo(f, lanes); - return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}); + int t_lanes = t.dtype().get_lanes_or_vscale_factor(); + int f_lanes = f.dtype().get_lanes_or_vscale_factor(); + int lanes = std::max(t_lanes, f_lanes); + bool is_scalable = t.dtype().is_scalable_vector() || f.dtype().is_scalable_vector(); + t = BroadcastTo(t, lanes, is_scalable); + f = BroadcastTo(f, lanes, is_scalable); + if (is_scalable) { + return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, {cond, t, f}); + } else { + return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}); + } } } // Reinterpret expr @@ -325,8 +375,12 @@ class Vectorizer : public StmtMutator, public ExprFunctorargs[0])) { return GetRef(op); } else { - int lanes = value.dtype().lanes(); - return Call(op->dtype.with_lanes(lanes), op->op, {value}); + int lanes = value.dtype().get_lanes_or_vscale_factor(); + if (value.dtype().is_scalable_vector()) { + return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, {value}); + } else { + return Call(op->dtype.with_lanes(lanes), op->op, {value}); + } } } // Call @@ -351,7 +405,8 @@ class Vectorizer : public StmtMutator, public ExprFunctorop.as(); - bool vectorizable = optional_op && op_vectorizable_.get(optional_op.value(), false); + bool vectorizable = optional_op && op_vectorizable_.get(optional_op.value(), false) && + !op->dtype.is_scalable_vector(); if (!vectorizable) { // Cannot vectorize this op @@ -409,7 +464,8 @@ class Vectorizer : public StmtMutator, public ExprFunctorsecond, value)) << "Let cannot bind the same var to two different values"; } - if (value.dtype().lanes() != op->value.dtype().lanes()) { + if (value.dtype().get_lanes_or_vscale_factor() != + op->value.dtype().get_lanes_or_vscale_factor()) { Var new_var(op->var->name_hint, value.dtype()); let_binding_[op->var] = new_var; return Let(new_var, value, this->VisitExpr(op->body)); @@ -433,20 +489,28 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->value); if (!indices.same_as(op->indices) || !value.same_as(op->value)) { + ICHECK(!op->buffer->dtype.is_scalable_vector()) + << "Vectorizing over scalable buffer elements is not supported in vectorizer."; // How many lanes of indexing are present in the index and - // buffer element type, excluding the last index. T + // buffer element type, excluding the last index. int other_index_lanes = op->buffer->dtype.lanes(); for (size_t i = 0; i < indices.size() - 1; i++) { other_index_lanes *= indices[i].dtype().lanes(); + // Only allow the last index to be scalable + ICHECK(!indices[i].dtype().is_scalable_vector()) << "Only the last index can be scalable."; } // The total number of lanes of indexing, including the last index. - int index_lanes = other_index_lanes * indices[indices.size() - 1].dtype().lanes(); + auto last_index_dtype = indices[indices.size() - 1].dtype(); + int lanes_in_last_index = last_index_dtype.get_lanes_or_vscale_factor(); + int index_lanes = other_index_lanes * lanes_in_last_index; // The total number of lanes in this store operation. Either // the index or the value will be broadcast out to this number // of lanes, depending on which has more lanes. - int total_lanes = std::max(index_lanes, value.dtype().lanes()); + int value_dtype_lanes = value.dtype().get_lanes_or_vscale_factor(); + bool is_last_index_scalable = last_index_dtype.is_scalable_vector(); + int total_lanes = std::max(index_lanes, value_dtype_lanes); ICHECK_EQ(total_lanes % other_index_lanes, 0) << "When storing to buffer " << op->buffer->name << ", cannot produce " << total_lanes @@ -455,11 +519,12 @@ class Vectorizer : public StmtMutator, public ExprFunctorindices = indices; - writer->value = BroadcastTo(value, total_lanes); + writer->value = BroadcastTo(value, total_lanes, is_last_index_scalable); } return std::move(store); @@ -512,7 +577,8 @@ class Vectorizer : public StmtMutator, public ExprFunctorvar)) << "SSA violation, a single var is binded twice"; let_binding_[op->var] = value; - if (value.dtype().lanes() != op->value.dtype().lanes()) { + if (value.dtype().get_lanes_or_vscale_factor() != + op->value.dtype().get_lanes_or_vscale_factor()) { Var new_var(op->var->name_hint, value.dtype()); let_binding_[op->var] = new_var; return LetStmt(new_var, value, this->VisitStmt(op->body)); @@ -566,8 +632,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorname_hint + ".s", var_->dtype); stmt = Substitute(stmt, {{var_, idx}}); - return For(idx, IntImm(var_->dtype, 0), IntImm(var_->dtype, var_lanes_), ForKind::kSerial, - stmt); + return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt); } // ProducerStore Stmt VisitStmt_(const ProducerStoreNode* op) final { @@ -582,7 +647,7 @@ class Vectorizer : public StmtMutator, public ExprFunctora) && b.same_as(op->b)) { return GetRef(op); } else { - int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); - return TOp(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); + int a_lanes = a.dtype().get_lanes_or_vscale_factor(); + int b_lanes = b.dtype().get_lanes_or_vscale_factor(); + int lanes = std::max(a_lanes, b_lanes); + bool is_scalable = a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector(); + return TOp(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, lanes, is_scalable)); } } template @@ -635,19 +703,22 @@ class Vectorizer : public StmtMutator, public ExprFunctora) && b.same_as(op->b)) { return GetRef(op); } else { - int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); + int a_lanes = a.dtype().get_lanes_or_vscale_factor(); + int b_lanes = b.dtype().get_lanes_or_vscale_factor(); + int lanes = std::max(a_lanes, b_lanes); if (lanes != 1) { const RampNode* b_ramp = b.as(); const RampNode* a_ramp = a.as(); - if (a.dtype().lanes() == 1 && b_ramp) { + if (a.dtype().is_scalar() && b_ramp) { return Ramp(fcompute(a, b_ramp->base), fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes); } - if (b.dtype().lanes() == 1 && a_ramp) { + if (b.dtype().is_scalar() && a_ramp) { return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); } } - return fcompute(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); + bool is_scalable = a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector(); + return fcompute(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, lanes, is_scalable)); } } }; @@ -657,11 +728,7 @@ class LoopVectorizer : public StmtMutator { Stmt VisitStmt_(const ForNode* op) final { if (op->kind == ForKind::kVectorized) { ICHECK(is_zero(op->min)); - auto* extent_as_int = op->extent.as(); - if (!extent_as_int || extent_as_int->value < 1) { - LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent; - } - return Vectorizer(op->loop_var, static_cast(extent_as_int->value))(op->body); + return Vectorizer(op->loop_var, op->extent)(op->body); } else { return StmtMutator::VisitStmt_(op); } diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index 7d0fac242307..dbca006b19cb 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -19,32 +19,29 @@ from tvm import te from tvm.script import ir as I from tvm.script import tir as T +import pytest -def test_vectorize_loop(): - dtype = "int64" - n = te.var("n") - ib = tvm.tir.ir_builder.create() - A = ib.pointer("float32", name="A") - with ib.for_range(0, n) as i: - with ib.for_range(0, 4, kind="vectorize") as j: - A[j] = tvm.tir.const(1, A.dtype) - stmt = ib.get() - - assert isinstance(stmt.body, tvm.tir.For) +@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +def test_vectorize_loop(extent): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((16,), "float32")): + for j in T.vectorized(0, extent): + A[j] = 1 - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) - stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((16,), "float32")): + A[T.Ramp(0, 1, extent)] = T.Broadcast(1, extent) - assert isinstance(stmt, tvm.tir.For) - assert not isinstance(stmt.body, tvm.tir.For) - assert len(stmt.body.indices) == 1 - assert isinstance(stmt.body.indices[0], tvm.tir.Ramp) - assert isinstance(stmt.body.value, tvm.tir.Broadcast) + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) def test_vectorize_vector(): - dtype = "int64" n = te.var("n") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32x4", name="A") @@ -64,28 +61,90 @@ def test_vectorize_vector(): assert isinstance(stmt.body.value, tvm.tir.Broadcast) -def test_vectorize_with_if(): - n = te.var("n") - x = te.var("x") - ib = tvm.tir.ir_builder.create() - A = ib.pointer("float32", name="A") - with ib.for_range(0, 4, kind="vectorize") as i: - with ib.if_scope(x < n): - A[i] = A[i] + 1 - with ib.else_scope(): - with ib.if_scope(i < n): - A[i] = 2.0 - stmt = ib.get() +def test_vectorize_vector_scalable_error(): + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((25,), "float32")): + for j in T.vectorized(T.vscale() * 4): + A[j * 4 : j * 4 + 4] = T.Broadcast(T.float32(1), 4) + + error_msg = f"Creating scalable vectors from existing vectors is not supported." + with pytest.raises(tvm.error.InternalError, match=error_msg): + tvm.tir.transform.VectorizeLoop()(Module) + + +def test_vectorize_vector_scalable_error2(): + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((25,), "float32xvscalex4")): + for j in T.vectorized(4): + A[j] = T.Broadcast(T.float32(1), T.vscale() * 4) + + error_msg = f"Vectorizing over scalable buffer elements is not supported in vectorizer." + with pytest.raises(tvm.error.InternalError, match=error_msg): + tvm.tir.transform.VectorizeLoop()(Module) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt)) - stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body - assert isinstance(stmt, tvm.tir.IfThenElse) - assert len(stmt.then_case.indices) == 1 - assert isinstance(stmt.then_case.indices[0], tvm.tir.Ramp) - assert isinstance(stmt.then_case.value, tvm.tir.Add) - assert stmt.then_case.value.dtype == "float32x4" - assert isinstance(stmt.else_case, tvm.tir.For) +def test_vectorize_vector_scalable_error3(): + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((25,), "float32")): + for j in T.vectorized(4): + A[j * T.vscale() * 4 : j * T.vscale() * 4 + T.vscale() * 4] = T.Broadcast( + T.float32(1), T.vscale() * 4 + ) + + error_msg = f"Vectorizing over existing scalable vectors is not supported." + with pytest.raises(tvm.error.InternalError, match=error_msg): + tvm.tir.transform.VectorizeLoop()(Module) + + +def test_vectorize_vector_scalable_error4(): + @I.ir_module + class Module: + @T.prim_func(private=True) + def main(A: T.Buffer((25,), "float32")): + for j in T.vectorized(T.vscale() * 4): + A[j * T.vscale() * 4 : j * T.vscale() * 4 + T.vscale() * 4] = T.Broadcast( + T.float32(1), T.vscale() * 4 + ) + + error_msg = f"Creating scalable vectors from existing vectors is not supported." + with pytest.raises(tvm.error.InternalError, match=error_msg): + tvm.tir.transform.VectorizeLoop()(Module) + + +@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +def test_vectorize_with_if(extent): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): + for i in T.vectorized(extent): + if x < n: + A[i] = A[i] + T.float32(1) + else: + if i < n: + A[i] = T.float32(2) + + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): + if x < n: + A[T.Ramp(0, 1, extent)] = A[T.Ramp(0, 1, extent)] + T.Broadcast( + T.float32(1), extent + ) + else: + for i_s in range(extent): + if i_s < n: + A[i_s] = T.float32(2) + + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) def test_vectorize_with_if_cond_int64(): @@ -98,25 +157,33 @@ def test_vectorize_with_if_cond_int64(): f = tvm.build(s, [A, B], "llvm") -def test_vectorize_let(): - v = tvm.tir.Var("v", "float32") - ib = tvm.tir.ir_builder.create() - A = ib.pointer("float32", name="A") - with ib.for_range(0, 4, kind="vectorize") as i: - ib.emit(lambda body: tvm.tir.LetStmt(v, A[i] + 1, body)) - A[i] = v + 2 +@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +def test_vectorize_let(extent): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((25,), "float32")): + for i in T.vectorized(extent): + v = A[i] + T.float32(1) + A[i] = v + T.float32(2) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], ib.get())) - stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body - assert isinstance(stmt, tvm.tir.LetStmt) - assert stmt.value.dtype == "float32x4" + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((25,), "float32")): + v = A[T.Ramp(0, 1, extent)] + T.Broadcast(T.float32(1), extent) + A[T.Ramp(0, 1, extent)] = v + T.Broadcast(T.float32(2), extent) + + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) -def test_vectorize_with_le_cond(): +@pytest.mark.parametrize("extent", (4, tvm.tir.vscale() * 4)) +def test_vectorize_with_le_cond(extent): n = te.var("n") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") - with ib.for_range(0, 4, kind="vectorize") as i: + with ib.for_range(0, extent, kind="vectorize") as i: with ib.if_scope(i <= n): A[i] = A[i] + 1 stmt = ib.get() @@ -124,14 +191,16 @@ def test_vectorize_with_le_cond(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + # Check that the loop was't vectorised assert isinstance(stmt, tvm.tir.For) -def test_vectorize_with_ge_cond(): +@pytest.mark.parametrize("extent", (4, tvm.tir.vscale() * 4)) +def test_vectorize_with_ge_cond(extent): n = te.var("n") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") - with ib.for_range(0, 4, kind="vectorize") as i: + with ib.for_range(0, extent, kind="vectorize") as i: with ib.if_scope(i >= n): A[i] = A[i] + 1 stmt = ib.get() @@ -139,39 +208,51 @@ def test_vectorize_with_ge_cond(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + # Check that the loop wasn't vectorised assert isinstance(stmt, tvm.tir.For) -def test_vectorize_if_then_else(): - n = te.var("n") - x = te.var("x") - ib = tvm.tir.ir_builder.create() - A = ib.pointer("float32", name="A") - with ib.for_range(0, 4, kind="vectorize") as i: - A[i] = tvm.tir.call_intrin("float32", "tir.if_then_else", i > 0, A[i] + 1, A[i]) - stmt = ib.get() +@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +def test_vectorize_if_then_else_scalarize(extent): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((25,), "float32")): + for i in T.vectorized(extent): + A[i] = T.if_then_else(i > 0, A[i] + T.float32(1), A[i]) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt)) - stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((25,), "float32")): + for i_s in range(extent): + A[i_s] = T.if_then_else(i_s > 0, A[i_s] + T.float32(1), A[i_s]) - assert isinstance(stmt, tvm.tir.For) + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) - ib = tvm.tir.ir_builder.create() - A = ib.pointer("float32", name="A") - with ib.for_range(0, n) as k: - with ib.for_range(0, 4, kind="vectorize") as i: - A[k * 4 + i] = tvm.tir.call_intrin( - "float32", "tir.if_then_else", k > 0, A[k * 4 + i], 0 - ) - stmt = ib.get() - assert isinstance(stmt.body, tvm.tir.For) +@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +def test_vectorize_if_then_else_vector(extent): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((25,), "float32"), n: T.int32): + for i in range(n): + for j in T.vectorized(extent): + A[i * extent + j] = T.if_then_else(i > 0, A[i * extent + j], 0) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) - stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((25,), "float32"), n: T.int32): + for i in range(n): + A[T.Ramp(i * extent, 1, extent)] = T.if_then_else( + i > 0, A[T.Ramp(i * extent, 1, extent)], T.Broadcast(0, extent) + ) - assert not isinstance(stmt.body, tvm.tir.For) - assert isinstance(stmt.body.value.args[2], tvm.tir.Broadcast) + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) def test_vectorize_while_fail(): @@ -229,23 +310,141 @@ def test_vectorize_dtype_mismatch(): tvm.lower(s, [A], "llvm", simple_mode=True) -def test_vectorize_with_reinterpret(): +@pytest.mark.parametrize( + "extent, vec_str", [(16, "float32x16"), (T.vscale() * 8, "float32xvscalex8")] +) +def test_vectorize_with_reinterpret(extent, vec_str): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")): - for i in T.vectorized(0, 16): + for i in T.vectorized(0, extent): B[i] = T.reinterpret("float32", A[i]) @I.ir_module class After: @T.prim_func def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")): - B[0:16] = T.reinterpret("float32x16", A[0:16]) + B[T.Ramp(0, 1, extent)] = T.reinterpret(vec_str, A[T.Ramp(0, 1, extent)]) mod = tvm.tir.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) +@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +@pytest.mark.parametrize( + "op", + ( + T.Mul, + T.Add, + T.Sub, + T.Div, + T.Mod, + T.FloorDiv, + T.FloorMod, + T.Min, + T.Max, + T.EQ, + T.LT, + T.LE, + T.GE, + T.GT, + T.NE, + ), +) +def test_vectorize_binary(op, extent): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): + for j in T.vectorized(extent): + A[j] = op(T.float32(3), B[j]) + + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): + A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.float32(3), extent), B[T.Ramp(0, 1, extent)]) + + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) + + +@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +@pytest.mark.parametrize("op", (T.And, T.Or)) +def test_vectorize_logical(op, extent): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")): + for j in T.vectorized(extent): + A[j] = op(T.bool(1), B[j]) + + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")): + A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.bool(1), extent), B[T.Ramp(0, 1, extent)]) + + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) + + +@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +def test_vectorize_select(extent): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): + for j in T.vectorized(extent): + A[j] = T.Select(T.bool(True), A[j], B[j]) + + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): + A[T.Ramp(0, 1, extent)] = T.Select( + T.Broadcast(T.bool(True), extent), + A[T.Ramp(0, 1, extent)], + B[T.Ramp(0, 1, extent)], + ) + + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) + + +@pytest.mark.parametrize("extent, vec_str", [(4, "int32x4"), (T.vscale() * 4, "int32xvscalex4")]) +def test_vectorize_cast(extent, vec_str): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): + for j in T.vectorized(extent): + A[j] = T.Cast("int32", B[j]) + + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): + A[T.Ramp(0, 1, extent)] = T.Cast(vec_str, B[T.Ramp(0, 1, extent)]) + + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) + + +def test_illegal_extent(): + @I.ir_module(check_well_formed=False) + class Mod: + @T.prim_func + def main(A: T.Buffer((25,), "int32")): + n = T.Var("n", dtype="int32") + for j in T.vectorized(n): + A[j] = 3 + + error_msg = f"Invalid expression for scalable lanes n" + with pytest.raises(tvm.error.InternalError, match=error_msg): + tvm.tir.transform.VectorizeLoop()(Mod) + + if __name__ == "__main__": tvm.testing.main()