diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index ce209ccd460f..aa217b0b78f9 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -278,6 +278,37 @@ class ExprVisitor : public ExprFunctor { virtual void VisitSpan(const Span& span); virtual void VisitPrimExpr(const PrimExpr& expr); + /*! + * \brief Look up the value bound to a variable. + * \param var The var to be looked up. + * \return The value bound to the input \p var. + * \note For function parameters, this function returns NullOpt. + */ + inline Optional LookupBinding(const Var& var) { + if (auto it = binding_table_.find(var->vid); it != binding_table_.end()) { + return it->second; + } else { + return NullOpt; + } + } + + /*! + * \brief Unwrap any known binding + * \param expr The expression to be unwrapped. + * \return The expression after following any known Var bindings + */ + inline Expr UnwrapBindings(Expr expr) { + while (true) { + auto as_var = expr.as(); + if (!as_var) return expr; + + auto bound_expr = LookupBinding(as_var.value()); + if (!bound_expr) return expr; + + expr = bound_expr.value(); + } + } + private: using TSelf = ExprVisitor; using VisitBindingVTable = @@ -308,6 +339,14 @@ class ExprVisitor : public ExprFunctor { // This visitor is not visible to child classes and only // used to supported default visiting behavior. DefaultStructInfoFieldVisitor default_struct_info_field_visitor_{this}; + + /*! \brief A binding table that maps var to value. + * + * Unlike ExprMutator, which can rely on the binding table of the + * internal BlockBuilder, ExprVisitor needs to track the bindings on + * its own. + */ + std::unordered_map binding_table_; }; void PostOrderVisit(const Expr& node, std::function fvisit); @@ -512,6 +551,23 @@ class ExprMutator : public ExprMutatorBase { */ Optional LookupBinding(const Var& var); + /*! + * \brief Unwrap any known binding + * \param expr The expression to be unwrapped. + * \return The expression after following any known Var bindings + */ + inline Expr UnwrapBindings(Expr expr) { + while (true) { + auto as_var = expr.as(); + if (!as_var) return expr; + + auto bound_expr = LookupBinding(as_var.value()); + if (!bound_expr) return expr; + + expr = bound_expr.value(); + } + } + /*! * \brief Post-order rewrite a node and normalize. * \tparam T The node type to be rewritten. diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index b363dc6952d8..c62b0221a476 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -25,7 +25,7 @@ from . import _ffi_api from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar, Var from ..expr import Tuple as RxTuple -from ..struct_info import StructInfo, TensorStructInfo +from ..struct_info import StructInfo, TensorStructInfo, TupleStructInfo from ...ir import PrimExpr from ..utils import args_converter @@ -66,6 +66,18 @@ def null_value() -> Call: return _ffi_api.null_value() # type: ignore +def _normalize_arg_tuple(args: Expr) -> Expr: + if isinstance(args, RxTuple) or isinstance(args.struct_info_, TupleStructInfo): + # A tuple, or a Var bound to a tuple, are kept as-is + return args + elif isinstance(args, Expr): + # A single argument is wrapped into a tuple + return RxTuple((args,)) + else: + # Anything else is left for the FFI to handle + return args + + @args_converter.auto def call_tir( gvar: GlobalVar, @@ -97,8 +109,7 @@ def call_tir( ret: Call A call node for the call_tir operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore - args = RxTuple((args,)) + args = _normalize_arg_tuple(args) if not isinstance(out_sinfo, list): out_sinfo = [out_sinfo] @@ -152,8 +163,7 @@ def call_tir_with_grad( ret: Call A call node for the call_tir_with_grad operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore - args = RxTuple((args,)) + args = _normalize_arg_tuple(args) if not isinstance(out_sinfo, list): out_sinfo = [out_sinfo] @@ -220,8 +230,7 @@ def call_tir_inplace( ret: Call A call node for the call_tir operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore - args = RxTuple((args,)) + args = _normalize_arg_tuple(args) if not isinstance(inplace_indices, list): inplace_indices = [inplace_indices] @@ -275,8 +284,7 @@ def call_dps_packed( if isinstance(func, str): func = ExternFunc(func) - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore - args = RxTuple((args,)) + args = _normalize_arg_tuple(args) if not isinstance(out_sinfo, list): out_sinfo = [out_sinfo] diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 14a704d729e3..75e88e00e337 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -71,6 +71,7 @@ void ExprVisitor::VisitBinding_(const VarBindingNode* binding, const OP* value) { \ this->VisitExpr(binding->value); \ this->VisitVarDef(binding->var); \ + this->binding_table_.insert({binding->var->vid, binding->value}); \ } // functions to be overriden. @@ -258,6 +259,7 @@ RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(DataTypeImmNode); void ExprVisitor::VisitBinding_(const MatchCastNode* binding) { this->VisitExpr(binding->value); this->VisitVarDef(binding->var); + this->binding_table_.insert({binding->var->vid, binding->value}); } void ExprVisitor::VisitBindingBlock_(const BindingBlockNode* block) { diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 01d0d04be0cc..4429347893d2 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -263,7 +263,7 @@ RELAY_REGISTER_OP("relax.call_tir") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FPurity", Bool(true)); -Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, +Expr MakeCallTIR(Expr func, Expr arg_tuple, Array out_sinfo_list, Optional packed_ints) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); @@ -283,9 +283,9 @@ Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, Call call; if (!packed_ints) { // don't use additional optional argument - call = Call(op, {func, args}, {}, {out_sinfo}); + call = Call(op, {func, arg_tuple}, {}, {out_sinfo}); } else { - call = Call(op, {func, args, packed_ints.value()}, {}, {out_sinfo}); + call = Call(op, {func, arg_tuple, packed_ints.value()}, {}, {out_sinfo}); } return call; } @@ -307,7 +307,7 @@ RELAY_REGISTER_OP("relax.call_tir_with_grad") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FPurity", Bool(true)); -Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array out_sinfo_list, +Expr MakeCallTIRWithGrad(Expr func, Expr arg_tuple, Array out_sinfo_list, String te_grad_name, Map te_grad_kwargs, Optional packed_ints) { for (const TensorStructInfo& sinfo : out_sinfo_list) { @@ -333,9 +333,9 @@ Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array out_sinf Call call; if (!packed_ints) { // don't use additional optional argument - call = Call(op, {func, args}, Attrs(attrs), {out_sinfo}); + call = Call(op, {func, arg_tuple}, Attrs(attrs), {out_sinfo}); } else { - call = Call(op, {func, args, packed_ints.value()}, Attrs(attrs), {out_sinfo}); + call = Call(op, {func, arg_tuple, packed_ints.value()}, Attrs(attrs), {out_sinfo}); } return call; } @@ -364,7 +364,8 @@ StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& c } // check the range for inplace indices, make sure at least one is not -1, ensure they're unique - size_t num_args = Downcast(call->args[1])->fields.size(); + auto arg_sinfo = GetStructInfoAs(call->args[1]); + size_t num_args = arg_sinfo->fields.size(); std::unordered_set encountered; for (size_t i = 0; i < attrs->inplace_indices.size(); i++) { int index = attrs->inplace_indices[i].IntValue(); @@ -391,14 +392,13 @@ StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& c // for safety, we will make sure the output shape for each in-place argument exactly matches the // input shape // TODO(@slyubomirsky): eventually we will want to handle cases where that is not true - Tuple call_args = Downcast(call->args[1]); if (attrs->inplace_indices.size() == 1) { auto* out_sinfo = call->sinfo_args[0].as(); if (!out_sinfo) { ctx->ReportFatal(Diagnostic::Error(call) << "The output struct info must be a tensor"); } - auto* input_sinfo = GetStructInfoAs( - call_args->fields[attrs->inplace_indices[0].IntValue()]); + auto* input_sinfo = + arg_sinfo->fields[attrs->inplace_indices[0].IntValue()].as(); if (!input_sinfo || !input_sinfo->shape.defined() || !CanProveShapeEqual(input_sinfo->shape.value(), out_sinfo->shape.value(), ctx->GetAnalyzer())) { @@ -412,24 +412,23 @@ StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& c } else { auto out_sinfos = call->sinfo_args[0].as()->fields; for (size_t i = 0; i < attrs->inplace_indices.size(); i++) { - if (attrs->inplace_indices[i].IntValue() == -1) { + int inplace_index = attrs->inplace_indices[i].IntValue(); + if (inplace_index == -1) { continue; } auto* out_sinfo = out_sinfos[i].as(); if (!out_sinfo) { ctx->ReportFatal(Diagnostic::Error(call) << "The output struct info must be a tensor"); } - auto* input_sinfo = GetStructInfoAs( - call_args->fields[attrs->inplace_indices[i].IntValue()]); + auto* input_sinfo = arg_sinfo->fields[inplace_index].as(); if (!input_sinfo || !input_sinfo->shape.defined() || !CanProveShapeEqual(input_sinfo->shape.value(), out_sinfo->shape.value(), ctx->GetAnalyzer())) { ctx->ReportFatal(Diagnostic::Error(call) << "The shape of output " << i << " must match that of input " - << attrs->inplace_indices[i].IntValue() << ", whereas we have " - << out_sinfo->shape.value() << " in output " << i << " versus " - << input_sinfo->shape.value() << " in input " - << attrs->inplace_indices[i].IntValue()); + << inplace_index << ", whereas we have " << out_sinfo->shape.value() + << " in output " << i << " versus " << input_sinfo->shape.value() + << " in input " << inplace_index); } } } @@ -453,7 +452,7 @@ RELAY_REGISTER_OP("relax.call_tir_inplace") // arguments will no longer be live) .set_attr("FPurity", Bool(true)); -Expr MakeCallTIRInplace(Expr func, Tuple args, Array inplace_indices, +Expr MakeCallTIRInplace(Expr func, Expr arg_tuple, Array inplace_indices, Array out_sinfo_list, Optional packed_ints) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); @@ -476,9 +475,9 @@ Expr MakeCallTIRInplace(Expr func, Tuple args, Array inplace_indices, Call call; if (!packed_ints) { // don't use additional optional argument - call = Call(op, {func, args}, Attrs(attrs), {out_sinfo}); + call = Call(op, {func, arg_tuple}, Attrs(attrs), {out_sinfo}); } else { - call = Call(op, {func, args, packed_ints.value()}, Attrs(attrs), {out_sinfo}); + call = Call(op, {func, arg_tuple, packed_ints.value()}, Attrs(attrs), {out_sinfo}); } return call; } diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index e040ccea1485..3b3e97517070 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -57,7 +57,7 @@ class CallTIRMutator : public ExprMutator { if (call->op == call_tir_op || call->op == call_tir_inplace_op || call->op == call_dps_packed_op) { - bool is_inplace = (call->op == call_tir_inplace_op); + bool is_inplace_op = (call->op == call_tir_inplace_op); const auto* inplace_attrs = call->attrs.as(); Array outs; if (const auto& _tensor_sinfo = MatchStructInfo(expr)) { @@ -65,7 +65,7 @@ class CallTIRMutator : public ExprMutator { const TensorStructInfo& tensor_sinfo = _tensor_sinfo.value(); ICHECK(tensor_sinfo->shape.defined()) << "the TensorStructInfo shape of call_tir has not populated"; - if (!is_inplace) { + if (!is_inplace_op) { outs.push_back( builder_->Emit(Call(alloc_tensor_op, // {Downcast(tensor_sinfo->shape.value()), @@ -77,8 +77,7 @@ class CallTIRMutator : public ExprMutator { ICHECK(inplace_attrs->inplace_indices[0].IntValue() != -1) << "If calling call_tir_inplace and there is one output, its in-place index must not" " be -1."; - outs.push_back( - Downcast(call->args[1])->fields[inplace_attrs->inplace_indices[0].IntValue()]); + outs.push_back(GetTupleIndex(call->args[1], 0)); } } else if (const auto& _tuple_sinfo = MatchStructInfo(expr)) { // multiple output case @@ -93,7 +92,7 @@ class CallTIRMutator : public ExprMutator { ICHECK(field_tensor->shape.defined()) << "call_tir expects all TensorStructInfo has shape, but got " << field_tensor << " as an element of TupleStructInfo"; - if (!is_inplace || inplace_attrs->inplace_indices[i].IntValue() == -1) { + if (!is_inplace_op || inplace_attrs->inplace_indices[i].IntValue() == -1) { outs.push_back( builder_->Emit(Call(alloc_tensor_op, {Downcast(field_tensor->shape.value()), @@ -101,8 +100,8 @@ class CallTIRMutator : public ExprMutator { Attrs()), "alloc")); } else { - outs.push_back(Downcast(call->args[1]) - ->fields[inplace_attrs->inplace_indices[i].IntValue()]); + outs.push_back( + GetTupleIndex(call->args[1], inplace_attrs->inplace_indices[i].IntValue())); } } } else { @@ -111,45 +110,85 @@ class CallTIRMutator : public ExprMutator { << expr->struct_info_; } - Array args; - if (call->args[1].as()) { - args = Downcast(call->args[1])->fields; - // for call_tir_inplace, don't reinsert in-place args, only the newly allocated ones - if (!is_inplace) { - args.insert(args.end(), outs.begin(), outs.end()); - } else { - for (size_t i = 0; i < outs.size(); i++) { - if (inplace_attrs->inplace_indices[i].IntValue() == -1) { - args.push_back(outs[i]); - } + Expr callee = call->args[0]; + Expr arg_tuple = call->args[1]; + Optional shape_tuple_of_tir_args = NullOpt; + if (call->args.size() > 2) { + shape_tuple_of_tir_args = call->args[2]; + } + + while (true) { + auto as_var = arg_tuple.as(); + if (!as_var) break; + + auto bound_expr = LookupBinding(as_var.value()); + if (!bound_expr) break; + + arg_tuple = bound_expr.value(); + } + + Array args = [&]() { + if (auto ptr = arg_tuple.as()) { + return ptr->fields; + } else if (auto ptr = arg_tuple->struct_info_.as()) { + size_t n_args = ptr->fields.size(); + Array args; + for (size_t i = 0; i < n_args; i++) { + args.push_back(TupleGetItem(arg_tuple, i)); } + return args; + } else { + LOG(FATAL) << "Lowering of " << call + << " requires knowing how many arguments are passed to the function. " + << "However, the tuple of arguments " << arg_tuple + << " is not itself a tuple, " + << "nor does its struct info " << GetStructInfo(arg_tuple) + << " define the number of arguments."; } + }(); - if (call->args.size() == 2) { - builder_->Emit(Call(call->args[0], args), "_"); - } else { - // unpack semantics - args.push_back(call->args[2]); - builder_->Emit(Call(call_tir_dyn_op, {call->args[0], Tuple(args)}), "_"); + for (size_t i = 0; i < outs.size(); i++) { + bool output_is_inplace = + is_inplace_op && inplace_attrs->inplace_indices[i].IntValue() != -1; + if (!output_is_inplace) { + args.push_back(outs[i]); } - } else { - if (!is_inplace) { - args = outs; - args.insert(args.begin(), call->args[1]); + } + + if (shape_tuple_of_tir_args) { + args.push_back(shape_tuple_of_tir_args.value()); + } + + Expr new_call = [&]() { + if (shape_tuple_of_tir_args) { + return Call(call_tir_dyn_op, {callee, Tuple(args)}); } else { - args.push_back(call->args[1]); + return Call(callee, args); } - builder_->Emit(Call(call->args[0], args), "_"); - } + }(); + builder_->Emit(new_call, "_"); if (outs.size() == 1) { return outs[0]; + } else { + return Tuple(outs); } - return std::move(Tuple(outs)); } return GetRef(call); } + + private: + // If e is a tuple literal, return the field denoted by the index. + // Otherwise, insert a tuple get item for that field and return the + // var the result is bound to. + Expr GetTupleIndex(const Expr& e, int index) { + if (const auto* tuple_node = e.as()) { + return tuple_node->fields[index]; + } + auto out = builder_->Emit(TupleGetItem(e, index)); + return out; + } }; Expr CallTIRRewrite(const Expr& e) { return CallTIRMutator().VisitExpr(e); } diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index d6da79c484cf..963fb601829b 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -177,9 +177,27 @@ class ConstantFolder : public ExprMutator { // call_tir needs to have at least three arguments ICHECK_GE(call->args.size(), 2); Optional func = MatchPrimFunc(call->args[0]); - ICHECK(call->args[1].as()) << "call_tir.args[1] must be Tuple"; - Optional> arr_args = - MatchConstArrayArgs(call->args[1].as()->fields); + + Expr arg_expr = call->args[1]; + ICHECK(GetStructInfo(arg_expr).as()) + << "call_tir.args[1] must be Tuple" + << ", but instead args[1] was " << arg_expr << " with struct info " + << GetStructInfo(arg_expr); + + // The arguments to call_tir may be a variable bound to a tuple, + // rather than the tuple itself. Unwrap any such variable + // bindings if they are known. + arg_expr = UnwrapBindings(arg_expr); + + // The arguments may still be a variable bound to a tuple, where + // the tuple was produced outside the current function. In that + // case, we cannot fold the expression. + auto arg_tuple = arg_expr.as(); + if (!arg_tuple) { + return NullOpt; + } + + Optional> arr_args = MatchConstArrayArgs(arg_tuple->fields); ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir should have exactly one sinfo arg"; Optional shape = MatchConstShape(call->sinfo_args[0]); bool output_not_tuple = call->sinfo_args.size() == 1; diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 5cabfc40ca9c..6d13a9b3f793 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -150,12 +150,16 @@ class GraphCreator : public ExprVisitor { } void VisitBinding_(const MatchCastNode* binding) final { + ExprVisitor::VisitBinding_(binding); + IndexedForwardGraph::Node* node = CreateNode(binding->var.get()); SetNodePattern(node, OpPatternKind::kOpaque); AddToPostDFSOrder(node, binding->var.get()); } void VisitBinding_(const VarBindingNode* binding) final { + ExprVisitor::VisitBinding_(binding); + IndexedForwardGraph::Node* node = CreateNode(binding->var.get()); // If the variable is not a dataflow variable, it must be the output variable of this dataflow @@ -169,9 +173,12 @@ class GraphCreator : public ExprVisitor { } else if (const auto* tuple_get_item = binding->value.as()) { // Case 2. The expression is a TupleGetItemNode VisitTupleGetItem(tuple_get_item, node); + } else if (const auto* tuple = binding->value.as()) { + // Case 3. The expression is a Tuple + VisitTuple(tuple, node); } else { VisitUnsupportedNode(binding->value, node); - // Case 3. The type of the expression is not fusion-supported. + // Case 4. The type of the expression is not fusion-supported. // In this case, we skip adding edges, adding an empty node into graph. } AddToPostDFSOrder(node, binding->var.get()); @@ -195,8 +202,16 @@ class GraphCreator : public ExprVisitor { const GlobalVar& global_var = Downcast(call->args[0]); tir::PrimFunc func = Downcast(mod_->Lookup(global_var)); - // Override args for call_tir - args = Downcast(call->args[1])->fields; + // Override args for call_tir, where possible. If the argument + // is a variable bound to a tuple, do not unwrap it, as doing so + // would prevent the tuple object from being exposed to the + // GraphPartitioner. + Expr arg_tuple = call->args[1]; + if (auto ptr = arg_tuple.as()) { + args = ptr->fields; + } else { + args = {arg_tuple}; + } Optional opt_pattern = func->GetAttr("op_pattern"); if (opt_pattern.defined()) { @@ -217,6 +232,13 @@ class GraphCreator : public ExprVisitor { } } + void VisitTuple(const TupleNode* tuple, IndexedForwardGraph::Node* binding_var_node) { + ICHECK_NOTNULL(binding_var_node); + + SetNodePattern(binding_var_node, OpPatternKind::kTuple); + VisitLeaf(GetRef(tuple), binding_var_node, OpPatternKind::kTuple); + } + void VisitTupleGetItem(const TupleGetItemNode* tuple_item, IndexedForwardGraph::Node* binding_var_node) { ICHECK_NOTNULL(binding_var_node); @@ -248,17 +270,11 @@ class GraphCreator : public ExprVisitor { for (const Expr& expr : tuple->fields) { VisitLeaf(expr, binding_var_node, pattern); } - return; - } - - if (!leaf_expr->IsInstance()) { - // Skip GlobalVar, ExternFunc, OpNode. - return; } - auto it = graph_.node_map.find(leaf_expr.get()); IndexedForwardGraph::Node* leaf_node = nullptr; - if (it != graph_.node_map.end()) { + + if (auto it = graph_.node_map.find(leaf_expr.get()); it != graph_.node_map.end()) { leaf_node = it->second; } else if (leaf_expr->IsInstance() || leaf_expr->IsInstance() || leaf_expr->IsInstance() || leaf_expr->IsInstance() || @@ -267,11 +283,11 @@ class GraphCreator : public ExprVisitor { // Since we never fuse constants, the pattern of the constant is set to `kOpaque`. SetNodePattern(leaf_node, OpPatternKind::kOpaque); AddToPostDFSOrder(leaf_node, leaf_expr.get()); - } else { - LOG(FATAL) << "The leaf Expr is supposed to be defined before, but got: " << leaf_expr - << " used before definition."; } - AddEdge(leaf_node, binding_var_node, pattern); + + if (leaf_node) { + AddEdge(leaf_node, binding_var_node, pattern); + } } /********** Helper Functions **********/ @@ -387,14 +403,21 @@ class FunctionCreator : public ExprMutator { if (const auto* var_binding = binding.as()) { if (const auto* call = var_binding->value.as()) { + Array args; + if (call->op == Op::Get("relax.call_tir")) { // Update the name of the function. name_hint_ = name_hint_ + "_" + Downcast(call->args[0])->name_hint; - const Tuple& args = Downcast(call->args[1]); - for (const Expr& arg : args->fields) { - CheckDefAndUpdateParam(arg); - ICHECK(GetStructInfoAs(arg) == nullptr); + // Unwrap a tuple variable to a tuple expression, if + // possible. This treats use of a relax::Var that points to + // a tuple to be treated as a direct usage of the underlying + // relax::Var, preventing duplicate parameters. + Expr arg_tuple = UnwrapBindings(call->args[1]); + if (auto ptr = arg_tuple.as()) { + args = ptr->fields; + } else { + args = {arg_tuple}; } // TODO(tvm-team): handle shape expr } else { @@ -409,27 +432,29 @@ class FunctionCreator : public ExprMutator { } } - for (const Expr& arg : call->args) { - CheckDefAndUpdateParam(arg); - if (GetStructInfoAs(arg) != nullptr) { - // The argument is fully referenced. Thus we remove it from the mapping. - partially_used_tuple_params_.erase(arg.get()); - } + args = call->args; + } + + for (const Expr& arg : args) { + CheckDefAndUpdateParam(arg); + + if (arg->struct_info_.as()) { + // Mark the tuple as being used as a tuple object. Therefore, we'll retain it as a + // tuple parameter. + tuple_param_info_[arg.get()].is_fully_used = true; } } - } else if (var_binding->value.as()) { - const auto* tuple_item = var_binding->value.as(); + + } else if (const auto* tuple_item = var_binding->value.as()) { CheckDefAndUpdateParam(tuple_item->tuple); - if (partially_used_tuple_params_.find(tuple_item->tuple.get()) != - partially_used_tuple_params_.end()) { - // Appending get-item index to the mapping. - partially_used_tuple_params_[tuple_item->tuple.get()].push_back(tuple_item->index); - } + tuple_param_info_[tuple_item->tuple.get()].used_by_index.push_back(tuple_item->index); + } else if (const auto* tuple = var_binding->value.as()) { + tuple_param_info_[tuple].is_fully_used = true; } // Mark the binding variable as defined. - defined_vars_.insert(var_binding->var.get()); + defined_vars_.insert({var_binding->var.get(), var_binding->value}); // Set var as output true if the binding is not a dataflow variable if (!var_binding->var->IsInstance()) { AppendOutput(var_binding->var); @@ -465,27 +490,31 @@ class FunctionCreator : public ExprMutator { // parameters with the parameters of its fields that are accessed in the // function. std::unordered_map> tuple_get_item_remap; - for (auto& [tuple_arg, item_indices] : partially_used_tuple_params_) { - ICHECK(!item_indices.empty()); - int param_idx = tuple_param_idx_[tuple_arg]; - Var param = params_[param_idx]; - String param_name = params_[param_idx]->name_hint(); - TupleStructInfo param_sinfo = Downcast(tuple_arg->struct_info_); - - Array item_args; - Array item_params; - item_args.reserve(item_indices.size()); - item_params.reserve(item_indices.size()); - for (int item_idx : item_indices) { - Var item_param(param_name + "_" + std::to_string(item_idx), param_sinfo->fields[item_idx]); - item_args.push_back(TupleGetItem(GetRef(tuple_arg), item_idx)); - item_params.push_back(item_param); - tuple_get_item_remap[tuple_arg][item_idx] = item_param; + for (const auto& [tuple_arg, tuple_info] : tuple_param_info_) { + if (!tuple_info.is_fully_used && tuple_info.param_index.has_value()) { + const auto& item_indices = tuple_info.used_by_index; + ICHECK(!item_indices.empty()); + int param_idx = tuple_info.param_index.value(); + Var param = params_[param_idx]; + String param_name = params_[param_idx]->name_hint(); + TupleStructInfo param_sinfo = Downcast(tuple_arg->struct_info_); + + Array item_args; + Array item_params; + item_args.reserve(item_indices.size()); + item_params.reserve(item_indices.size()); + for (int item_idx : item_indices) { + Var item_param(param_name + "_" + std::to_string(item_idx), + param_sinfo->fields[item_idx]); + item_args.push_back(TupleGetItem(GetRef(tuple_arg), item_idx)); + item_params.push_back(item_param); + tuple_get_item_remap[tuple_arg][item_idx] = item_param; + } + arguments_.erase(arguments_.begin() + param_idx); + arguments_.insert(arguments_.begin() + param_idx, item_args.begin(), item_args.end()); + params_.erase(params_.begin() + param_idx); + params_.insert(params_.begin() + param_idx, item_params.begin(), item_params.end()); } - arguments_.erase(arguments_.begin() + param_idx); - arguments_.insert(arguments_.begin() + param_idx, item_args.begin(), item_args.end()); - params_.erase(params_.begin() + param_idx); - params_.insert(params_.begin() + param_idx, item_params.begin(), item_params.end()); } // Step 3. Visit each binding and collect outputs one by one. @@ -606,8 +635,7 @@ class FunctionCreator : public ExprMutator { // Mark the tuple parameter is partially referenced in the beginning. // We will remove it from the mapping once we find it is fully referenced. if (param_sinfo->IsInstance()) { - partially_used_tuple_params_[expr.get()] = {}; - tuple_param_idx_[expr.get()] = static_cast(arguments_.size()) - 1; + tuple_param_info_[expr.get()].param_index = static_cast(arguments_.size()) - 1; } } } @@ -623,21 +651,47 @@ class FunctionCreator : public ExprMutator { } private: + /* \brief Shadow the ExprMutator::UnwrapBindings + * + * Because the ExprMutator only knows of bindings that have been + * visited, and we do not call VisitBinding until FunctionCreate, we + * cannot use it to unwrap the known binding. Therefore, + * reproducing the same functionality here. + */ + Expr UnwrapBindings(Expr expr) { + while (true) { + auto var = expr.as(); + if (!var) return expr; + + auto it = defined_vars_.find(var.get()); + if (it == defined_vars_.end()) return expr; + + expr = it->second; + } + } + /*! \brief The variables defined in this function */ - std::unordered_set defined_vars_; + std::unordered_map defined_vars_; /*! \brief The number of parameters reserved for constants */ int n_param_for_const_ = 0; /*! \brief The output vars */ std::vector output_vars_; /*! \brief Whether or not to lift bound constants to parameters */ bool lift_constant_; - /*! \brief Mapping from tuple parameter of the function to its position index */ - std::unordered_map tuple_param_idx_; + + struct TupleParamInfo { + std::optional param_index = std::nullopt; + bool is_fully_used{false}; + std::vector used_by_index; + }; + /*! - * \brief Mapping from partially referenced tuple parameter to the list of - * indices that the parameter is referred by TupleGetItem + * \brief Mapping from tuple parameters to collected information about them. + * + * Used to decide whether to pass individual tuple elements to the + * fused function, or entire tuple objects. */ - std::unordered_map> partially_used_tuple_params_; + std::unordered_map tuple_param_info_; }; /*! diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc b/src/relax/transform/rewrite_dataflow_reshape.cc index 8345f3e0b745..6c96e2b99560 100644 --- a/src/relax/transform/rewrite_dataflow_reshape.cc +++ b/src/relax/transform/rewrite_dataflow_reshape.cc @@ -82,8 +82,11 @@ class DataflowReshapeRewriter : public ExprMutator { // vm.builtin.reshape in the VMBuiltinLower pass. auto prim_fn = Downcast(mod_->Lookup(Downcast(call->args[0]))); - auto arg_tuple = Downcast(call->args[1])->fields; - auto used_arg_indices = GetUsedArgsIndices(prim_fn, arg_tuple.size()); + + auto args_expr = call->args[1]; + auto args_sinfo = GetStructInfo(args_expr).as(); + + auto used_arg_indices = GetUsedArgsIndices(prim_fn, args_sinfo->fields.size()); // The number of inputs to call_tir(reshape, (...)) might not be one, since FuseOps // can generate a fused TupleGetItem + reshape function whose input is a tuple. FuseTIR @@ -92,18 +95,32 @@ class DataflowReshapeRewriter : public ExprMutator { if (used_arg_indices.size() != 1) { return GetRef(call); } + size_t arg_index = used_arg_indices[0]; - auto arg = arg_tuple[used_arg_indices[0]]; + auto arg_sinfo = Downcast(args_sinfo->fields[arg_index]); - if (!IsCallingTIRReshape(call, arg)) { + if (!IsCallingTIRReshape(call, arg_sinfo)) { return GetRef(call); } + // Now we know that we're calling a reshape, but we don't yet know + // on what. Ideally, the arguments are either a tuple, or a + // variable that is bound to a known tuple, but we may need to + // fall back to a TupleGetItem. + args_expr = UnwrapBindings(args_expr); + auto arg = [&]() -> Expr { + if (auto known_tuple = args_expr.as()) { + return known_tuple->fields[arg_index]; + } else { + return TupleGetItem(args_expr, arg_index); + } + }(); + TensorStructInfo res_sinfo = Downcast(call->struct_info_.value()); return reshape(arg, res_sinfo->shape.value()); } - bool IsCallingTIRReshape(const CallNode* call, Expr inp) { + bool IsCallingTIRReshape(const CallNode* call, TensorStructInfo inp_sinfo) { const GlobalVar& global_var = Downcast(call->args[0]); const auto* func = mod_->functions.Get(global_var).as(); ICHECK_NOTNULL(func); @@ -115,8 +132,7 @@ class DataflowReshapeRewriter : public ExprMutator { // as the number of elements in the result. There are operators that could have a reshape // pattern that don't meet this requirement (e.g. strided_slice), and they should not be // converted to reshape. - ICHECK(inp->struct_info_.defined() && call->struct_info_.defined()); - TensorStructInfo inp_sinfo = Downcast(inp->struct_info_.value()); + ICHECK(call->struct_info_.defined()); TensorStructInfo res_sinfo = Downcast(call->struct_info_.value()); if (inp_sinfo->IsUnknownDtype() || inp_sinfo->dtype != res_sinfo->dtype) { diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 9ab2ffc60536..c35d82d1a4c3 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -21,7 +21,7 @@ from tvm.ir import structural_equal import tvm.script -from tvm.script import tir as T, relax as R +from tvm.script import tir as T, relax as R, ir as I def test_to_non_dataflow(): @@ -84,8 +84,8 @@ def fvisit(e): def test_call_tir_rewrite(): - @tvm.script.ir_module - class TestCallTIRRewrite: + @I.ir_module + class Before: @T.prim_func def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): T.evaluate(0) @@ -95,32 +95,119 @@ def foo(x: R.Tensor(("m", "n"), "float32")): # we expect RemovePurityChecking to have been used before this point R.func_attr({"relax.force_pure": True}) m, n = T.int64(), T.int64() - gv0 = R.call_tir(TestCallTIRRewrite.exp, (x,), R.Tensor((m, n), dtype="float32")) + gv0 = R.call_tir(Before.exp, (x,), R.Tensor((m, n), dtype="float32")) return gv0 - mod = TestCallTIRRewrite + @I.ir_module + class Expected: + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) - # before rewrite - v0 = mod["foo"].body.blocks[0].bindings[0].var - s0 = mod["foo"].body.blocks[0].bindings[0].value - assert isinstance(s0, relax.Call) - assert s0.op.name == "relax.call_tir" + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + # we expect RemovePurityChecking to have been used before this point + R.func_attr({"relax.force_pure": True}) + m, n = T.int64(), T.int64() + alloc = R.builtin.alloc_tensor(R.shape([m, n]), "float32", R.prim_value(0)) + _ = Expected.exp(x, alloc) + gv0 = alloc + return gv0 - # after rewrite - new_mod = relax.transform.CallTIRRewrite()(mod) - func = new_mod["foo"] + After = relax.transform.CallTIRRewrite()(Before) + tvm.ir.assert_structural_equal(Expected, After) - block = func.body.blocks[0] - assert not isinstance(block, relax.DataflowBlock) - s1 = block.bindings[0].value - assert isinstance(s1, relax.Call) - assert s1.op.name == "relax.builtin.alloc_tensor" - assert isinstance(s1.args[0], relax.ShapeExpr) - assert structural_equal(s1.args[0], s0.sinfo_args[0].shape) - s2 = block.bindings[1].value - tvm.ir.expr.GlobalVar - assert s2.op.name_hint == "exp" +def test_call_tir_rewrite_with_var_tuple(): + @I.ir_module + class Before: + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) + + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + # we expect RemovePurityChecking to have been used before this point + R.func_attr({"relax.force_pure": True}) + m, n = T.int64(), T.int64() + exp_args = (x,) + gv0 = R.call_tir(Before.exp, exp_args, R.Tensor((m, n), dtype="float32")) + return gv0 + + @I.ir_module + class Expected: + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) + + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + R.func_attr({"relax.force_pure": True}) + m, n = T.int64(), T.int64() + exp_args = (x,) + alloc = R.builtin.alloc_tensor(R.shape([m, n]), "float32", R.prim_value(0)) + _ = Expected.exp(x, alloc) + gv0 = alloc + return gv0 + + After = relax.transform.CallTIRRewrite()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_call_tir_rewrite_separate_tuple(): + # if the arguments to call_tir are tuple-typed but not a tuple literal, + # the rewrite should index into the tuple + @tvm.script.ir_module + class TestCallTIRRewrite: + @T.prim_func + def add( + A: T.Buffer((2, 3), "float32"), + B: T.Buffer((2, 3), "float32"), + C: T.Buffer((2, 3), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[ax0, ax1], B[ax0, ax1]) + T.writes(C[ax0, ax1]) + C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1] + + @R.function + def foo(x: R.Tensor((2, 3), "float32")): + # we expect RemovePurityChecking to have been used before this point + R.func_attr({"relax.force_pure": True}) + tup = (x, x) + gv0 = R.call_tir(TestCallTIRRewrite.add, tup, R.Tensor((2, 3), dtype="float32")) + return gv0 + + @tvm.script.ir_module + class Expected: + @T.prim_func + def add( + A: T.Buffer((2, 3), "float32"), + B: T.Buffer((2, 3), "float32"), + C: T.Buffer((2, 3), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[ax0, ax1], B[ax0, ax1]) + T.writes(C[ax0, ax1]) + C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1] + + @R.function + def foo(x: R.Tensor((2, 3), "float32")): + R.func_attr({"relax.force_pure": True}) + tup = (x, x) + alloc = R.builtin.alloc_tensor(R.shape([2, 3]), dtype="float32", runtime_device_index=0) + _ = Expected.add(x, x, alloc) + gv0 = alloc + return gv0 + + new_mod = relax.transform.CallTIRRewrite()(TestCallTIRRewrite) + tvm.ir.assert_structural_equal(new_mod, Expected) def test_transform_remove_purity_checking(): diff --git a/tests/python/relax/test_transform_fold_constant.py b/tests/python/relax/test_transform_fold_constant.py index 9f2e3a4a092d..0cacbab28d04 100644 --- a/tests/python/relax/test_transform_fold_constant.py +++ b/tests/python/relax/test_transform_fold_constant.py @@ -84,6 +84,38 @@ def expected(c1: R.Tensor((16, 16), "float32")): tvm.ir.assert_structural_equal(after, expected) +def test_one_fold_addone_with_arg_tuple(): + """Like test_one_fold_addone, but without an inline tuple""" + + @tvm.script.ir_module + class Module: + @T.prim_func + def addone(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")) -> None: + for i, j in T.grid(16, 16): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + + @R.function + def before(c0: R.Tensor((16, 16), "float32")): + cls = Module + arg_tuple = (c0,) + lv0 = relax.call_tir(cls.addone, arg_tuple, R.Tensor((16, 16), dtype="float32")) + return lv0 + + @R.function + def expected(c1: R.Tensor((16, 16), "float32")): + return c1 + + c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) + c1_np = c0_np + 1 + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + def test_one_fold_transpose(): # put before after in a single module @tvm.script.ir_module diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 1a4a630e3e5a..c60312deda3b 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -1501,5 +1501,47 @@ def main( _check(Module, Expected) +def test_fuse_call_tir_with_var_tuple(): + """Call_tir may contain either inline tuple or a var""" + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.sin, x) + lv1 = bb.emit_te(topi.cos, x) + gv = bb.emit_output(bb.call_te(topi.add, lv0, lv1)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + + with bb.function("fused_sin_cos_add", [x], attrs={"Primitive": 1}, private=True): + with bb.dataflow(): + lv0 = bb.emit_te(topi.sin, x) + lv1 = bb.emit_te(topi.cos, x) + gv = bb.emit_output(bb.call_te(topi.add, lv0, lv1)) + bb.emit_func_output(gv) + fused_sin_cos_add = bb.get().get_global_var("fused_sin_cos_add") + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_sin_cos_add, [x])) + bb.emit_func_output(gv) + + return bb.get() + + before = relax.transform.EliminateCommonSubexpr()(before()) + expected = relax.transform.EliminateCommonSubexpr()(expected()) + + _check(before, expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py index 01065bea21df..4680e51a6888 100644 --- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py +++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py @@ -18,7 +18,7 @@ import tvm import tvm.testing from tvm import relax -from tvm.script import relax as R, tir as T +from tvm.script import relax as R, tir as T, ir as I def test_reshape_expand_dims(): @@ -500,5 +500,76 @@ def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((1,), dtype="float32"): tvm.ir.assert_structural_equal(rewritten, Expected) +def test_reshape_with_tuple_var_as_argument(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def add( + A: T.Buffer((T.int64(1),), "float32"), + B: T.Buffer((T.int64(1),), "float32"), + T_add: T.Buffer((T.int64(1),), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for i in range(T.int64(1)): + with T.block("T_add"): + vi = T.axis.spatial(T.int64(1), i) + T_add[vi] = A[vi] + B[vi] + + @T.prim_func(private=True) + def reshape(A: T.Buffer((), "float32"), T_reshape: T.Buffer((T.int64(1),), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + for i in range(T.int64(1)): + with T.block("T_reshape"): + vi = T.axis.spatial(T.int64(1), i) + T_reshape[vi] = A[()] + + @R.function + def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((1,), dtype="float32"): + cls = Before + with R.dataflow(): + reshape_args = (x,) + y = R.call_tir(cls.reshape, reshape_args, out_sinfo=R.Tensor((1,), dtype="float32")) + add_args = (y, y) + z = R.call_tir(cls.add, add_args, out_sinfo=R.Tensor((1,), dtype="float32")) + R.output(z) + return z + + @tvm.script.ir_module + class Expected: + @T.prim_func(private=True) + def add( + A: T.Buffer((T.int64(1),), "float32"), + B: T.Buffer((T.int64(1),), "float32"), + T_add: T.Buffer((T.int64(1),), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for i in range(T.int64(1)): + with T.block("T_add"): + vi = T.axis.spatial(T.int64(1), i) + T_add[vi] = A[vi] + B[vi] + + @T.prim_func(private=True) + def reshape(A: T.Buffer((), "float32"), T_reshape: T.Buffer((T.int64(1),), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + for i in range(T.int64(1)): + with T.block("T_reshape"): + vi = T.axis.spatial(T.int64(1), i) + T_reshape[vi] = A[()] + + @R.function + def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((1,), dtype="float32"): + cls = Expected + with R.dataflow(): + reshape_args = (x,) + y = R.reshape(x, R.shape([1])) + add_args = (y, y) + z = R.call_tir(cls.add, add_args, out_sinfo=R.Tensor((1,), dtype="float32")) + R.output(z) + return z + + After = relax.transform.RewriteDataflowReshape()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index 4f28c4a47a69..8dfc63878a15 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -231,9 +231,11 @@ def main( ) -> R.Tuple( R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32") ): + # also make sure it works with a tuple bound separately + tup = (x, y, z) res = R.call_tir_inplace( TestCallTIRInplaceE2ESimple.copy, - (x, y, z), + tup, [0, 1, -1], [R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")], ) @@ -303,6 +305,48 @@ def main( tvm.testing.assert_allclose(out.numpy(), expected.numpy(), rtol=1e-7, atol=1e-7) +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_call_tir_reuse_tuple_input(exec_mode): + # read and write from the same tensor + @tvm.script.ir_module + class TestCallTIRTupleInput: + @T.prim_func + def add( + A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32"), C: T.Buffer((2, 3), "int32") + ): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[ax0, ax1], B[ax0, ax1]) + T.writes(C[ax0, ax1]) + C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1] + + @R.function + def main( + x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32") + ) -> R.Tuple(R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")): + tup = (x, y) + res1 = R.call_tir(TestCallTIRTupleInput.add, tup, out_sinfo=R.Tensor((2, 3), "int32")) + res2 = R.call_tir(TestCallTIRTupleInput.add, tup, out_sinfo=R.Tensor((2, 3), "int32")) + return (res1, res2) + + mod = TestCallTIRTupleInput + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + x = tvm.nd.array(np.ones((2, 3)).astype(np.int32)) + y = tvm.nd.array(np.ones((2, 3)).astype(np.int32)) + vm.set_input("main", x, y) + vm.invoke_stateful("main") + out = vm.get_outputs("main") + expected = tvm.nd.array(np.full((2, 3), 2).astype(np.int32)) + tvm.testing.assert_allclose(out[0].numpy(), expected.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(out[1].numpy(), expected.numpy(), rtol=1e-7, atol=1e-7) + + def test_vm_emit_te_extern(exec_mode): if not tvm.get_global_func("tvm.contrib.cblas.matmul", True): print("skip because extern function is not available") @@ -763,7 +807,8 @@ def relax_matmul_tir( ) -> R.Tensor((32, 32), dtype="float32"): cls = TestVMSubFunction with R.dataflow(): - gv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) + tup = (x, w) + gv0 = R.call_tir(cls.tir_matmul, tup, R.Tensor((32, 32), dtype="float32")) R.output(gv0) return gv0