Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 111 additions & 48 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,46 @@ class SeqStmtNode : public StmtNode {
TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode);
};

/*!
* \brief Evaluates an expression.
* This is mostly used for putting a Call node into Stmt.
*
* If value do not have side-effect, this node can be safely removed.
*/
class EvaluateNode : public StmtNode {
public:
/*! \brief The expression to be evaluated. */
PrimExpr value;

void VisitAttrs(AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("span", &span);
}

bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const {
return equal(value, other->value);
}

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }

static constexpr const char* _type_key = "tir.Evaluate";
TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode);
};

/*!
* \brief Managed reference to EvaluateNode.
* \sa EvaluateNode
*/
class Evaluate : public Stmt {
public:
TVM_DLL explicit Evaluate(PrimExpr value, Span span = Span());

explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {}

TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode);
};

/*! \brief Sequence statement. */
class SeqStmt : public Stmt {
public:
Expand Down Expand Up @@ -718,6 +758,10 @@ class SeqStmt : public Stmt {
* \note This function can directly return an element
* if it is the only element in the sequence.
*
* \note If the only argument to this function is a SeqStmt, and if
* no flattening of the SeqStmt is required, then the SeqStmt
* will be returned as-is.
*
* \param seq_args The list of arguments to be flattened.
* \tparam Args arguments
* \return The constructed statement
Expand All @@ -726,34 +770,93 @@ class SeqStmt : public Stmt {
static Stmt Flatten(Args&&... seq_args) {
Array<Stmt> seq;
runtime::detail::for_each(Flattener(&seq), std::forward<Args>(seq_args)...);
if (seq.size() == 1) return seq[0];

if (seq.empty()) {
return Evaluate(0);
} else if (seq.size() == 1) {
return seq[0];
}

// If the argument is a single SeqStmt argument with no
// flattening or unwrapping required required, then we may
// return the SeqStmt as-is.
if constexpr (sizeof...(seq_args) == 1) {
if (auto opt = Flattener::AsSeqStmt(std::forward<Args>(seq_args)...)) {
SeqStmt original = opt.value();
bool all_same = [&]() {
if (original->seq.size() != seq.size()) {
return false;
}
for (size_t i = 0; i < seq.size(); i++) {
if (!original->seq[i].same_as(seq[i])) {
return false;
}
}
return true;
}();
if (all_same) {
return original;
}
}
}

return SeqStmt(seq);
}
/*! \brief Helper class to flatten sequence of arguments into Array. */
class Flattener {
public:
explicit Flattener(Array<Stmt>* seq) : seq_(seq) {}

template <typename T>
static Optional<SeqStmt> AsSeqStmt(const T& t) {
if constexpr (std::is_same_v<T, SeqStmt>) {
return t;
} else if constexpr (!std::is_base_of_v<T, SeqStmt>) {
return NullOpt;
} else if (auto* ptr = t.template as<SeqStmtNode>()) {
return GetRef<SeqStmt>(ptr);
} else {
return NullOpt;
}
}

template <typename T>
void operator()(size_t i, const T& stmt_or_seq) const {
if constexpr (std::is_base_of_v<ObjectRef, T>) {
// Early bail-out, applicable to any ObjectRef
if (!stmt_or_seq.defined()) return;
if (!stmt_or_seq.defined()) {
return;
}
}

if constexpr (std::is_same_v<T, SeqStmt>) {
// No need for dynamic type-checking if the static type is a
// SeqStmt.
// Static type-checking for a SeqStmt that could be flattened.
(*this)(0, stmt_or_seq->seq);
} else if constexpr (std::is_base_of_v<T, SeqStmt>) {
return;
}

if constexpr (std::is_base_of_v<T, SeqStmt>) {
// Dynamic type-checking for a SeqStmt that could be
// flattened.
if (auto* op = stmt_or_seq.template as<SeqStmtNode>()) {
operator()(0, op->seq);
} else {
seq_->push_back(stmt_or_seq);
return;
}
} else if constexpr (std::is_base_of_v<Stmt, T>) {
}

if constexpr (std::is_base_of_v<T, Evaluate>) {
// Evaluate(0) is used to represent a no-op, and may be
// generated by previous calls to SeqStmt::Flatten(). These
// should be removed to ensure that Flatten(a+b) is equivalent
// to Flatten(Flatten(a), Flatten(b)).
if (auto* op = stmt_or_seq.template as<EvaluateNode>()) {
if (auto* as_int = op->value.template as<IntImmNode>(); as_int && as_int->value == 0) {
return;
}
}
}

if constexpr (std::is_base_of_v<Stmt, T>) {
// Any other Stmt type just gets appended.
seq_->push_back(stmt_or_seq);
} else {
Expand Down Expand Up @@ -819,46 +922,6 @@ class IfThenElse : public Stmt {
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfThenElseNode);
};

/*!
* \brief Evaluates an expression.
* This is mostly used for putting a Call node into Stmt.
*
* If value do not have side-effect, this node can be safely removed.
*/
class EvaluateNode : public StmtNode {
public:
/*! \brief The expression to be evaluated. */
PrimExpr value;

void VisitAttrs(AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("span", &span);
}

bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const {
return equal(value, other->value);
}

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }

static constexpr const char* _type_key = "tir.Evaluate";
TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode);
};

/*!
* \brief Managed reference to EvaluateNode.
* \sa EvaluateNode
*/
class Evaluate : public Stmt {
public:
TVM_DLL explicit Evaluate(PrimExpr value, Span span = Span());

explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {}

TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode);
};

/*!
* \brief The kind of the loop.
*
Expand Down
6 changes: 3 additions & 3 deletions src/relay/backend/aot/aot_lower_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ class AOTMainLowerer : public MixedModeVisitor {
* runner function needs to be legalized by the LegalizePackedCalls pass.
*/
tir::PrimFunc CreateMainFunc(String mod_name) {
tir::Stmt body = tir::SeqStmt(stmts_);
tir::Stmt body = tir::SeqStmt::Flatten(stmts_);
// Allocate the sids
std::unordered_map<int, bool> allocated;
std::vector<std::pair<int64_t, int64_t>> sids_to_allocate;
Expand Down Expand Up @@ -674,7 +674,7 @@ class AOTMainLowerer : public MixedModeVisitor {
}));
}

