Skip to content
Closed
56 changes: 56 additions & 0 deletions include/tvm/relax/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,37 @@ class ExprVisitor : public ExprFunctor<void(const Expr&)> {
virtual void VisitSpan(const Span& span);
virtual void VisitPrimExpr(const PrimExpr& expr);

/*!
* \brief Look up the value bound to a variable.
* \param var The var to be looked up.
* \return The value bound to the input \p var.
* \note For function parameters, this function returns NullOpt.
*/
inline Optional<Expr> LookupBinding(const Var& var) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should instead enforce call_tir to be a more restricted form

if (auto it = binding_table_.find(var->vid); it != binding_table_.end()) {
return it->second;
} else {
return NullOpt;
}
}

/*!
* \brief Unwrap any known binding
* \param expr The expression to be unwrapped.
* \return The expression after following any known Var bindings
*/
inline Expr UnwrapBindings(Expr expr) {
while (true) {
auto as_var = expr.as<Var>();
if (!as_var) return expr;

auto bound_expr = LookupBinding(as_var.value());
if (!bound_expr) return expr;

expr = bound_expr.value();
}
}

private:
using TSelf = ExprVisitor;
using VisitBindingVTable =
Expand Down Expand Up @@ -308,6 +339,14 @@ class ExprVisitor : public ExprFunctor<void(const Expr&)> {
// This visitor is not visible to child classes and only
// used to supported default visiting behavior.
DefaultStructInfoFieldVisitor default_struct_info_field_visitor_{this};

/*! \brief A binding table that maps var to value.
*
* Unlike ExprMutator, which can rely on the binding table of the
* internal BlockBuilder, ExprVisitor needs to track the bindings on
* its own.
*/
std::unordered_map<Id, Expr, ObjectPtrHash, ObjectPtrEqual> binding_table_;
};

