Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 117 additions & 73 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,78 +213,6 @@ Call WithFields(Call call, Optional<Expr> opt_op = Optional<Expr>(),
Optional<Array<StructInfo>> opt_sinfo_args = Optional<Array<StructInfo>>(),
Optional<Span> opt_span = Optional<Span>());

/*!
* \brief Condition expression
*
* Unlike traditional statement `if`s, the if evalutes
* to the result of the branch taken.
*
* x = if (true) { 1 } else { 0 }; // x is 1
* y = if (false) { 1 } else { 0 }; // y is 0
*
* \note This is similar to C's ternary operator.
*/
class IfNode : public ExprNode {
public:
/*! \brief The condition. */
Expr cond;
/*! \brief The expression evaluated when condition is true. */
Expr true_branch;
/*! \brief The expression evaluated when condition is false */
Expr false_branch;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("cond", &cond);
v->Visit("true_branch", &true_branch);
v->Visit("false_branch", &false_branch);
v->Visit("_checked_type_", &checked_type_);
v->Visit("struct_info_", &struct_info_);
v->Visit("span", &span);
}

bool SEqualReduce(const IfNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(cond, other->cond) && equal(true_branch, other->true_branch) &&
equal(false_branch, other->false_branch) && equal(struct_info_, other->struct_info_);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(cond);
hash_reduce(true_branch);
hash_reduce(false_branch);
hash_reduce(struct_info_);
}

static constexpr const char* _type_key = "relax.expr.If";
TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode);
};

class If : public Expr {
public:
/*!
* \brief The constructor
* \param cond The condition of a if node.
* \param true_branch The fall through branch
* \param false_branch The branch for execution when condition is false.
* \param span The source span of the expression.
*/
TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode);
};

/*!
* \brief Returns \p if_expr with the given properties. A null property denotes 'no change'.
* Returns \p if_expr if all properties are unchanged. Otherwise, returns a copy with the new
* fields.
*/
If WithFields(If if_expr, Optional<Expr> opt_cond = Optional<Expr>(),
Optional<Expr> opt_true_branch = Optional<Expr>(),
Optional<Expr> opt_false_branch = Optional<Expr>(),
Optional<Span> opt_span = Optional<Span>());

/*! \brief Tuple container */
class TupleNode : public ExprNode {
public:
Expand Down Expand Up @@ -915,18 +843,113 @@ class SeqExprNode : public ExprNode {

class SeqExpr : public Expr {
public:
/* \brief Implicit conversion constructor
*
* Relax nodes that introduce a new scope (e.g. `relax::Function`)
* are required to be held as SeqExpr. This implicit conversion
* provides allows callsites to use these member variables when the
* C++ compile-time type is a `relax::Expr`. For example,
* a transform may use `func.CopyOnWrite()->body = expr;`.
*
* If the expression is already a `relax::SeqExpr`, the same
* underlying `relax::SeqExprNode` is used, and no copies are made.
*/
TVM_DLL SeqExpr(Expr body); // NOLINT(*)

TVM_DLL explicit SeqExpr(Array<BindingBlock> blocks, Expr body, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqExprNode);
};

/*!
* \brief Condition expression
*
* Unlike traditional statement `if`s, the if evalutes
* to the result of the branch taken.
*
* x = if (true) { 1 } else { 0 }; // x is 1
* y = if (false) { 1 } else { 0 }; // y is 0
*
* \note This is similar to C's ternary operator.
*/
class IfNode : public ExprNode {
Comment thread
quic-sanirudh marked this conversation as resolved.
public:
/*! \brief The condition. */
Expr cond;
/*! \brief The expression evaluated when condition is true. */
SeqExpr true_branch;
/*! \brief The expression evaluated when condition is false */
SeqExpr false_branch;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("cond", &cond);
v->Visit("true_branch", &true_branch);
v->Visit("false_branch", &false_branch);
v->Visit("_checked_type_", &checked_type_);
v->Visit("struct_info_", &struct_info_);
v->Visit("span", &span);
}

bool SEqualReduce(const IfNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(cond, other->cond) && equal(true_branch, other->true_branch) &&
equal(false_branch, other->false_branch) && equal(struct_info_, other->struct_info_);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(cond);
hash_reduce(true_branch);
hash_reduce(false_branch);
hash_reduce(struct_info_);
}

static constexpr const char* _type_key = "relax.expr.If";
TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode);
};

class If : public Expr {
public:
/*!
* \brief The constructor
*
* \param cond The condition of a if node.
*
* \param true_branch The fall through branch. If this is not a
* SeqExpr, it will be wrapped in a SeqExpr, to satisfy the
* Relax IR requirement that all scopes be contained in a
* SeqExpr.
*
* \param false_branch The branch for execution when condition is
* false. If this is not a SeqExpr, it will be wrapped in a
* SeqExpr, to satisfy the Relax IR requirement that all scopes
* be contained in a SeqExpr.
*
* \param span The source span of the expression.
*/
TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode);
};