tir::Stmt body = tir::SeqStmt({func_call});
tir::Stmt body = tir::SeqStmt::Flatten(func_call);
stmts_.push_back(body);
}

Expand Down Expand Up @@ -717,7 +717,7 @@ class AOTMainLowerer : public MixedModeVisitor {
{tvm::tir::StringImm(device_hook_name), context})));
device_hooks.push_back(device_hook);
}
return tir::SeqStmt(device_hooks);
return tir::SeqStmt::Flatten(device_hooks);
}

/*!
Expand Down
6 changes: 3 additions & 3 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}));
}

tir::Stmt body = tir::SeqStmt({func_call});
tir::Stmt body = tir::SeqStmt::Flatten(func_call);
stmts_.push_back(body);
}

Expand Down Expand Up @@ -570,7 +570,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
{tvm::tir::StringImm(device_hook_name), context})));
device_hooks.push_back(device_hook);
}
return tir::SeqStmt(device_hooks);
return tir::SeqStmt::Flatten(device_hooks);
}

/**
Expand Down Expand Up @@ -736,7 +736,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// the packed function calls don't pack their arguments. The AOT
// runner function needs to be legalized by the LegalizePackedCalls pass.
tir::PrimFunc CreateMainFunc(String mod_name, unsigned int relay_params) {
tir::Stmt body = tir::SeqStmt(stmts_);
tir::Stmt body = tir::SeqStmt::Flatten(stmts_);
// Allocate the sids
std::unordered_map<int, bool> allocated;

Expand Down
9 changes: 1 addition & 8 deletions src/script/ir_builder/tir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,7 @@ inline void AddToParent(tvm::tir::Stmt stmt) {
* \return The SeqStmt.
*/
inline tvm::tir::Stmt AsStmt(const Array<tvm::tir::Stmt>& stmt) {
using namespace tvm::tir;
if (stmt.empty()) {
return tvm::tir::Evaluate(0);
} else if (stmt.size() == 1) {
return stmt[0];
} else {
return SeqStmt(stmt);
}
return tvm::tir::SeqStmt::Flatten(stmt);
}