void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
Expand Down Expand Up @@ -512,6 +551,23 @@ class ExprMutator : public ExprMutatorBase {
*/
Optional<Expr> LookupBinding(const Var& var);

/*!
* \brief Unwrap any known binding
* \param expr The expression to be unwrapped.
* \return The expression after following any known Var bindings
*/
inline Expr UnwrapBindings(Expr expr) {
while (true) {
auto as_var = expr.as<Var>();
if (!as_var) return expr;

auto bound_expr = LookupBinding(as_var.value());
if (!bound_expr) return expr;

expr = bound_expr.value();
}
}

/*!
* \brief Post-order rewrite a node and normalize.
* \tparam T The node type to be rewritten.
Expand Down
26 changes: 17 additions & 9 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from . import _ffi_api
from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar, Var
from ..expr import Tuple as RxTuple
from ..struct_info import StructInfo, TensorStructInfo
from ..struct_info import StructInfo, TensorStructInfo, TupleStructInfo
from ...ir import PrimExpr
from ..utils import args_converter

Expand Down Expand Up @@ -66,6 +66,18 @@ def null_value() -> Call:
return _ffi_api.null_value() # type: ignore


def _normalize_arg_tuple(args: Expr) -> Expr:
if isinstance(args, RxTuple) or isinstance(args.struct_info_, TupleStructInfo):
# A tuple, or a Var bound to a tuple, are kept as-is
return args
elif isinstance(args, Expr):
# A single argument is wrapped into a tuple
return RxTuple((args,))
else:
# Anything else is left for the FFI to handle
return args


@args_converter.auto
def call_tir(
gvar: GlobalVar,
Expand Down Expand Up @@ -97,8 +109,7 @@ def call_tir(
ret: Call
A call node for the call_tir operator.
"""
if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))
args = _normalize_arg_tuple(args)

if not isinstance(out_sinfo, list):
out_sinfo = [out_sinfo]
Expand Down Expand Up @@ -152,8 +163,7 @@ def call_tir_with_grad(
ret: Call
A call node for the call_tir_with_grad operator.
"""
if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))
args = _normalize_arg_tuple(args)

if not isinstance(out_sinfo, list):
out_sinfo = [out_sinfo]
Expand Down Expand Up @@ -220,8 +230,7 @@ def call_tir_inplace(
ret: Call
A call node for the call_tir operator.
"""
if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))
args = _normalize_arg_tuple(args)

if not isinstance(inplace_indices, list):
inplace_indices = [inplace_indices]
Expand Down Expand Up @@ -275,8 +284,7 @@ def call_dps_packed(
if isinstance(func, str):
func = ExternFunc(func)

if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))
args = _normalize_arg_tuple(args)

if not isinstance(out_sinfo, list):
out_sinfo = [out_sinfo]
Expand Down
2 changes: 2 additions & 0 deletions src/relax/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
void ExprVisitor::VisitBinding_(const VarBindingNode* binding, const OP* value) { \
this->VisitExpr(binding->value); \
this->VisitVarDef(binding->var); \
this->binding_table_.insert({binding->var->vid, binding->value}); \
}

// functions to be overriden.
Expand Down Expand Up @@ -258,6 +259,7 @@ RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(DataTypeImmNode);
void ExprVisitor::VisitBinding_(const MatchCastNode* binding) {
this->VisitExpr(binding->value);
this->VisitVarDef(binding->var);
this->binding_table_.insert({binding->var->vid, binding->value});
}

void ExprVisitor::VisitBindingBlock_(const BindingBlockNode* block) {
Expand Down
39 changes: 19 additions & 20 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ RELAY_REGISTER_OP("relax.call_tir")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIR)
.set_attr<Bool>("FPurity", Bool(true));

Expr MakeCallTIR(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list,
Expr MakeCallTIR(Expr func, Expr arg_tuple, Array<TensorStructInfo> out_sinfo_list,
Optional<Expr> packed_ints) {
for (const TensorStructInfo& sinfo : out_sinfo_list) {
const auto* shape = sinfo->shape.as<ShapeExprNode>();
Expand All @@ -283,9 +283,9 @@ Expr MakeCallTIR(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list,
Call call;
if (!packed_ints) {
// don't use additional optional argument
call = Call(op, {func, args}, {}, {out_sinfo});
call = Call(op, {func, arg_tuple}, {}, {out_sinfo});
} else {
call = Call(op, {func, args, packed_ints.value()}, {}, {out_sinfo});
call = Call(op, {func, arg_tuple, packed_ints.value()}, {}, {out_sinfo});
}
return call;
}
Expand All @@ -307,7 +307,7 @@ RELAY_REGISTER_OP("relax.call_tir_with_grad")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIR)
.set_attr<Bool>("FPurity", Bool(true));

Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list,
Expr MakeCallTIRWithGrad(Expr func, Expr arg_tuple, Array<TensorStructInfo> out_sinfo_list,
String te_grad_name, Map<String, ObjectRef> te_grad_kwargs,
Optional<Expr> packed_ints) {
for (const TensorStructInfo& sinfo : out_sinfo_list) {
Expand All @@ -333,9 +333,9 @@ Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array<TensorStructInfo> out_sinf
Call call;
if (!packed_ints) {
// don't use additional optional argument
call = Call(op, {func, args}, Attrs(attrs), {out_sinfo});
call = Call(op, {func, arg_tuple}, Attrs(attrs), {out_sinfo});
} else {
call = Call(op, {func, args, packed_ints.value()}, Attrs(attrs), {out_sinfo});
call = Call(op, {func, arg_tuple, packed_ints.value()}, Attrs(attrs), {out_sinfo});
}
return call;
}
Expand Down Expand Up @@ -364,7 +364,8 @@ StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& c
}

// check the range for inplace indices, make sure at least one is not -1, ensure they're unique
size_t num_args = Downcast<Tuple>(call->args[1])->fields.size();
auto arg_sinfo = GetStructInfoAs<TupleStructInfoNode>(call->args[1]);
size_t num_args = arg_sinfo->fields.size();
std::unordered_set<int> encountered;
for (size_t i = 0; i < attrs->inplace_indices.size(); i++) {
int index = attrs->inplace_indices[i].IntValue();
Expand All @@ -391,14 +392,13 @@ StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& c
// for safety, we will make sure the output shape for each in-place argument exactly matches the
// input shape
// TODO(@slyubomirsky): eventually we will want to handle cases where that is not true
Tuple call_args = Downcast<Tuple>(call->args[1]);
if (attrs->inplace_indices.size() == 1) {
auto* out_sinfo = call->sinfo_args[0].as<TensorStructInfoNode>();
if (!out_sinfo) {
ctx->ReportFatal(Diagnostic::Error(call) << "The output struct info must be a tensor");
}
auto* input_sinfo = GetStructInfoAs<TensorStructInfoNode>(
call_args->fields[attrs->inplace_indices[0].IntValue()]);
auto* input_sinfo =
arg_sinfo->fields[attrs->inplace_indices[0].IntValue()].as<TensorStructInfoNode>();
if (!input_sinfo || !input_sinfo->shape.defined() ||
!CanProveShapeEqual(input_sinfo->shape.value(), out_sinfo->shape.value(),
ctx->GetAnalyzer())) {
Expand All @@ -412,24 +412,23 @@ StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& c
} else {
auto out_sinfos = call->sinfo_args[0].as<TupleStructInfoNode>()->fields;
for (size_t i = 0; i < attrs->inplace_indices.size(); i++) {
if (attrs->inplace_indices[i].IntValue() == -1) {
int inplace_index = attrs->inplace_indices[i].IntValue();
if (inplace_index == -1) {
continue;
}
auto* out_sinfo = out_sinfos[i].as<TensorStructInfoNode>();
if (!out_sinfo) {
ctx->ReportFatal(Diagnostic::Error(call) << "The output struct info must be a tensor");
}
auto* input_sinfo = GetStructInfoAs<TensorStructInfoNode>(
call_args->fields[attrs->inplace_indices[i].IntValue()]);
auto* input_sinfo = arg_sinfo->fields[inplace_index].as<TensorStructInfoNode>();
if (!input_sinfo || !input_sinfo->shape.defined() ||
!CanProveShapeEqual(input_sinfo->shape.value(), out_sinfo->shape.value(),
ctx->GetAnalyzer())) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "The shape of output " << i << " must match that of input "
<< attrs->inplace_indices[i].IntValue() << ", whereas we have "
<< out_sinfo->shape.value() << " in output " << i << " versus "
<< input_sinfo->shape.value() << " in input "
<< attrs->inplace_indices[i].IntValue());
<< inplace_index << ", whereas we have " << out_sinfo->shape.value()
<< " in output " << i << " versus " << input_sinfo->shape.value()
<< " in input " << inplace_index);
}
}
}
Expand All @@ -453,7 +452,7 @@ RELAY_REGISTER_OP("relax.call_tir_inplace")
// arguments will no longer be live)
.set_attr<Bool>("FPurity", Bool(true));

Expr MakeCallTIRInplace(Expr func, Tuple args, Array<Integer> inplace_indices,
Expr MakeCallTIRInplace(Expr func, Expr arg_tuple, Array<Integer> inplace_indices,
Array<TensorStructInfo> out_sinfo_list, Optional<Expr> packed_ints) {
for (const TensorStructInfo& sinfo : out_sinfo_list) {
const auto* shape = sinfo->shape.as<ShapeExprNode>();
Expand All @@ -476,9 +475,9 @@ Expr MakeCallTIRInplace(Expr func, Tuple args, Array<Integer> inplace_indices,
Call call;
if (!packed_ints) {
// don't use additional optional argument
call = Call(op, {func, args}, Attrs(attrs), {out_sinfo});
call = Call(op, {func, arg_tuple}, Attrs(attrs), {out_sinfo});
} else {
call = Call(op, {func, args, packed_ints.value()}, Attrs(attrs), {out_sinfo});
call = Call(op, {func, arg_tuple, packed_ints.value()}, Attrs(attrs), {out_sinfo});
}
return call;
}
Expand Down
Loading