From a00fe500e41af4049ee0b8a0e5a87ff7cc4e8ff0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 3 Oct 2023 18:21:26 +0000 Subject: [PATCH 01/10] [Unity] Provide LookupBinding utility in ExprVisitor The `ExprMutator` class provides a `LookupBinding` utility for use by subclasses. This commit provides the same functionality to subclasses of `ExprVisitor`. --- include/tvm/relax/expr_functor.h | 22 ++++++++++++++++++++++ src/relax/ir/expr_functor.cc | 2 ++ 2 files changed, 24 insertions(+) diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index ce209ccd460f..f5d46a6010db 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -278,6 +278,20 @@ class ExprVisitor : public ExprFunctor { virtual void VisitSpan(const Span& span); virtual void VisitPrimExpr(const PrimExpr& expr); + /*! + * \brief Look up the value bound to a variable. + * \param var The var to be looked up. + * \return The value bound to the input \p var. + * \note For function parameters, this function returns NullOpt. + */ + inline Optional LookupBinding(const Var& var) { + if (auto it = binding_table_.find(var->vid); it != binding_table_.end()) { + return it->second; + } else { + return NullOpt; + } + } + private: using TSelf = ExprVisitor; using VisitBindingVTable = @@ -308,6 +322,14 @@ class ExprVisitor : public ExprFunctor { // This visitor is not visible to child classes and only // used to supported default visiting behavior. DefaultStructInfoFieldVisitor default_struct_info_field_visitor_{this}; + + /*! \brief A binding table that maps var to value. + * + * Unlike ExprMutator, which can rely on the binding table of the + * internal BlockBuilder, ExprVisitor needs to track the bindings on + * its own. + */ + std::unordered_map binding_table_; }; void PostOrderVisit(const Expr& node, std::function fvisit); diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 14a704d729e3..75e88e00e337 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -71,6 +71,7 @@ void ExprVisitor::VisitBinding_(const VarBindingNode* binding, const OP* value) { \ this->VisitExpr(binding->value); \ this->VisitVarDef(binding->var); \ + this->binding_table_.insert({binding->var->vid, binding->value}); \ } // functions to be overriden. @@ -258,6 +259,7 @@ RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(DataTypeImmNode); void ExprVisitor::VisitBinding_(const MatchCastNode* binding) { this->VisitExpr(binding->value); this->VisitVarDef(binding->var); + this->binding_table_.insert({binding->var->vid, binding->value}); } void ExprVisitor::VisitBindingBlock_(const BindingBlockNode* block) { From 5bd481925e2fa77345d89af207fa72aeca892e87 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 11 Oct 2023 16:57:38 +0000 Subject: [PATCH 02/10] [Unity] Add utility method UnwrapBindings for ExprVisitor/Mutator --- include/tvm/relax/expr_functor.h | 34 ++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index f5d46a6010db..aa217b0b78f9 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -292,6 +292,23 @@ class ExprVisitor : public ExprFunctor { } } + /*! + * \brief Unwrap any known binding + * \param expr The expression to be unwrapped. + * \return The expression after following any known Var bindings + */ + inline Expr UnwrapBindings(Expr expr) { + while (true) { + auto as_var = expr.as(); + if (!as_var) return expr; + + auto bound_expr = LookupBinding(as_var.value()); + if (!bound_expr) return expr; + + expr = bound_expr.value(); + } + } + private: using TSelf = ExprVisitor; using VisitBindingVTable = @@ -534,6 +551,23 @@ class ExprMutator : public ExprMutatorBase { */ Optional LookupBinding(const Var& var); + /*! + * \brief Unwrap any known binding + * \param expr The expression to be unwrapped. + * \return The expression after following any known Var bindings + */ + inline Expr UnwrapBindings(Expr expr) { + while (true) { + auto as_var = expr.as(); + if (!as_var) return expr; + + auto bound_expr = LookupBinding(as_var.value()); + if (!bound_expr) return expr; + + expr = bound_expr.value(); + } + } + /*! * \brief Post-order rewrite a node and normalize. * \tparam T The node type to be rewritten. From 7c426af451c46a0ae97f499804ebaf01f9efa0d8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 3 Oct 2023 19:21:51 +0000 Subject: [PATCH 03/10] [Unity] Handle bound tuple variables in relax.transform.FoldConstant Prior to this commit, the `relax.transform.FoldConstant` pass assumed that the `""relax.call_tir"` builtin had arguments expressed as `relax.Tuple`, and failed if provided with a `relax.Var` that had been bound to a tuple. This commit updates the `FoldConstant` pass to handle variables annotated with `TupleStructInfo`. If the variable's value was determined within the scope of the mutated function, we can look up the bound tuple and find the argument. If the variable's value was produced as output from another function, then we cannot use it in a constant expression, and must leave it as-is. --- python/tvm/relax/op/base.py | 4 +-- src/relax/op/op.cc | 18 +++++------ src/relax/transform/fold_constant.cc | 24 ++++++++++++-- .../relax/test_transform_fold_constant.py | 32 +++++++++++++++++++ 4 files changed, 64 insertions(+), 14 deletions(-) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index b363dc6952d8..7ce190667d11 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -25,7 +25,7 @@ from . import _ffi_api from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar, Var from ..expr import Tuple as RxTuple -from ..struct_info import StructInfo, TensorStructInfo +from ..struct_info import StructInfo, TensorStructInfo, TupleStructInfo from ...ir import PrimExpr from ..utils import args_converter @@ -97,7 +97,7 @@ def call_tir( ret: Call A call node for the call_tir operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore + if isinstance(args, Expr) and not isinstance(args, RxTuple) and not isinstance(args.struct_info_, TupleStructInfo): # type: ignore args = RxTuple((args,)) if not isinstance(out_sinfo, list): diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 01d0d04be0cc..40365a533c75 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -263,7 +263,7 @@ RELAY_REGISTER_OP("relax.call_tir") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FPurity", Bool(true)); -Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, +Expr MakeCallTIR(Expr func, Expr arg_tuple, Array out_sinfo_list, Optional packed_ints) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); @@ -283,9 +283,9 @@ Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, Call call; if (!packed_ints) { // don't use additional optional argument - call = Call(op, {func, args}, {}, {out_sinfo}); + call = Call(op, {func, arg_tuple}, {}, {out_sinfo}); } else { - call = Call(op, {func, args, packed_ints.value()}, {}, {out_sinfo}); + call = Call(op, {func, arg_tuple, packed_ints.value()}, {}, {out_sinfo}); } return call; } @@ -307,7 +307,7 @@ RELAY_REGISTER_OP("relax.call_tir_with_grad") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FPurity", Bool(true)); -Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array out_sinfo_list, +Expr MakeCallTIRWithGrad(Expr func, Expr arg_tuple, Array out_sinfo_list, String te_grad_name, Map te_grad_kwargs, Optional packed_ints) { for (const TensorStructInfo& sinfo : out_sinfo_list) { @@ -333,9 +333,9 @@ Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array out_sinf Call call; if (!packed_ints) { // don't use additional optional argument - call = Call(op, {func, args}, Attrs(attrs), {out_sinfo}); + call = Call(op, {func, arg_tuple}, Attrs(attrs), {out_sinfo}); } else { - call = Call(op, {func, args, packed_ints.value()}, Attrs(attrs), {out_sinfo}); + call = Call(op, {func, arg_tuple, packed_ints.value()}, Attrs(attrs), {out_sinfo}); } return call; } @@ -453,7 +453,7 @@ RELAY_REGISTER_OP("relax.call_tir_inplace") // arguments will no longer be live) .set_attr("FPurity", Bool(true)); -Expr MakeCallTIRInplace(Expr func, Tuple args, Array inplace_indices, +Expr MakeCallTIRInplace(Expr func, Expr arg_tuple, Array inplace_indices, Array out_sinfo_list, Optional packed_ints) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); @@ -476,9 +476,9 @@ Expr MakeCallTIRInplace(Expr func, Tuple args, Array inplace_indices, Call call; if (!packed_ints) { // don't use additional optional argument - call = Call(op, {func, args}, Attrs(attrs), {out_sinfo}); + call = Call(op, {func, arg_tuple}, Attrs(attrs), {out_sinfo}); } else { - call = Call(op, {func, args, packed_ints.value()}, Attrs(attrs), {out_sinfo}); + call = Call(op, {func, arg_tuple, packed_ints.value()}, Attrs(attrs), {out_sinfo}); } return call; } diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index d6da79c484cf..963fb601829b 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -177,9 +177,27 @@ class ConstantFolder : public ExprMutator { // call_tir needs to have at least three arguments ICHECK_GE(call->args.size(), 2); Optional func = MatchPrimFunc(call->args[0]); - ICHECK(call->args[1].as()) << "call_tir.args[1] must be Tuple"; - Optional> arr_args = - MatchConstArrayArgs(call->args[1].as()->fields); + + Expr arg_expr = call->args[1]; + ICHECK(GetStructInfo(arg_expr).as()) + << "call_tir.args[1] must be Tuple" + << ", but instead args[1] was " << arg_expr << " with struct info " + << GetStructInfo(arg_expr); + + // The arguments to call_tir may be a variable bound to a tuple, + // rather than the tuple itself. Unwrap any such variable + // bindings if they are known. + arg_expr = UnwrapBindings(arg_expr); + + // The arguments may still be a variable bound to a tuple, where + // the tuple was produced outside the current function. In that + // case, we cannot fold the expression. + auto arg_tuple = arg_expr.as(); + if (!arg_tuple) { + return NullOpt; + } + + Optional> arr_args = MatchConstArrayArgs(arg_tuple->fields); ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir should have exactly one sinfo arg"; Optional shape = MatchConstShape(call->sinfo_args[0]); bool output_not_tuple = call->sinfo_args.size() == 1; diff --git a/tests/python/relax/test_transform_fold_constant.py b/tests/python/relax/test_transform_fold_constant.py index 9f2e3a4a092d..0cacbab28d04 100644 --- a/tests/python/relax/test_transform_fold_constant.py +++ b/tests/python/relax/test_transform_fold_constant.py @@ -84,6 +84,38 @@ def expected(c1: R.Tensor((16, 16), "float32")): tvm.ir.assert_structural_equal(after, expected) +def test_one_fold_addone_with_arg_tuple(): + """Like test_one_fold_addone, but without an inline tuple""" + + @tvm.script.ir_module + class Module: + @T.prim_func + def addone(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")) -> None: + for i, j in T.grid(16, 16): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + + @R.function + def before(c0: R.Tensor((16, 16), "float32")): + cls = Module + arg_tuple = (c0,) + lv0 = relax.call_tir(cls.addone, arg_tuple, R.Tensor((16, 16), dtype="float32")) + return lv0 + + @R.function + def expected(c1: R.Tensor((16, 16), "float32")): + return c1 + + c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) + c1_np = c0_np + 1 + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + def test_one_fold_transpose(): # put before after in a single module @tvm.script.ir_module From d4f27f2b60cfa2735dfef506dfda3f6673c99258 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 3 Oct 2023 19:27:24 +0000 Subject: [PATCH 04/10] [Unity] Handle tuple variables in relax.transform.FuseOps Prior to this commit, the `relax.transform.FuseOps` pass assumed that the `""relax.call_tir"` builtin had arguments expressed as `relax.Tuple`, and failed if provided with a `relax.Var` that had been bound to a tuple. This commit updates the `FuseOps` pass to unwrap variable bindings prior to downcasting from `relax::Expr` to `relax::Tuple`. --- src/relax/transform/fuse_ops.cc | 176 ++++++++++++------ tests/python/relax/test_transform_fuse_ops.py | 42 +++++ 2 files changed, 157 insertions(+), 61 deletions(-) diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 5cabfc40ca9c..6d13a9b3f793 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -150,12 +150,16 @@ class GraphCreator : public ExprVisitor { } void VisitBinding_(const MatchCastNode* binding) final { + ExprVisitor::VisitBinding_(binding); + IndexedForwardGraph::Node* node = CreateNode(binding->var.get()); SetNodePattern(node, OpPatternKind::kOpaque); AddToPostDFSOrder(node, binding->var.get()); } void VisitBinding_(const VarBindingNode* binding) final { + ExprVisitor::VisitBinding_(binding); + IndexedForwardGraph::Node* node = CreateNode(binding->var.get()); // If the variable is not a dataflow variable, it must be the output variable of this dataflow @@ -169,9 +173,12 @@ class GraphCreator : public ExprVisitor { } else if (const auto* tuple_get_item = binding->value.as()) { // Case 2. The expression is a TupleGetItemNode VisitTupleGetItem(tuple_get_item, node); + } else if (const auto* tuple = binding->value.as()) { + // Case 3. The expression is a Tuple + VisitTuple(tuple, node); } else { VisitUnsupportedNode(binding->value, node); - // Case 3. The type of the expression is not fusion-supported. + // Case 4. The type of the expression is not fusion-supported. // In this case, we skip adding edges, adding an empty node into graph. } AddToPostDFSOrder(node, binding->var.get()); @@ -195,8 +202,16 @@ class GraphCreator : public ExprVisitor { const GlobalVar& global_var = Downcast(call->args[0]); tir::PrimFunc func = Downcast(mod_->Lookup(global_var)); - // Override args for call_tir - args = Downcast(call->args[1])->fields; + // Override args for call_tir, where possible. If the argument + // is a variable bound to a tuple, do not unwrap it, as doing so + // would prevent the tuple object from being exposed to the + // GraphPartitioner. + Expr arg_tuple = call->args[1]; + if (auto ptr = arg_tuple.as()) { + args = ptr->fields; + } else { + args = {arg_tuple}; + } Optional opt_pattern = func->GetAttr("op_pattern"); if (opt_pattern.defined()) { @@ -217,6 +232,13 @@ class GraphCreator : public ExprVisitor { } } + void VisitTuple(const TupleNode* tuple, IndexedForwardGraph::Node* binding_var_node) { + ICHECK_NOTNULL(binding_var_node); + + SetNodePattern(binding_var_node, OpPatternKind::kTuple); + VisitLeaf(GetRef(tuple), binding_var_node, OpPatternKind::kTuple); + } + void VisitTupleGetItem(const TupleGetItemNode* tuple_item, IndexedForwardGraph::Node* binding_var_node) { ICHECK_NOTNULL(binding_var_node); @@ -248,17 +270,11 @@ class GraphCreator : public ExprVisitor { for (const Expr& expr : tuple->fields) { VisitLeaf(expr, binding_var_node, pattern); } - return; - } - - if (!leaf_expr->IsInstance()) { - // Skip GlobalVar, ExternFunc, OpNode. - return; } - auto it = graph_.node_map.find(leaf_expr.get()); IndexedForwardGraph::Node* leaf_node = nullptr; - if (it != graph_.node_map.end()) { + + if (auto it = graph_.node_map.find(leaf_expr.get()); it != graph_.node_map.end()) { leaf_node = it->second; } else if (leaf_expr->IsInstance() || leaf_expr->IsInstance() || leaf_expr->IsInstance() || leaf_expr->IsInstance() || @@ -267,11 +283,11 @@ class GraphCreator : public ExprVisitor { // Since we never fuse constants, the pattern of the constant is set to `kOpaque`. SetNodePattern(leaf_node, OpPatternKind::kOpaque); AddToPostDFSOrder(leaf_node, leaf_expr.get()); - } else { - LOG(FATAL) << "The leaf Expr is supposed to be defined before, but got: " << leaf_expr - << " used before definition."; } - AddEdge(leaf_node, binding_var_node, pattern); + + if (leaf_node) { + AddEdge(leaf_node, binding_var_node, pattern); + } } /********** Helper Functions **********/ @@ -387,14 +403,21 @@ class FunctionCreator : public ExprMutator { if (const auto* var_binding = binding.as()) { if (const auto* call = var_binding->value.as()) { + Array args; + if (call->op == Op::Get("relax.call_tir")) { // Update the name of the function. name_hint_ = name_hint_ + "_" + Downcast(call->args[0])->name_hint; - const Tuple& args = Downcast(call->args[1]); - for (const Expr& arg : args->fields) { - CheckDefAndUpdateParam(arg); - ICHECK(GetStructInfoAs(arg) == nullptr); + // Unwrap a tuple variable to a tuple expression, if + // possible. This treats use of a relax::Var that points to + // a tuple to be treated as a direct usage of the underlying + // relax::Var, preventing duplicate parameters. + Expr arg_tuple = UnwrapBindings(call->args[1]); + if (auto ptr = arg_tuple.as()) { + args = ptr->fields; + } else { + args = {arg_tuple}; } // TODO(tvm-team): handle shape expr } else { @@ -409,27 +432,29 @@ class FunctionCreator : public ExprMutator { } } - for (const Expr& arg : call->args) { - CheckDefAndUpdateParam(arg); - if (GetStructInfoAs(arg) != nullptr) { - // The argument is fully referenced. Thus we remove it from the mapping. - partially_used_tuple_params_.erase(arg.get()); - } + args = call->args; + } + + for (const Expr& arg : args) { + CheckDefAndUpdateParam(arg); + + if (arg->struct_info_.as()) { + // Mark the tuple as being used as a tuple object. Therefore, we'll retain it as a + // tuple parameter. + tuple_param_info_[arg.get()].is_fully_used = true; } } - } else if (var_binding->value.as()) { - const auto* tuple_item = var_binding->value.as(); + + } else if (const auto* tuple_item = var_binding->value.as()) { CheckDefAndUpdateParam(tuple_item->tuple); - if (partially_used_tuple_params_.find(tuple_item->tuple.get()) != - partially_used_tuple_params_.end()) { - // Appending get-item index to the mapping. - partially_used_tuple_params_[tuple_item->tuple.get()].push_back(tuple_item->index); - } + tuple_param_info_[tuple_item->tuple.get()].used_by_index.push_back(tuple_item->index); + } else if (const auto* tuple = var_binding->value.as()) { + tuple_param_info_[tuple].is_fully_used = true; } // Mark the binding variable as defined. - defined_vars_.insert(var_binding->var.get()); + defined_vars_.insert({var_binding->var.get(), var_binding->value}); // Set var as output true if the binding is not a dataflow variable if (!var_binding->var->IsInstance()) { AppendOutput(var_binding->var); @@ -465,27 +490,31 @@ class FunctionCreator : public ExprMutator { // parameters with the parameters of its fields that are accessed in the // function. std::unordered_map> tuple_get_item_remap; - for (auto& [tuple_arg, item_indices] : partially_used_tuple_params_) { - ICHECK(!item_indices.empty()); - int param_idx = tuple_param_idx_[tuple_arg]; - Var param = params_[param_idx]; - String param_name = params_[param_idx]->name_hint(); - TupleStructInfo param_sinfo = Downcast(tuple_arg->struct_info_); - - Array item_args; - Array item_params; - item_args.reserve(item_indices.size()); - item_params.reserve(item_indices.size()); - for (int item_idx : item_indices) { - Var item_param(param_name + "_" + std::to_string(item_idx), param_sinfo->fields[item_idx]); - item_args.push_back(TupleGetItem(GetRef(tuple_arg), item_idx)); - item_params.push_back(item_param); - tuple_get_item_remap[tuple_arg][item_idx] = item_param; + for (const auto& [tuple_arg, tuple_info] : tuple_param_info_) { + if (!tuple_info.is_fully_used && tuple_info.param_index.has_value()) { + const auto& item_indices = tuple_info.used_by_index; + ICHECK(!item_indices.empty()); + int param_idx = tuple_info.param_index.value(); + Var param = params_[param_idx]; + String param_name = params_[param_idx]->name_hint(); + TupleStructInfo param_sinfo = Downcast(tuple_arg->struct_info_); + + Array item_args; + Array item_params; + item_args.reserve(item_indices.size()); + item_params.reserve(item_indices.size()); + for (int item_idx : item_indices) { + Var item_param(param_name + "_" + std::to_string(item_idx), + param_sinfo->fields[item_idx]); + item_args.push_back(TupleGetItem(GetRef(tuple_arg), item_idx)); + item_params.push_back(item_param); + tuple_get_item_remap[tuple_arg][item_idx] = item_param; + } + arguments_.erase(arguments_.begin() + param_idx); + arguments_.insert(arguments_.begin() + param_idx, item_args.begin(), item_args.end()); + params_.erase(params_.begin() + param_idx); + params_.insert(params_.begin() + param_idx, item_params.begin(), item_params.end()); } - arguments_.erase(arguments_.begin() + param_idx); - arguments_.insert(arguments_.begin() + param_idx, item_args.begin(), item_args.end()); - params_.erase(params_.begin() + param_idx); - params_.insert(params_.begin() + param_idx, item_params.begin(), item_params.end()); } // Step 3. Visit each binding and collect outputs one by one. @@ -606,8 +635,7 @@ class FunctionCreator : public ExprMutator { // Mark the tuple parameter is partially referenced in the beginning. // We will remove it from the mapping once we find it is fully referenced. if (param_sinfo->IsInstance()) { - partially_used_tuple_params_[expr.get()] = {}; - tuple_param_idx_[expr.get()] = static_cast(arguments_.size()) - 1; + tuple_param_info_[expr.get()].param_index = static_cast(arguments_.size()) - 1; } } } @@ -623,21 +651,47 @@ class FunctionCreator : public ExprMutator { } private: + /* \brief Shadow the ExprMutator::UnwrapBindings + * + * Because the ExprMutator only knows of bindings that have been + * visited, and we do not call VisitBinding until FunctionCreate, we + * cannot use it to unwrap the known binding. Therefore, + * reproducing the same functionality here. + */ + Expr UnwrapBindings(Expr expr) { + while (true) { + auto var = expr.as(); + if (!var) return expr; + + auto it = defined_vars_.find(var.get()); + if (it == defined_vars_.end()) return expr; + + expr = it->second; + } + } + /*! \brief The variables defined in this function */ - std::unordered_set defined_vars_; + std::unordered_map defined_vars_; /*! \brief The number of parameters reserved for constants */ int n_param_for_const_ = 0; /*! \brief The output vars */ std::vector output_vars_; /*! \brief Whether or not to lift bound constants to parameters */ bool lift_constant_; - /*! \brief Mapping from tuple parameter of the function to its position index */ - std::unordered_map tuple_param_idx_; + + struct TupleParamInfo { + std::optional param_index = std::nullopt; + bool is_fully_used{false}; + std::vector used_by_index; + }; + /*! - * \brief Mapping from partially referenced tuple parameter to the list of - * indices that the parameter is referred by TupleGetItem + * \brief Mapping from tuple parameters to collected information about them. + * + * Used to decide whether to pass individual tuple elements to the + * fused function, or entire tuple objects. */ - std::unordered_map> partially_used_tuple_params_; + std::unordered_map tuple_param_info_; }; /*! diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 1a4a630e3e5a..c60312deda3b 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -1501,5 +1501,47 @@ def main( _check(Module, Expected) +def test_fuse_call_tir_with_var_tuple(): + """Call_tir may contain either inline tuple or a var""" + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.sin, x) + lv1 = bb.emit_te(topi.cos, x) + gv = bb.emit_output(bb.call_te(topi.add, lv0, lv1)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + + with bb.function("fused_sin_cos_add", [x], attrs={"Primitive": 1}, private=True): + with bb.dataflow(): + lv0 = bb.emit_te(topi.sin, x) + lv1 = bb.emit_te(topi.cos, x) + gv = bb.emit_output(bb.call_te(topi.add, lv0, lv1)) + bb.emit_func_output(gv) + fused_sin_cos_add = bb.get().get_global_var("fused_sin_cos_add") + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_sin_cos_add, [x])) + bb.emit_func_output(gv) + + return bb.get() + + before = relax.transform.EliminateCommonSubexpr()(before()) + expected = relax.transform.EliminateCommonSubexpr()(expected()) + + _check(before, expected) + + if __name__ == "__main__": tvm.testing.main() From 36bea9b4dc18e195acdc3d0868bcc782d882db9b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 3 Oct 2023 19:31:03 +0000 Subject: [PATCH 05/10] [Unity] Handle tuple variables in RewriteDataflowReshape Prior to this commit, the `relax.transform.RewriteDataflowReshape` pass assumed that the `""relax.call_tir"` builtin had arguments expressed as `relax.Tuple`, and failed if provided with a `relax.Var` that had been bound to a tuple. This commit updates the `RewriteDataflowReshape` pass to handle variables annotated with `TupleStructInfo`. The identification of a reshape can be done with only the number of arguments, which can be extracted from the variables `TupleStructInfo` instead of requiring a `Tuple`. If the TIR function is a reshape, then the tuple variable can be unwrapped to a known tuple in order to find the argument, or can use a `TupleGetItem` node to extract the argument. --- .../transform/rewrite_dataflow_reshape.cc | 30 ++++++-- ...test_transform_rewrite_dataflow_reshape.py | 73 ++++++++++++++++++- 2 files changed, 95 insertions(+), 8 deletions(-) diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc b/src/relax/transform/rewrite_dataflow_reshape.cc index 8345f3e0b745..6c96e2b99560 100644 --- a/src/relax/transform/rewrite_dataflow_reshape.cc +++ b/src/relax/transform/rewrite_dataflow_reshape.cc @@ -82,8 +82,11 @@ class DataflowReshapeRewriter : public ExprMutator { // vm.builtin.reshape in the VMBuiltinLower pass. auto prim_fn = Downcast(mod_->Lookup(Downcast(call->args[0]))); - auto arg_tuple = Downcast(call->args[1])->fields; - auto used_arg_indices = GetUsedArgsIndices(prim_fn, arg_tuple.size()); + + auto args_expr = call->args[1]; + auto args_sinfo = GetStructInfo(args_expr).as(); + + auto used_arg_indices = GetUsedArgsIndices(prim_fn, args_sinfo->fields.size()); // The number of inputs to call_tir(reshape, (...)) might not be one, since FuseOps // can generate a fused TupleGetItem + reshape function whose input is a tuple. FuseTIR @@ -92,18 +95,32 @@ class DataflowReshapeRewriter : public ExprMutator { if (used_arg_indices.size() != 1) { return GetRef(call); } + size_t arg_index = used_arg_indices[0]; - auto arg = arg_tuple[used_arg_indices[0]]; + auto arg_sinfo = Downcast(args_sinfo->fields[arg_index]); - if (!IsCallingTIRReshape(call, arg)) { + if (!IsCallingTIRReshape(call, arg_sinfo)) { return GetRef(call); } + // Now we know that we're calling a reshape, but we don't yet know + // on what. Ideally, the arguments are either a tuple, or a + // variable that is bound to a known tuple, but we may need to + // fall back to a TupleGetItem. + args_expr = UnwrapBindings(args_expr); + auto arg = [&]() -> Expr { + if (auto known_tuple = args_expr.as()) { + return known_tuple->fields[arg_index]; + } else { + return TupleGetItem(args_expr, arg_index); + } + }(); + TensorStructInfo res_sinfo = Downcast(call->struct_info_.value()); return reshape(arg, res_sinfo->shape.value()); } - bool IsCallingTIRReshape(const CallNode* call, Expr inp) { + bool IsCallingTIRReshape(const CallNode* call, TensorStructInfo inp_sinfo) { const GlobalVar& global_var = Downcast(call->args[0]); const auto* func = mod_->functions.Get(global_var).as(); ICHECK_NOTNULL(func); @@ -115,8 +132,7 @@ class DataflowReshapeRewriter : public ExprMutator { // as the number of elements in the result. There are operators that could have a reshape // pattern that don't meet this requirement (e.g. strided_slice), and they should not be // converted to reshape. - ICHECK(inp->struct_info_.defined() && call->struct_info_.defined()); - TensorStructInfo inp_sinfo = Downcast(inp->struct_info_.value()); + ICHECK(call->struct_info_.defined()); TensorStructInfo res_sinfo = Downcast(call->struct_info_.value()); if (inp_sinfo->IsUnknownDtype() || inp_sinfo->dtype != res_sinfo->dtype) { diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py index 01065bea21df..4680e51a6888 100644 --- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py +++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py @@ -18,7 +18,7 @@ import tvm import tvm.testing from tvm import relax -from tvm.script import relax as R, tir as T +from tvm.script import relax as R, tir as T, ir as I def test_reshape_expand_dims(): @@ -500,5 +500,76 @@ def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((1,), dtype="float32"): tvm.ir.assert_structural_equal(rewritten, Expected) +def test_reshape_with_tuple_var_as_argument(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def add( + A: T.Buffer((T.int64(1),), "float32"), + B: T.Buffer((T.int64(1),), "float32"), + T_add: T.Buffer((T.int64(1),), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for i in range(T.int64(1)): + with T.block("T_add"): + vi = T.axis.spatial(T.int64(1), i) + T_add[vi] = A[vi] + B[vi] + + @T.prim_func(private=True) + def reshape(A: T.Buffer((), "float32"), T_reshape: T.Buffer((T.int64(1),), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + for i in range(T.int64(1)): + with T.block("T_reshape"): + vi = T.axis.spatial(T.int64(1), i) + T_reshape[vi] = A[()] + + @R.function + def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((1,), dtype="float32"): + cls = Before + with R.dataflow(): + reshape_args = (x,) + y = R.call_tir(cls.reshape, reshape_args, out_sinfo=R.Tensor((1,), dtype="float32")) + add_args = (y, y) + z = R.call_tir(cls.add, add_args, out_sinfo=R.Tensor((1,), dtype="float32")) + R.output(z) + return z + + @tvm.script.ir_module + class Expected: + @T.prim_func(private=True) + def add( + A: T.Buffer((T.int64(1),), "float32"), + B: T.Buffer((T.int64(1),), "float32"), + T_add: T.Buffer((T.int64(1),), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for i in range(T.int64(1)): + with T.block("T_add"): + vi = T.axis.spatial(T.int64(1), i) + T_add[vi] = A[vi] + B[vi] + + @T.prim_func(private=True) + def reshape(A: T.Buffer((), "float32"), T_reshape: T.Buffer((T.int64(1),), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + for i in range(T.int64(1)): + with T.block("T_reshape"): + vi = T.axis.spatial(T.int64(1), i) + T_reshape[vi] = A[()] + + @R.function + def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((1,), dtype="float32"): + cls = Expected + with R.dataflow(): + reshape_args = (x,) + y = R.reshape(x, R.shape([1])) + add_args = (y, y) + z = R.call_tir(cls.add, add_args, out_sinfo=R.Tensor((1,), dtype="float32")) + R.output(z) + return z + + After = relax.transform.RewriteDataflowReshape()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": tvm.testing.main() From 7a91d13deb2291f387e33430053e8ad995413872 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 4 Oct 2023 20:51:03 +0000 Subject: [PATCH 06/10] [Unity] Handle tuple vars in CallTIRRewrite --- src/relax/transform/call_tir_rewrite.cc | 84 ++++++++++++++++--------- tests/python/relax/test_transform.py | 77 ++++++++++++++++------- 2 files changed, 110 insertions(+), 51 deletions(-) diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index e040ccea1485..7553ec9f97e2 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -57,7 +57,7 @@ class CallTIRMutator : public ExprMutator { if (call->op == call_tir_op || call->op == call_tir_inplace_op || call->op == call_dps_packed_op) { - bool is_inplace = (call->op == call_tir_inplace_op); + bool is_inplace_op = (call->op == call_tir_inplace_op); const auto* inplace_attrs = call->attrs.as(); Array outs; if (const auto& _tensor_sinfo = MatchStructInfo(expr)) { @@ -65,7 +65,7 @@ class CallTIRMutator : public ExprMutator { const TensorStructInfo& tensor_sinfo = _tensor_sinfo.value(); ICHECK(tensor_sinfo->shape.defined()) << "the TensorStructInfo shape of call_tir has not populated"; - if (!is_inplace) { + if (!is_inplace_op) { outs.push_back( builder_->Emit(Call(alloc_tensor_op, // {Downcast(tensor_sinfo->shape.value()), @@ -93,7 +93,7 @@ class CallTIRMutator : public ExprMutator { ICHECK(field_tensor->shape.defined()) << "call_tir expects all TensorStructInfo has shape, but got " << field_tensor << " as an element of TupleStructInfo"; - if (!is_inplace || inplace_attrs->inplace_indices[i].IntValue() == -1) { + if (!is_inplace_op || inplace_attrs->inplace_indices[i].IntValue() == -1) { outs.push_back( builder_->Emit(Call(alloc_tensor_op, {Downcast(field_tensor->shape.value()), @@ -111,41 +111,69 @@ class CallTIRMutator : public ExprMutator { << expr->struct_info_; } - Array args; - if (call->args[1].as()) { - args = Downcast(call->args[1])->fields; - // for call_tir_inplace, don't reinsert in-place args, only the newly allocated ones - if (!is_inplace) { - args.insert(args.end(), outs.begin(), outs.end()); - } else { - for (size_t i = 0; i < outs.size(); i++) { - if (inplace_attrs->inplace_indices[i].IntValue() == -1) { - args.push_back(outs[i]); - } + Expr callee = call->args[0]; + Expr arg_tuple = call->args[1]; + Optional shape_tuple_of_tir_args = NullOpt; + if (call->args.size() > 2) { + shape_tuple_of_tir_args = call->args[2]; + } + + while (true) { + auto as_var = arg_tuple.as(); + if (!as_var) break; + + auto bound_expr = LookupBinding(as_var.value()); + if (!bound_expr) break; + + arg_tuple = bound_expr.value(); + } + + Array args = [&]() { + if (auto ptr = arg_tuple.as()) { + return ptr->fields; + } else if (auto ptr = arg_tuple->struct_info_.as()) { + size_t n_args = ptr->fields.size(); + Array args; + for (size_t i = 0; i < n_args; i++) { + args.push_back(TupleGetItem(arg_tuple, i)); } + return args; + } else { + LOG(FATAL) << "Lowering of " << call + << " requires knowing how many arguments are passed to the function. " + << "However, the tuple of arguments " << arg_tuple + << " is not itself a tuple, " + << "nor does its struct info " << GetStructInfo(arg_tuple) + << " define the number of arguments."; } + }(); - if (call->args.size() == 2) { - builder_->Emit(Call(call->args[0], args), "_"); - } else { - // unpack semantics - args.push_back(call->args[2]); - builder_->Emit(Call(call_tir_dyn_op, {call->args[0], Tuple(args)}), "_"); + for (size_t i = 0; i < outs.size(); i++) { + bool output_is_inplace = + is_inplace_op && inplace_attrs->inplace_indices[i].IntValue() != -1; + if (!output_is_inplace) { + args.push_back(outs[i]); } - } else { - if (!is_inplace) { - args = outs; - args.insert(args.begin(), call->args[1]); + } + + if (shape_tuple_of_tir_args) { + args.push_back(shape_tuple_of_tir_args.value()); + } + + Expr new_call = [&]() { + if (shape_tuple_of_tir_args) { + return Call(call_tir_dyn_op, {callee, Tuple(args)}); } else { - args.push_back(call->args[1]); + return Call(callee, args); } - builder_->Emit(Call(call->args[0], args), "_"); - } + }(); + builder_->Emit(new_call, "_"); if (outs.size() == 1) { return outs[0]; + } else { + return Tuple(outs); } - return std::move(Tuple(outs)); } return GetRef(call); diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 9ab2ffc60536..ee48bb7154c3 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -21,7 +21,7 @@ from tvm.ir import structural_equal import tvm.script -from tvm.script import tir as T, relax as R +from tvm.script import tir as T, relax as R, ir as I def test_to_non_dataflow(): @@ -84,8 +84,8 @@ def fvisit(e): def test_call_tir_rewrite(): - @tvm.script.ir_module - class TestCallTIRRewrite: + @I.ir_module + class Before: @T.prim_func def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): T.evaluate(0) @@ -95,32 +95,63 @@ def foo(x: R.Tensor(("m", "n"), "float32")): # we expect RemovePurityChecking to have been used before this point R.func_attr({"relax.force_pure": True}) m, n = T.int64(), T.int64() - gv0 = R.call_tir(TestCallTIRRewrite.exp, (x,), R.Tensor((m, n), dtype="float32")) + gv0 = R.call_tir(Before.exp, (x,), R.Tensor((m, n), dtype="float32")) return gv0 - mod = TestCallTIRRewrite + @I.ir_module + class Expected: + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) - # before rewrite - v0 = mod["foo"].body.blocks[0].bindings[0].var - s0 = mod["foo"].body.blocks[0].bindings[0].value - assert isinstance(s0, relax.Call) - assert s0.op.name == "relax.call_tir" + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + # we expect RemovePurityChecking to have been used before this point + R.func_attr({"relax.force_pure": True}) + m, n = T.int64(), T.int64() + alloc = R.builtin.alloc_tensor(R.shape([m, n]), "float32", R.prim_value(0)) + _ = Expected.exp(x, alloc) + gv0 = alloc + return gv0 - # after rewrite - new_mod = relax.transform.CallTIRRewrite()(mod) - func = new_mod["foo"] + After = relax.transform.CallTIRRewrite()(Before) + tvm.ir.assert_structural_equal(Expected, After) - block = func.body.blocks[0] - assert not isinstance(block, relax.DataflowBlock) - s1 = block.bindings[0].value - assert isinstance(s1, relax.Call) - assert s1.op.name == "relax.builtin.alloc_tensor" - assert isinstance(s1.args[0], relax.ShapeExpr) - assert structural_equal(s1.args[0], s0.sinfo_args[0].shape) - s2 = block.bindings[1].value - tvm.ir.expr.GlobalVar - assert s2.op.name_hint == "exp" +def test_call_tir_rewrite_with_var_tuple(): + @I.ir_module + class Before: + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) + + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + # we expect RemovePurityChecking to have been used before this point + R.func_attr({"relax.force_pure": True}) + m, n = T.int64(), T.int64() + exp_args = (x,) + gv0 = R.call_tir(Before.exp, exp_args, R.Tensor((m, n), dtype="float32")) + return gv0 + + @I.ir_module + class Expected: + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) + + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + R.func_attr({"relax.force_pure": True}) + m, n = T.int64(), T.int64() + exp_args = (x,) + alloc = R.builtin.alloc_tensor(R.shape([m, n]), "float32", R.prim_value(0)) + _ = Expected.exp(x, alloc) + gv0 = alloc + return gv0 + + After = relax.transform.CallTIRRewrite()(Before) + tvm.ir.assert_structural_equal(Expected, After) def test_transform_remove_purity_checking(): From a48c9117485005dce22793b76655773c01ee14eb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 12 Oct 2023 15:02:02 +0000 Subject: [PATCH 07/10] lint fix --- python/tvm/relax/op/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 7ce190667d11..1d53b58ac1b5 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -97,7 +97,11 @@ def call_tir( ret: Call A call node for the call_tir operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple) and not isinstance(args.struct_info_, TupleStructInfo): # type: ignore + if ( + isinstance(args, Expr) + and not isinstance(args, RxTuple) + and not isinstance(args.struct_info_, TupleStructInfo) + ): args = RxTuple((args,)) if not isinstance(out_sinfo, list): From fc16d5da7de6e48a1d1f281c22ad3187fa2775e5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 26 Oct 2023 08:14:06 -0500 Subject: [PATCH 08/10] Check for TupleStructInfo in all call_* variants --- python/tvm/relax/op/base.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 1d53b58ac1b5..c62b0221a476 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -66,6 +66,18 @@ def null_value() -> Call: return _ffi_api.null_value() # type: ignore +def _normalize_arg_tuple(args: Expr) -> Expr: + if isinstance(args, RxTuple) or isinstance(args.struct_info_, TupleStructInfo): + # A tuple, or a Var bound to a tuple, are kept as-is + return args + elif isinstance(args, Expr): + # A single argument is wrapped into a tuple + return RxTuple((args,)) + else: + # Anything else is left for the FFI to handle + return args + + @args_converter.auto def call_tir( gvar: GlobalVar, @@ -97,12 +109,7 @@ def call_tir( ret: Call A call node for the call_tir operator. """ - if ( - isinstance(args, Expr) - and not isinstance(args, RxTuple) - and not isinstance(args.struct_info_, TupleStructInfo) - ): - args = RxTuple((args,)) + args = _normalize_arg_tuple(args) if not isinstance(out_sinfo, list): out_sinfo = [out_sinfo] @@ -156,8 +163,7 @@ def call_tir_with_grad( ret: Call A call node for the call_tir_with_grad operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore - args = RxTuple((args,)) + args = _normalize_arg_tuple(args) if not isinstance(out_sinfo, list): out_sinfo = [out_sinfo] @@ -224,8 +230,7 @@ def call_tir_inplace( ret: Call A call node for the call_tir operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore - args = RxTuple((args,)) + args = _normalize_arg_tuple(args) if not isinstance(inplace_indices, list): inplace_indices = [inplace_indices] @@ -279,8 +284,7 @@ def call_dps_packed( if isinstance(func, str): func = ExternFunc(func) - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore - args = RxTuple((args,)) + args = _normalize_arg_tuple(args) if not isinstance(out_sinfo, list): out_sinfo = [out_sinfo] From 4108fb0a37e0079574db2c40cf7aebacce059f96 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 26 Oct 2023 10:40:56 -0500 Subject: [PATCH 09/10] Pull in improvements and unit tests from PR#15971 --- src/relax/op/op.cc | 21 +++++----- src/relax/transform/call_tir_rewrite.cc | 21 +++++++--- tests/python/relax/test_transform.py | 56 +++++++++++++++++++++++++ tests/python/relax/test_vm_build.py | 49 +++++++++++++++++++++- 4 files changed, 129 insertions(+), 18 deletions(-) diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 40365a533c75..4429347893d2 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -364,7 +364,8 @@ StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& c } // check the range for inplace indices, make sure at least one is not -1, ensure they're unique - size_t num_args = Downcast(call->args[1])->fields.size(); + auto arg_sinfo = GetStructInfoAs(call->args[1]); + size_t num_args = arg_sinfo->fields.size(); std::unordered_set encountered; for (size_t i = 0; i < attrs->inplace_indices.size(); i++) { int index = attrs->inplace_indices[i].IntValue(); @@ -391,14 +392,13 @@ StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& c // for safety, we will make sure the output shape for each in-place argument exactly matches the // input shape // TODO(@slyubomirsky): eventually we will want to handle cases where that is not true - Tuple call_args = Downcast(call->args[1]); if (attrs->inplace_indices.size() == 1) { auto* out_sinfo = call->sinfo_args[0].as(); if (!out_sinfo) { ctx->ReportFatal(Diagnostic::Error(call) << "The output struct info must be a tensor"); } - auto* input_sinfo = GetStructInfoAs( - call_args->fields[attrs->inplace_indices[0].IntValue()]); + auto* input_sinfo = + arg_sinfo->fields[attrs->inplace_indices[0].IntValue()].as(); if (!input_sinfo || !input_sinfo->shape.defined() || !CanProveShapeEqual(input_sinfo->shape.value(), out_sinfo->shape.value(), ctx->GetAnalyzer())) { @@ -412,24 +412,23 @@ StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& c } else { auto out_sinfos = call->sinfo_args[0].as()->fields; for (size_t i = 0; i < attrs->inplace_indices.size(); i++) { - if (attrs->inplace_indices[i].IntValue() == -1) { + int inplace_index = attrs->inplace_indices[i].IntValue(); + if (inplace_index == -1) { continue; } auto* out_sinfo = out_sinfos[i].as(); if (!out_sinfo) { ctx->ReportFatal(Diagnostic::Error(call) << "The output struct info must be a tensor"); } - auto* input_sinfo = GetStructInfoAs( - call_args->fields[attrs->inplace_indices[i].IntValue()]); + auto* input_sinfo = arg_sinfo->fields[inplace_index].as(); if (!input_sinfo || !input_sinfo->shape.defined() || !CanProveShapeEqual(input_sinfo->shape.value(), out_sinfo->shape.value(), ctx->GetAnalyzer())) { ctx->ReportFatal(Diagnostic::Error(call) << "The shape of output " << i << " must match that of input " - << attrs->inplace_indices[i].IntValue() << ", whereas we have " - << out_sinfo->shape.value() << " in output " << i << " versus " - << input_sinfo->shape.value() << " in input " - << attrs->inplace_indices[i].IntValue()); + << inplace_index << ", whereas we have " << out_sinfo->shape.value() + << " in output " << i << " versus " << input_sinfo->shape.value() + << " in input " << inplace_index); } } } diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index 7553ec9f97e2..faabf127485d 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -77,14 +77,14 @@ class CallTIRMutator : public ExprMutator { ICHECK(inplace_attrs->inplace_indices[0].IntValue() != -1) << "If calling call_tir_inplace and there is one output, its in-place index must not" " be -1."; - outs.push_back( - Downcast(call->args[1])->fields[inplace_attrs->inplace_indices[0].IntValue()]); + outs.push_back(GetTupleIndex(call->args[1], 0)); } } else if (const auto& _tuple_sinfo = MatchStructInfo(expr)) { // multiple output case const TupleStructInfo& tuple_sinfo = _tuple_sinfo.value(); for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { const auto& field = tuple_sinfo->fields[i]; + int inplace_index = inplace_attrs->inplace_indices[i].IntValue(); ICHECK(field->IsInstance()) << "call_tir expects Tuple of TensorStructInfo, but got " << field @@ -93,7 +93,7 @@ class CallTIRMutator : public ExprMutator { ICHECK(field_tensor->shape.defined()) << "call_tir expects all TensorStructInfo has shape, but got " << field_tensor << " as an element of TupleStructInfo"; - if (!is_inplace_op || inplace_attrs->inplace_indices[i].IntValue() == -1) { + if (!is_inplace_op || inplace_index == -1) { outs.push_back( builder_->Emit(Call(alloc_tensor_op, {Downcast(field_tensor->shape.value()), @@ -101,8 +101,7 @@ class CallTIRMutator : public ExprMutator { Attrs()), "alloc")); } else { - outs.push_back(Downcast(call->args[1]) - ->fields[inplace_attrs->inplace_indices[i].IntValue()]); + outs.push_back(GetTupleIndex(call->args[1], inplace_index)); } } } else { @@ -178,6 +177,18 @@ class CallTIRMutator : public ExprMutator { return GetRef(call); } + + private: + // If e is a tuple literal, return the field denoted by the index. + // Otherwise, insert a tuple get item for that field and return the + // var the result is bound to. + Expr GetTupleIndex(const Expr& e, int index) { + if (const auto* tuple_node = e.as()) { + return tuple_node->fields[index]; + } + auto out = builder_->Emit(TupleGetItem(e, index)); + return out; + } }; Expr CallTIRRewrite(const Expr& e) { return CallTIRMutator().VisitExpr(e); } diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index ee48bb7154c3..c35d82d1a4c3 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -154,6 +154,62 @@ def foo(x: R.Tensor(("m", "n"), "float32")): tvm.ir.assert_structural_equal(Expected, After) +def test_call_tir_rewrite_separate_tuple(): + # if the arguments to call_tir are tuple-typed but not a tuple literal, + # the rewrite should index into the tuple + @tvm.script.ir_module + class TestCallTIRRewrite: + @T.prim_func + def add( + A: T.Buffer((2, 3), "float32"), + B: T.Buffer((2, 3), "float32"), + C: T.Buffer((2, 3), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[ax0, ax1], B[ax0, ax1]) + T.writes(C[ax0, ax1]) + C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1] + + @R.function + def foo(x: R.Tensor((2, 3), "float32")): + # we expect RemovePurityChecking to have been used before this point + R.func_attr({"relax.force_pure": True}) + tup = (x, x) + gv0 = R.call_tir(TestCallTIRRewrite.add, tup, R.Tensor((2, 3), dtype="float32")) + return gv0 + + @tvm.script.ir_module + class Expected: + @T.prim_func + def add( + A: T.Buffer((2, 3), "float32"), + B: T.Buffer((2, 3), "float32"), + C: T.Buffer((2, 3), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[ax0, ax1], B[ax0, ax1]) + T.writes(C[ax0, ax1]) + C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1] + + @R.function + def foo(x: R.Tensor((2, 3), "float32")): + R.func_attr({"relax.force_pure": True}) + tup = (x, x) + alloc = R.builtin.alloc_tensor(R.shape([2, 3]), dtype="float32", runtime_device_index=0) + _ = Expected.add(x, x, alloc) + gv0 = alloc + return gv0 + + new_mod = relax.transform.CallTIRRewrite()(TestCallTIRRewrite) + tvm.ir.assert_structural_equal(new_mod, Expected) + + def test_transform_remove_purity_checking(): @tvm.script.ir_module class Before: diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index 4f28c4a47a69..8dfc63878a15 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -231,9 +231,11 @@ def main( ) -> R.Tuple( R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32") ): + # also make sure it works with a tuple bound separately + tup = (x, y, z) res = R.call_tir_inplace( TestCallTIRInplaceE2ESimple.copy, - (x, y, z), + tup, [0, 1, -1], [R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")], ) @@ -303,6 +305,48 @@ def main( tvm.testing.assert_allclose(out.numpy(), expected.numpy(), rtol=1e-7, atol=1e-7) +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_call_tir_reuse_tuple_input(exec_mode): + # read and write from the same tensor + @tvm.script.ir_module + class TestCallTIRTupleInput: + @T.prim_func + def add( + A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32"), C: T.Buffer((2, 3), "int32") + ): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[ax0, ax1], B[ax0, ax1]) + T.writes(C[ax0, ax1]) + C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1] + + @R.function + def main( + x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32") + ) -> R.Tuple(R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")): + tup = (x, y) + res1 = R.call_tir(TestCallTIRTupleInput.add, tup, out_sinfo=R.Tensor((2, 3), "int32")) + res2 = R.call_tir(TestCallTIRTupleInput.add, tup, out_sinfo=R.Tensor((2, 3), "int32")) + return (res1, res2) + + mod = TestCallTIRTupleInput + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + x = tvm.nd.array(np.ones((2, 3)).astype(np.int32)) + y = tvm.nd.array(np.ones((2, 3)).astype(np.int32)) + vm.set_input("main", x, y) + vm.invoke_stateful("main") + out = vm.get_outputs("main") + expected = tvm.nd.array(np.full((2, 3), 2).astype(np.int32)) + tvm.testing.assert_allclose(out[0].numpy(), expected.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(out[1].numpy(), expected.numpy(), rtol=1e-7, atol=1e-7) + + def test_vm_emit_te_extern(exec_mode): if not tvm.get_global_func("tvm.contrib.cblas.matmul", True): print("skip because extern function is not available") @@ -763,7 +807,8 @@ def relax_matmul_tir( ) -> R.Tensor((32, 32), dtype="float32"): cls = TestVMSubFunction with R.dataflow(): - gv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) + tup = (x, w) + gv0 = R.call_tir(cls.tir_matmul, tup, R.Tensor((32, 32), dtype="float32")) R.output(gv0) return gv0 From 5089f47f8f8f5ac6e109120fb344ff8160101fbb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 30 Oct 2023 08:51:59 -0500 Subject: [PATCH 10/10] Fix breakage caused by incorrect merging with PR#15971 changes --- src/relax/transform/call_tir_rewrite.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index faabf127485d..3b3e97517070 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -84,7 +84,6 @@ class CallTIRMutator : public ExprMutator { const TupleStructInfo& tuple_sinfo = _tuple_sinfo.value(); for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { const auto& field = tuple_sinfo->fields[i]; - int inplace_index = inplace_attrs->inplace_indices[i].IntValue(); ICHECK(field->IsInstance()) << "call_tir expects Tuple of TensorStructInfo, but got " << field @@ -93,7 +92,7 @@ class CallTIRMutator : public ExprMutator { ICHECK(field_tensor->shape.defined()) << "call_tir expects all TensorStructInfo has shape, but got " << field_tensor << " as an element of TupleStructInfo"; - if (!is_inplace_op || inplace_index == -1) { + if (!is_inplace_op || inplace_attrs->inplace_indices[i].IntValue() == -1) { outs.push_back( builder_->Emit(Call(alloc_tensor_op, {Downcast(field_tensor->shape.value()), @@ -101,7 +100,8 @@ class CallTIRMutator : public ExprMutator { Attrs()), "alloc")); } else { - outs.push_back(GetTupleIndex(call->args[1], inplace_index)); + outs.push_back( + GetTupleIndex(call->args[1], inplace_attrs->inplace_indices[i].IntValue())); } } } else {