From 16891200bfeaeab3fa6e11770d3eb7bc07b20e25 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 6 Dec 2023 17:47:19 +0000 Subject: [PATCH 1/3] [Unity][Transform] Implement UpdateParamStructInfo Provide a convenience method to update parameter struct info, propagating any changes to internal bindings and return type. --- python/tvm/relax/transform/__init__.py | 1 + python/tvm/relax/transform/transform.py | 27 ++++- .../transform/update_param_struct_info.cc | 111 ++++++++++++++++++ ...test_transform_update_param_struct_info.py | 71 +++++++++++ 4 files changed, 209 insertions(+), 1 deletion(-) create mode 100644 src/relax/transform/update_param_struct_info.cc create mode 100644 tests/python/relax/test_transform_update_param_struct_info.py diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 19316c76b83d..3a0460f99f4c 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -67,6 +67,7 @@ StaticPlanBlockMemory, ToMixedPrecision, ToNonDataflow, + UpdateParamStructInfo, UpdateVDevice, VMBuiltinLower, VMShapeLower, diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 9589f661d79e..e4ba4b7d22f3 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -24,7 +24,7 @@ import numpy as np # type: ignore import tvm.ir -from tvm.relax import Expr, Var +from tvm.relax import Expr, Var, StructInfo from tvm.relax.dpl import DFPattern from tvm.runtime import NDArray, Object from tvm.tir import IndexMap, PrimFunc @@ -1224,6 +1224,31 @@ def SplitCallTIRByPattern(patterns: List[PrimFunc], fcodegen: Callable) -> tvm.i return _ffi_api.SplitCallTIRByPattern(patterns, fcodegen) # type: ignore +def UpdateParamStructInfo(sinfo_func: Callable[[Var], Optional[StructInfo]]): + """Update struct info of parameters + + Update struct info of parameters. Internal bindings and function + return type will be updated using relax's struct inference rules. + Errors resulting from struct inference will be propagated to the + user. + + Parameters + ---------- + sinfo_func: Callable[[Var], Optional[StructInfo]] + + A function that is called once for each function parameter, + and returns the updated struct info to be used for it. If the + function returns `None`, the parameter is not modified. + + Returns + ------- + ret : tvm.transform.Pass + The corresponding pass. + + """ + return _ffi_api.UpdateParamStructInfo(sinfo_func) # type: ignore + + def CombineParallelMatmul(check=None): """Combine multiple matmul operators sharing the same LHS matrix into one, followed by slicing. When all matmul branches in a tree have the same set of fused ops, diff --git a/src/relax/transform/update_param_struct_info.cc b/src/relax/transform/update_param_struct_info.cc new file mode 100644 index 000000000000..ef27f23ca84c --- /dev/null +++ b/src/relax/transform/update_param_struct_info.cc @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/update_param_struct_info.cc + * \brief Mutate IRModule to accept new parameters + */ + +#include +#include +#include + +#include +#include +#include +#include + +#include "utils.h" + +namespace tvm { +namespace relax { + +namespace { +class ParamStructInfoMutator : public ExprMutator { + public: + ParamStructInfoMutator(TypedPackedFunc(Var)> sinfo_func) + : sinfo_func_(sinfo_func) {} + + using ExprMutator::VisitExpr_; + using ExprMutator::VisitVarDef_; + + Var VisitVarDef_(const VarNode* op) override { + Var var = ExprMutator::VisitVarDef_(op); + if (!inside_expr_) { + if (auto new_sinfo = sinfo_func_(var)) { + return Var(var->vid, new_sinfo); + } + } + return var; + } + + Expr VisitExpr_(const SeqExprNode* op) override { + bool cache = inside_expr_; + inside_expr_ = true; + auto ret = ExprMutator::VisitExpr_(op); + inside_expr_ = cache; + return ret; + } + + TypedPackedFunc(Var)> sinfo_func_; + bool inside_expr_{false}; +}; +} // namespace + +namespace transform { +Pass UpdateParamStructInfo(TypedPackedFunc(Var)> sinfo_func) { + auto pass_func = [=](IRModule mod, PassContext pc) { + ParamStructInfoMutator mutator(sinfo_func); + + std::unordered_set to_remove; + std::unordered_map to_add; + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + auto updated = Downcast(mutator(func.value())); + if (!updated.same_as(base_func)) { + GlobalVar new_gvar(gvar->name_hint); + UpdateStructInfo(new_gvar, GetStructInfo(updated)); + to_add.insert({new_gvar, updated}); + to_remove.insert(gvar); + } + } + } + + if (to_remove.size() || to_add.size()) { + auto write_ptr = mod.CopyOnWrite(); + + for (const auto& gvar : to_remove) { + write_ptr->Remove(gvar); + } + for (const auto& [gvar, func] : to_add) { + write_ptr->Add(gvar, func); + } + } + + return mod; + }; + return tvm::transform::CreateModulePass(pass_func, 1, "UpdateParamStructInfo", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.UpdateParamStructInfo").set_body_typed(UpdateParamStructInfo); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_update_param_struct_info.py b/tests/python/relax/test_transform_update_param_struct_info.py new file mode 100644 index 000000000000..6680580fe063 --- /dev/null +++ b/tests/python/relax/test_transform_update_param_struct_info.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import inspect +from typing import Optional + +import pytest + +import tvm.testing +from tvm import relax +from tvm.script import ir as I, relax as R + + +class Base: + def test_compare(self): + transform = relax.transform.UpdateParamStructInfo(self.update_sinfo) + + if inspect.isclass(self.Expected) and issubclass(self.Expected, Exception): + with pytest.raises(self.Expected): + transform(self.Before) + else: + after = transform(self.Before) + tvm.ir.assert_structural_equal(self.Expected, after) + + def update_sinfo(self, var: relax.Var) -> Optional[relax.StructInfo]: + """The struct info update function provided to the transform""" + raise NotImplementedError("Should be implemented in derived class") + + +class TestSimple(Base): + def update_sinfo(self, var: relax.Var) -> Optional[relax.StructInfo]: + if var.name_hint == "weight": + return relax.TensorStructInfo([64, 16], "float32") + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([16], "float32"), + weight: R.Tensor([32, 16], "float32"), + ) -> R.Tensor([32], "float32"): + out: R.Tensor([32], "float32") = R.matmul(weight, x) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([16], "float32"), + weight: R.Tensor([64, 16], "float32"), + ) -> R.Tensor([64], "float32"): + out: R.Tensor([64], "float32") = R.matmul(weight, x) + return out + + +if __name__ == "__main__": + tvm.testing.main() From 27d8f132a1419e9087a5759d019b6cdaba2e571e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 2 Jan 2024 18:28:05 +0000 Subject: [PATCH 2/3] lint fix --- src/relax/transform/update_param_struct_info.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relax/transform/update_param_struct_info.cc b/src/relax/transform/update_param_struct_info.cc index ef27f23ca84c..6711f951404b 100644 --- a/src/relax/transform/update_param_struct_info.cc +++ b/src/relax/transform/update_param_struct_info.cc @@ -39,7 +39,7 @@ namespace relax { namespace { class ParamStructInfoMutator : public ExprMutator { public: - ParamStructInfoMutator(TypedPackedFunc(Var)> sinfo_func) + explicit ParamStructInfoMutator(TypedPackedFunc(Var)> sinfo_func) : sinfo_func_(sinfo_func) {} using ExprMutator::VisitExpr_; From 11463edc90076256fb035ebaded9f4a04707ed8e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 3 Jan 2024 15:10:10 +0000 Subject: [PATCH 3/3] Update implementation to update params in relax::Function mutator --- .../transform/update_param_struct_info.cc | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/relax/transform/update_param_struct_info.cc b/src/relax/transform/update_param_struct_info.cc index 6711f951404b..327185fd0bc3 100644 --- a/src/relax/transform/update_param_struct_info.cc +++ b/src/relax/transform/update_param_struct_info.cc @@ -45,26 +45,26 @@ class ParamStructInfoMutator : public ExprMutator { using ExprMutator::VisitExpr_; using ExprMutator::VisitVarDef_; - Var VisitVarDef_(const VarNode* op) override { - Var var = ExprMutator::VisitVarDef_(op); - if (!inside_expr_) { - if (auto new_sinfo = sinfo_func_(var)) { - return Var(var->vid, new_sinfo); + Expr VisitExpr_(const FunctionNode* op) override { + auto func = GetRef(op); + + auto params = op->params.Map([this](Var param) { + if (auto new_sinfo = sinfo_func_(param)) { + auto new_param = WithStructInfo(param, new_sinfo.value()); + var_remap_[param->vid] = new_param; + return new_param; + } else { + return param; } - } - return var; - } + }); - Expr VisitExpr_(const SeqExprNode* op) override { - bool cache = inside_expr_; - inside_expr_ = true; - auto ret = ExprMutator::VisitExpr_(op); - inside_expr_ = cache; - return ret; + if (!params.same_as(func->params)) { + func.CopyOnWrite()->params = params; + } + return ExprMutator::VisitExpr_(func.get()); } TypedPackedFunc(Var)> sinfo_func_; - bool inside_expr_{false}; }; } // namespace