/*!
Expand Down
2 changes: 1 addition & 1 deletion src/tir/contrib/ethosu/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class HoistAllocatesMutator : public StmtExprMutator {
for (auto it = allocates_.rbegin(); it != allocates_.rend(); it++) {
Allocate current_alloc = *it;
if (it != allocates_.rbegin()) {
new_main_func_body = SeqStmt({new_main_func_body});
new_main_func_body = SeqStmt::Flatten(new_main_func_body);
}
new_main_func_body =
Allocate(current_alloc->buffer_var, current_alloc->dtype, current_alloc->extents,
Expand Down
19 changes: 19 additions & 0 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,25 @@ TVM_REGISTER_NODE_TYPE(PrefetchNode);

// SeqStmt
SeqStmt::SeqStmt(Array<Stmt> seq, Span span) {
bool requires_flattening = std::any_of(
seq.begin(), seq.end(), [](const Stmt& stmt) { return stmt->IsInstance<SeqStmtNode>(); });

if (requires_flattening) {
auto flattened = SeqStmt::Flatten(seq);
if (auto* ptr = flattened.as<SeqStmtNode>()) {
seq = ptr->seq;
} else {
seq = {flattened};
}
}

ICHECK_NE(seq.size(), 0) << "An empty SeqStmt is prohibited. "
<< "To write a no-op, use Evaluate(0), "
<< "or the result of SeqStmt::Flatten()";
ICHECK_NE(seq.size(), 1) << "A SeqStmt of length 1 is prohibited. "
<< "Use the node " << seq[0] << "directly, "
<< "or for dynamic usage, normalize using SeqStmt::Flatten()";

auto node = make_object<SeqStmtNode>();
node->seq = std::move(seq);
node->span = std::move(span);
Expand Down
8 changes: 4 additions & 4 deletions src/tir/ir/stmt_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,11 +437,11 @@ Stmt StmtMutator::VisitStmt_(const PrefetchNode* op) {
Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) {
Array<Stmt> seq = Internal::Mutate(this, op->seq);
if (seq.same_as(op->seq)) {
return GetRef<Stmt>(op);
return SeqStmt::Flatten(GetRef<Stmt>(op));
} else {
auto n = CopyOnWrite(op);
n->seq = std::move(seq);
return Stmt(n);
auto node = CopyOnWrite(op);
node->seq = std::move(seq);
return SeqStmt::Flatten(SeqStmt(node));
}
}

Expand Down
7 changes: 3 additions & 4 deletions src/tir/schedule/analysis/reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -394,16 +394,15 @@ void ExtractReductionUpdates(const Optional<ScheduleState>& self, Block block,
if (p_seq == nullptr && p_buf_store == nullptr) {
ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/5);
}
SeqStmt seq =
p_seq != nullptr ? GetRef<SeqStmt>(p_seq) : SeqStmt({GetRef<BufferStore>(p_buf_store)});
if (static_cast<int>(seq->seq.size()) != n_buffers) {
Array<Stmt> seq = p_seq != nullptr ? p_seq->seq : Array<Stmt>{GetRef<BufferStore>(p_buf_store)};
if (static_cast<int>(seq.size()) != n_buffers) {
ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/6);
}

// Step 2.
// - Create BufferStores according to the variables being stored.
// - Construct the mapping from reduction buffers to the index.
for (const Stmt& stmt : seq->seq) {
for (const Stmt& stmt : seq) {
const auto* buf_store = stmt.as<BufferStoreNode>();
if (buf_store == nullptr) {
ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/5);
Expand Down
25 changes: 0 additions & 25 deletions src/tir/transforms/remove_no_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,31 +178,6 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer {
}
}

Stmt VisitStmt_(const SeqStmtNode* op) final {
auto ret = Downcast<SeqStmt>(StmtMutator::VisitSeqStmt_(op, true));

bool need_compact = std::any_of(ret->seq.begin(), ret->seq.end(),
[](const auto& stmt) { return is_no_op(stmt); });

if (need_compact) {
Array<Stmt> filtered;
for (Stmt stmt : ret->seq) {
if (!is_no_op(stmt)) {
filtered.push_back(std::move(stmt));
}
}
ret = SeqStmt(filtered);
}

if (ret->size() == 0) {
return Evaluate(0);
} else if (ret->size() == 1) {
return ret->seq[0];
} else {
return std::move(ret);
}
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store = GetRef<BufferStore>(op);

Expand Down
3 changes: 1 addition & 2 deletions tests/cpp/ir_functor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@ TEST(IRF, StmtMutator) {
auto* ref2 = body2.get();
auto* extentptr = body.as<AllocateNode>()->extents.get();
// construct a recursive SeqStmt.
body = SeqStmt({body});
body = SeqStmt({body, body2});
body = SeqStmt({body, body2});
body = v(std::move(body));
Expand All @@ -296,7 +295,7 @@ TEST(IRF, StmtMutator) {
Stmt body2 = Evaluate(1);
auto* extentptr = body.as<AllocateNode>()->extents.get();
// construct a recursive SeqStmt.
body = SeqStmt({body});
body = SeqStmt({body, body2});
auto bref = body;
body = SeqStmt({body, body2});
body = v(std::move(body));
Expand Down
Loading