diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 40707675fe75..170cfe0171cd 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -213,78 +213,6 @@ Call WithFields(Call call, Optional opt_op = Optional(), Optional> opt_sinfo_args = Optional>(), Optional opt_span = Optional()); -/*! - * \brief Condition expression - * - * Unlike traditional statement `if`s, the if evalutes - * to the result of the branch taken. - * - * x = if (true) { 1 } else { 0 }; // x is 1 - * y = if (false) { 1 } else { 0 }; // y is 0 - * - * \note This is similar to C's ternary operator. - */ -class IfNode : public ExprNode { - public: - /*! \brief The condition. */ - Expr cond; - /*! \brief The expression evaluated when condition is true. */ - Expr true_branch; - /*! \brief The expression evaluated when condition is false */ - Expr false_branch; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("cond", &cond); - v->Visit("true_branch", &true_branch); - v->Visit("false_branch", &false_branch); - v->Visit("_checked_type_", &checked_type_); - v->Visit("struct_info_", &struct_info_); - v->Visit("span", &span); - } - - bool SEqualReduce(const IfNode* other, SEqualReducer equal) const { - equal->MarkGraphNode(); - return equal(cond, other->cond) && equal(true_branch, other->true_branch) && - equal(false_branch, other->false_branch) && equal(struct_info_, other->struct_info_); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce->MarkGraphNode(); - hash_reduce(cond); - hash_reduce(true_branch); - hash_reduce(false_branch); - hash_reduce(struct_info_); - } - - static constexpr const char* _type_key = "relax.expr.If"; - TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode); -}; - -class If : public Expr { - public: - /*! - * \brief The constructor - * \param cond The condition of a if node. - * \param true_branch The fall through branch - * \param false_branch The branch for execution when condition is false. - * \param span The source span of the expression. - */ - TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span()); - - TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode); -}; - -/*! - * \brief Returns \p if_expr with the given properties. A null property denotes 'no change'. - * Returns \p if_expr if all properties are unchanged. Otherwise, returns a copy with the new - * fields. - */ -If WithFields(If if_expr, Optional opt_cond = Optional(), - Optional opt_true_branch = Optional(), - Optional opt_false_branch = Optional(), - Optional opt_span = Optional()); - /*! \brief Tuple container */ class TupleNode : public ExprNode { public: @@ -915,18 +843,113 @@ class SeqExprNode : public ExprNode { class SeqExpr : public Expr { public: + /* \brief Implicit conversion constructor + * + * Relax nodes that introduce a new scope (e.g. `relax::Function`) + * are required to be held as SeqExpr. This implicit conversion + * provides allows callsites to use these member variables when the + * C++ compile-time type is a `relax::Expr`. For example, + * a transform may use `func.CopyOnWrite()->body = expr;`. + * + * If the expression is already a `relax::SeqExpr`, the same + * underlying `relax::SeqExprNode` is used, and no copies are made. + */ + TVM_DLL SeqExpr(Expr body); // NOLINT(*) + TVM_DLL explicit SeqExpr(Array blocks, Expr body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqExprNode); }; +/*! + * \brief Condition expression + * + * Unlike traditional statement `if`s, the if evalutes + * to the result of the branch taken. + * + * x = if (true) { 1 } else { 0 }; // x is 1 + * y = if (false) { 1 } else { 0 }; // y is 0 + * + * \note This is similar to C's ternary operator. + */ +class IfNode : public ExprNode { + public: + /*! \brief The condition. */ + Expr cond; + /*! \brief The expression evaluated when condition is true. */ + SeqExpr true_branch; + /*! \brief The expression evaluated when condition is false */ + SeqExpr false_branch; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("cond", &cond); + v->Visit("true_branch", &true_branch); + v->Visit("false_branch", &false_branch); + v->Visit("_checked_type_", &checked_type_); + v->Visit("struct_info_", &struct_info_); + v->Visit("span", &span); + } + + bool SEqualReduce(const IfNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(cond, other->cond) && equal(true_branch, other->true_branch) && + equal(false_branch, other->false_branch) && equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce(cond); + hash_reduce(true_branch); + hash_reduce(false_branch); + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.If"; + TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode); +}; + +class If : public Expr { + public: + /*! + * \brief The constructor + * + * \param cond The condition of a if node. + * + * \param true_branch The fall through branch. If this is not a + * SeqExpr, it will be wrapped in a SeqExpr, to satisfy the + * Relax IR requirement that all scopes be contained in a + * SeqExpr. + * + * \param false_branch The branch for execution when condition is + * false. If this is not a SeqExpr, it will be wrapped in a + * SeqExpr, to satisfy the Relax IR requirement that all scopes + * be contained in a SeqExpr. + * + * \param span The source span of the expression. + */ + TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode); +}; + +/*! + * \brief Returns \p if_expr with the given properties. A null property denotes 'no change'. + * Returns \p if_expr if all properties are unchanged. Otherwise, returns a copy with the new + * fields. + */ +If WithFields(If if_expr, Optional opt_cond = Optional(), + Optional opt_true_branch = Optional(), + Optional opt_false_branch = Optional(), + Optional opt_span = Optional()); + /*! \brief A Relax function. */ class FunctionNode : public BaseFuncNode { public: /*! \brief The parameters to the function. */ Array params; /*! \brief The body of the function. */ - Expr body; + SeqExpr body; /*! \brief The return type of the function. */ StructInfo ret_struct_info; /*! \brief Whether the function is annotated as pure or not. */ @@ -968,6 +991,27 @@ class FunctionNode : public BaseFuncNode { class Function : public BaseFunc { public: + /*! + * \brief Construct a Relax Function + * + * \param params The parameters accepted by the function + * + * \param body The body of the function. If this is not a + * SeqExpr, it will be wrapped in a SeqExpr, to satisfy the + * Relax IR requirement that all scopes be contained in a + * SeqExpr. + * + * \param ret_struct_info The StructInfo returned by the function. + * If NullOpt, will be inferred from the StructInfo of the + * function's body. + * + * \param is_pure The purity of the function. + * + * \param attrs Any attributes associated with the function. + * Defaults to an empty dictionary. + * + * \param span The source span of the expression. + */ TVM_DLL explicit Function(Array params, Expr body, Optional ret_struct_info, bool is_pure = true, DictAttrs attrs = DictAttrs(), Span span = Span()); diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index 02b5a2ee671a..d35a462579d9 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -166,12 +166,9 @@ const MSCGraph RelaxGraphBuilder::Build(const relax::Function& func) { } } VisitExpr(func); - if (const auto* b_node = func->body.as()) { - ICHECK(expr_tensor_map_.count(b_node->body)) << "Can not find seqexpr body " << b_node->body; - output_names = expr_tensor_map_[b_node->body]; - } else { - LOG(FATAL) << "Function body should be SeqExpr, get " << func->body; - } + ICHECK(expr_tensor_map_.count(func->body->body)) + << "Can not find seqexpr body " << func->body->body; + output_names = expr_tensor_map_[func->body->body]; // remove const nodes as weights Array valid_nodes; std::set ignore_inputs; diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc b/src/contrib/msc/core/transform/set_expr_layout.cc index 0ece7a51cac8..76775a5ba322 100644 --- a/src/contrib/msc/core/transform/set_expr_layout.cc +++ b/src/contrib/msc/core/transform/set_expr_layout.cc @@ -1268,13 +1268,9 @@ class LayoutInfer : public ExprVisitor { SetExprLayout(call->args[i], var_layout_map_[func->params[i]]); } } - if (const auto* b_node = func->body.as()) { - if (b_node->body->IsInstance() && - var_layout_map_.count(Downcast(b_node->body))) { - SetExprLayout(ret, var_layout_map_[Downcast(b_node->body)]); - } - } else { - LOG(FATAL) << "Function body should be SeqExpr, get " << func->body; + if (func->body->body->IsInstance() && + var_layout_map_.count(Downcast(func->body->body))) { + SetExprLayout(ret, var_layout_map_[Downcast(func->body->body)]); } } @@ -1288,13 +1284,9 @@ class LayoutInfer : public ExprVisitor { if (producer->IsInstance() && local_funcs_.count(Downcast(producer)->op)) { const auto& caller = local_funcs_[Downcast(producer)->op]; - if (const auto* b_node = caller->body.as()) { - if (b_node->body->IsInstance() && - var_map_.count(Downcast(b_node->body))) { - SetExprLayout(b_node->body, param_layout); - } - } else { - LOG(FATAL) << "Caller body should be SeqExpr, get " << caller->body; + if (caller->body->body->IsInstance() && + var_map_.count(Downcast(caller->body->body))) { + SetExprLayout(caller->body->body, param_layout); } } } diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index b4a0fc4b9883..a73e6fb233bf 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -281,11 +281,7 @@ class WellFormedChecker : public relax::ExprVisitor, } } - if (auto seq = op->body.as()) { - this->VisitSeqExpr(seq); - } else { - Malformed(Diagnostic::Error(op) << "Function bodies must be sequence expressions"); - } + this->VisitSeqExpr(op->body.get()); is_dataflow_ = old_dataflow_state; dataflow_var_set_ = prev_dataflow_var_set; @@ -367,21 +363,17 @@ class WellFormedChecker : public relax::ExprVisitor, } else { Malformed(Diagnostic::Error(op) << "The condition for an if node must be a leaf expression."); } - auto true_seq = op->true_branch.as(); - auto false_seq = op->false_branch.as(); - if (true_seq && false_seq) { - std::unordered_set previous_var_set = var_set_; - std::unordered_set previous_symbolic_var_set = - symbolic_var_set_; - this->VisitSeqExpr(true_seq); - var_set_ = previous_var_set; - symbolic_var_set_ = previous_symbolic_var_set; - this->VisitSeqExpr(false_seq); - var_set_ = previous_var_set; - symbolic_var_set_ = previous_symbolic_var_set; - } else { - Malformed(Diagnostic::Error(op) << "If node branches must be seq exprs"); - } + + std::unordered_set previous_var_set = var_set_; + std::unordered_set previous_symbolic_var_set = + symbolic_var_set_; + this->VisitSeqExpr(op->true_branch.get()); + var_set_ = previous_var_set; + symbolic_var_set_ = previous_symbolic_var_set; + this->VisitSeqExpr(op->false_branch.get()); + var_set_ = previous_var_set; + symbolic_var_set_ = previous_symbolic_var_set; + CheckStructInfo(op); } diff --git a/src/relax/backend/contrib/utils.cc b/src/relax/backend/contrib/utils.cc index 20b2a6fce698..b260ea24bed3 100644 --- a/src/relax/backend/contrib/utils.cc +++ b/src/relax/backend/contrib/utils.cc @@ -36,7 +36,7 @@ Map ExtractArgIdx(String pattern_name, Function f) { ICHECK(pattern) << "Unsupported op_type " << pattern_name; auto bindings = AnalyzeVar2Value(f); - auto inner_body = Downcast(f->body)->body; + auto inner_body = f->body->body; auto matched_expr = relax::ExtractMatchedExpr(pattern.value()->pattern, inner_body, bindings); ICHECK(matched_expr) << "ValueError: " << "For named pattern \"" << pattern_name diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index cf8934c372e2..c0b8d1e1df08 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -59,13 +59,30 @@ bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { return VisitDFPattern(pattern, expr); } -static Expr TryGetValOfVar(const Expr& expr, const Map& var2val) { - if (var2val.empty()) return expr; +static Expr TryGetValOfVar(Expr expr, const Map& var2val) { + auto unwrap = [&](Expr expr) -> Optional { + // Unwrap variables into the value to which they are bound. + if (var2val.size()) { + if (const VarNode* var = expr.as()) { + if (auto may = var2val.Get(GetRef(var))) { + return may.value(); + } + } + } + + // Unwrap SeqExpr with no bindings. These can occur due to Relax + // IR constraints for the bodies of Function and If nodes. + if (auto seq = expr.as()) { + if (seq->blocks.empty()) { + return seq->body; + } + } + + return NullOpt; + }; - // if not match, try to match value of var if expr is a var. - if (const VarNode* var = expr.as()) { - auto may = var2val.Get(GetRef(var)); - if (may.defined()) return may.value(); + while (auto unwrapped = unwrap(expr)) { + expr = unwrapped.value(); } return expr; diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 1b5551e5097b..16bf61f15c78 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -476,6 +476,14 @@ TVM_REGISTER_GLOBAL("relax.DataflowBlock").set_body_typed([](Array bind TVM_REGISTER_NODE_TYPE(SeqExprNode); +SeqExpr::SeqExpr(Expr body) { + if (auto seq = body.as()) { + *this = seq.value(); + } else { + *this = SeqExpr(Array{}, body); + } +} + SeqExpr::SeqExpr(Array blocks, Expr body, Span span) { ObjectPtr n = make_object(); n->blocks = std::move(blocks); diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc index 19faaad58b87..a7348483f680 100644 --- a/src/relax/training/utils.cc +++ b/src/relax/training/utils.cc @@ -65,13 +65,10 @@ class AppendLossMutator : private ExprMutator { num_backbone_outputs_(num_backbone_outputs) {} Expr VisitExpr_(const FunctionNode* func) final { - CHECK(func->body->IsInstance() && loss_function_->body->IsInstance()) - << "The bodies of the backbone and the loss function must be SeqExpr."; - // Well-formed checks and setting up class members - loss_body_ = Downcast(loss_function_->body); + loss_body_ = loss_function_->body; CheckLossBody(); - BackboneReturnToArr(func->body.as()->body); + BackboneReturnToArr(func->body->body); CheckAndRemapBackboneReturn(); CheckAndRemapLossParams(loss_function_->params); diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index a2a3e96dd567..66a8dafb98c9 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -1257,9 +1257,19 @@ class CompositeFunctionAnnotator : public ExprMutator { params.push_back(new_v); } + // We cannot delegate to `ExprMutator::VisitExpr_(const FunctionNode*)` at this point, as it + // would recursively visit the Call node. However, we are still required to generate + // well-formed Relax IR. As a result, we need to build the SeqExpr ourselves. + Var local_func_var("local_func", GetStructInfo(f_inner)); + Var output_var("output", f_inner->ret_struct_info); + SeqExpr new_body({BindingBlock({ + VarBinding(local_func_var, f_inner), + VarBinding(output_var, Call(local_func_var, params)), + })}, + output_var); + // pure if the inner func is pure (no need to force purity if it's forced for the inner func) - return Function(param_vars, Call(f_inner, params), func_node->ret_struct_info, - f_inner->is_pure); + return Function(param_vars, new_body, func_node->ret_struct_info, f_inner->is_pure); } private: diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 11785ab73ac6..f6828470275f 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -438,9 +438,7 @@ class FusedTIRConstructor : public ExprVisitor { ExprVisitor::VisitExpr_(func); // Step 3. Create and remap buffers for function output - ICHECK(func->body->IsInstance()) - << "Function body is expected to be a SeqExpr, but got: " << func->body->GetTypeKey(); - Expr body = Downcast(func->body)->body; + Expr body = func->body->body; auto it = func_info_.expr2buffers.find(body); ICHECK(it != func_info_.expr2buffers.end()) << "Fail to detect output buffers for function body"; diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index 70e3e37876fd..cd07af37e0f0 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -664,8 +664,6 @@ class GradientMutator : private ExprMutator { } Expr VisitExpr_(const FunctionNode* func) final { - CHECK(func->body->IsInstance()) << "The body of the function must be SeqExpr."; - orig_params_ = func->params; Expr new_body = this->VisitExpr(func->body); diff --git a/src/script/printer/relax/binding.cc b/src/script/printer/relax/binding.cc index 44a2cd338c5e..c8b616b4bcb5 100644 --- a/src/script/printer/relax/binding.cc +++ b/src/script/printer/relax/binding.cc @@ -27,8 +27,8 @@ IfDoc PrintIfExpr(const relax::If& n, const ObjectPath& n_p, const IRDocsifier& using relax::SeqExpr; ExprDoc cond = d->AsDoc(n->cond, n_p->Attr("cond")); std::vector> branches{ - PrintSeqExpr(Downcast(n->true_branch), n_p->Attr("true_branch"), d, false), - PrintSeqExpr(Downcast(n->false_branch), n_p->Attr("false_branch"), d, false), + PrintSeqExpr(n->true_branch, n_p->Attr("true_branch"), d, false), + PrintSeqExpr(n->false_branch, n_p->Attr("false_branch"), d, false), }; if (var.defined()) { for (Array& stmts : branches) { diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc index 458eb3766de8..3b5302bebc3e 100644 --- a/src/script/printer/relax/function.cc +++ b/src/script/printer/relax/function.cc @@ -119,8 +119,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } // Step 6. Print body - Array body = - PrintSeqExpr(Downcast(n->body), n_p->Attr("body"), d, /*use_ret=*/true); + Array body = PrintSeqExpr(n->body, n_p->Attr("body"), d, /*use_ret=*/true); (*f)->stmts.insert((*f)->stmts.end(), body.begin(), body.end()); return HeaderWrapper(d, FunctionDoc(func_name, params, {decorator}, ret_type, (*f)->stmts)); }); diff --git a/tests/python/relax/test_expr_functor.py b/tests/python/relax/test_expr_functor.py index 0daf9d4a1f7a..f3d2432549e1 100644 --- a/tests/python/relax/test_expr_functor.py +++ b/tests/python/relax/test_expr_functor.py @@ -439,7 +439,7 @@ def test_if(): if_node = relax.If(x, x, x) basic_check( if_node, - "\n".join(["If", "\tVar", "\tVar", "\tVar"]), + "\n".join(["If", "\tVar", "\tSeqExpr", "\t\tVar", "\tSeqExpr", "\t\tVar"]), "\n".join(["Var", "Var", "SeqExpr", "Var", "SeqExpr", "If"]), )