/*!
* \brief Returns \p if_expr with the given properties. A null property denotes 'no change'.
* Returns \p if_expr if all properties are unchanged. Otherwise, returns a copy with the new
* fields.
*/
If WithFields(If if_expr, Optional<Expr> opt_cond = Optional<Expr>(),
Optional<Expr> opt_true_branch = Optional<Expr>(),
Optional<Expr> opt_false_branch = Optional<Expr>(),
Optional<Span> opt_span = Optional<Span>());

/*! \brief A Relax function. */
class FunctionNode : public BaseFuncNode {
public:
/*! \brief The parameters to the function. */
Array<Var> params;
/*! \brief The body of the function. */
Expr body;
SeqExpr body;
/*! \brief The return type of the function. */
StructInfo ret_struct_info;
/*! \brief Whether the function is annotated as pure or not. */
Expand Down Expand Up @@ -968,6 +991,27 @@ class FunctionNode : public BaseFuncNode {

class Function : public BaseFunc {
public:
/*!
* \brief Construct a Relax Function
*
* \param params The parameters accepted by the function
*
* \param body The body of the function. If this is not a
* SeqExpr, it will be wrapped in a SeqExpr, to satisfy the
* Relax IR requirement that all scopes be contained in a
* SeqExpr.
*
* \param ret_struct_info The StructInfo returned by the function.
* If NullOpt, will be inferred from the StructInfo of the
* function's body.
*
* \param is_pure The purity of the function.
*
* \param attrs Any attributes associated with the function.
* Defaults to an empty dictionary.
*
* \param span The source span of the expression.
*/
TVM_DLL explicit Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info,
bool is_pure = true, DictAttrs attrs = DictAttrs(), Span span = Span());

Expand Down
9 changes: 3 additions & 6 deletions src/contrib/msc/core/ir/graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,9 @@ const MSCGraph RelaxGraphBuilder::Build(const relax::Function& func) {
}
}
VisitExpr(func);
if (const auto* b_node = func->body.as<relax::SeqExprNode>()) {
ICHECK(expr_tensor_map_.count(b_node->body)) << "Can not find seqexpr body " << b_node->body;
output_names = expr_tensor_map_[b_node->body];
} else {
LOG(FATAL) << "Function body should be SeqExpr, get " << func->body;
}
ICHECK(expr_tensor_map_.count(func->body->body))
<< "Can not find seqexpr body " << func->body->body;
output_names = expr_tensor_map_[func->body->body];
// remove const nodes as weights
Array<MSCJoint> valid_nodes;
std::set<String> ignore_inputs;
Expand Down
20 changes: 6 additions & 14 deletions src/contrib/msc/core/transform/set_expr_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1268,13 +1268,9 @@ class LayoutInfer : public ExprVisitor {
SetExprLayout(call->args[i], var_layout_map_[func->params[i]]);
}
}
if (const auto* b_node = func->body.as<relax::SeqExprNode>()) {
if (b_node->body->IsInstance<VarNode>() &&
var_layout_map_.count(Downcast<Var>(b_node->body))) {
SetExprLayout(ret, var_layout_map_[Downcast<Var>(b_node->body)]);
}
} else {
LOG(FATAL) << "Function body should be SeqExpr, get " << func->body;
if (func->body->body->IsInstance<VarNode>() &&
var_layout_map_.count(Downcast<Var>(func->body->body))) {
SetExprLayout(ret, var_layout_map_[Downcast<Var>(func->body->body)]);
}
}

Expand All @@ -1288,13 +1284,9 @@ class LayoutInfer : public ExprVisitor {
if (producer->IsInstance<CallNode>() &&
local_funcs_.count(Downcast<Call>(producer)->op)) {
const auto& caller = local_funcs_[Downcast<Call>(producer)->op];
if (const auto* b_node = caller->body.as<relax::SeqExprNode>()) {
if (b_node->body->IsInstance<VarNode>() &&
var_map_.count(Downcast<Var>(b_node->body))) {
SetExprLayout(b_node->body, param_layout);
}
} else {
LOG(FATAL) << "Caller body should be SeqExpr, get " << caller->body;
if (caller->body->body->IsInstance<VarNode>() &&
var_map_.count(Downcast<Var>(caller->body->body))) {
SetExprLayout(caller->body->body, param_layout);
}
}
}
Expand Down
32 changes: 12 additions & 20 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,7 @@ class WellFormedChecker : public relax::ExprVisitor,
}
}

if (auto seq = op->body.as<SeqExprNode>()) {
this->VisitSeqExpr(seq);
} else {
Malformed(Diagnostic::Error(op) << "Function bodies must be sequence expressions");
}
this->VisitSeqExpr(op->body.get());

is_dataflow_ = old_dataflow_state;
dataflow_var_set_ = prev_dataflow_var_set;
Expand Down Expand Up @@ -367,21 +363,17 @@ class WellFormedChecker : public relax::ExprVisitor,
} else {
Malformed(Diagnostic::Error(op) << "The condition for an if node must be a leaf expression.");
}
auto true_seq = op->true_branch.as<SeqExprNode>();
auto false_seq = op->false_branch.as<SeqExprNode>();
if (true_seq && false_seq) {
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> previous_var_set = var_set_;
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> previous_symbolic_var_set =
symbolic_var_set_;
this->VisitSeqExpr(true_seq);
var_set_ = previous_var_set;
symbolic_var_set_ = previous_symbolic_var_set;
this->VisitSeqExpr(false_seq);
var_set_ = previous_var_set;
symbolic_var_set_ = previous_symbolic_var_set;
} else {
Malformed(Diagnostic::Error(op) << "If node branches must be seq exprs");
}

std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> previous_var_set = var_set_;
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> previous_symbolic_var_set =
symbolic_var_set_;
this->VisitSeqExpr(op->true_branch.get());
var_set_ = previous_var_set;
symbolic_var_set_ = previous_symbolic_var_set;
this->VisitSeqExpr(op->false_branch.get());
var_set_ = previous_var_set;
symbolic_var_set_ = previous_symbolic_var_set;

CheckStructInfo(op);
}

Expand Down
2 changes: 1 addition & 1 deletion src/relax/backend/contrib/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Map<String, IntImm> ExtractArgIdx(String pattern_name, Function f) {
ICHECK(pattern) << "Unsupported op_type " << pattern_name;

auto bindings = AnalyzeVar2Value(f);
auto inner_body = Downcast<SeqExpr>(f->body)->body;
auto inner_body = f->body->body;
auto matched_expr = relax::ExtractMatchedExpr(pattern.value()->pattern, inner_body, bindings);
ICHECK(matched_expr) << "ValueError: "
<< "For named pattern \"" << pattern_name
Expand Down
29 changes: 23 additions & 6 deletions src/relax/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,30 @@ bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
return VisitDFPattern(pattern, expr);
}

static Expr TryGetValOfVar(const Expr& expr, const Map<Var, Expr>& var2val) {
if (var2val.empty()) return expr;
static Expr TryGetValOfVar(Expr expr, const Map<Var, Expr>& var2val) {
auto unwrap = [&](Expr expr) -> Optional<Expr> {
// Unwrap variables into the value to which they are bound.
if (var2val.size()) {
if (const VarNode* var = expr.as<VarNode>()) {
if (auto may = var2val.Get(GetRef<Var>(var))) {
return may.value();
}
}
}

// Unwrap SeqExpr with no bindings. These can occur due to Relax
// IR constraints for the bodies of Function and If nodes.
if (auto seq = expr.as<SeqExprNode>()) {
if (seq->blocks.empty()) {
return seq->body;
}
}

return NullOpt;
};

// if not match, try to match value of var if expr is a var.
if (const VarNode* var = expr.as<VarNode>()) {
auto may = var2val.Get(GetRef<Var>(var));
if (may.defined()) return may.value();
while (auto unwrapped = unwrap(expr)) {
expr = unwrapped.value();
}

return expr;
Expand Down
8 changes: 8 additions & 0 deletions src/relax/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,14 @@ TVM_REGISTER_GLOBAL("relax.DataflowBlock").set_body_typed([](Array<Binding> bind

TVM_REGISTER_NODE_TYPE(SeqExprNode);

SeqExpr::SeqExpr(Expr body) {
if (auto seq = body.as<SeqExpr>()) {
*this = seq.value();
} else {
*this = SeqExpr(Array<BindingBlock>{}, body);
}
}

SeqExpr::SeqExpr(Array<BindingBlock> blocks, Expr body, Span span) {
ObjectPtr<SeqExprNode> n = make_object<SeqExprNode>();
n->blocks = std::move(blocks);
Expand Down
7 changes: 2 additions & 5 deletions src/relax/training/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,10 @@ class AppendLossMutator : private ExprMutator {
num_backbone_outputs_(num_backbone_outputs) {}

Expr VisitExpr_(const FunctionNode* func) final {
CHECK(func->body->IsInstance<SeqExprNode>() && loss_function_->body->IsInstance<SeqExprNode>())
<< "The bodies of the backbone and the loss function must be SeqExpr.";

// Well-formed checks and setting up class members
loss_body_ = Downcast<SeqExpr>(loss_function_->body);
loss_body_ = loss_function_->body;
CheckLossBody();
BackboneReturnToArr(func->body.as<SeqExprNode>()->body);
BackboneReturnToArr(func->body->body);
CheckAndRemapBackboneReturn();
CheckAndRemapLossParams(loss_function_->params);

Expand Down
Loading