From 517a93d858b8b3eadf130dbd7284438fc10ceeb5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 19 Jan 2024 16:05:46 +0000 Subject: [PATCH] [Unity] Check for transpose and dynamic shape in AdjustMatmulOrder When determining whether to evaluate matrix multiplications as `(A*B)*C` or as `A*(B*C)`, dynamic shapes may occur (e.g. a dynamic LoRA rank). This commit tests for these cases, and improves the arithmetic bounds used to prove which order of evaluation is preferred. As part of the implementation, this commit also adds a utility `CollectNonNegativeExpressions`, exposed to the python API as `relax.analysis.collect_non_negative_expresisons`. This utility collects expressions within a `StructInfo` which must be non-negative, based on the location where they appear. For example, the size of a tensor along each dimension must be non-negative. Unlike the existing `defineable_tir_vars_in_struct_info`, this will include the `N-2` expression in `R.Tensor([N-2])`. --- include/tvm/relax/analysis.h | 13 ++ python/tvm/relax/analysis/__init__.py | 1 + python/tvm/relax/analysis/analysis.py | 27 +++ src/relax/analysis/struct_info_analysis.cc | 45 +++++ src/relax/ir/expr_functor.cc | 11 ++ src/relax/transform/adjust_matmul_order.cc | 83 +++++++-- .../test_analysis_struct_info_analysis.py | 43 +++++ .../test_transform_adjust_matmul_order.py | 164 ++++++++++++++++++ 8 files changed, 368 insertions(+), 19 deletions(-) diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 0c4373281323..7f1c3c06e5f5 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -304,6 +304,19 @@ TVM_DLL Array TIRVarsInStructInfo(const StructInfo& sinfo); */ TVM_DLL Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo); +/*! \brief Collect expressions whose usage requires them to be non-negative + * + * Any PrimExpr that is used as a tensor shape, or as an element in a + * ShapeExpr, may not be negative. This utility function can be used + * to generate assertions prior to calling a kernel, or to provide + * assumptions within a kernel that may be useful for simplification. + * + * \param sinfo The struct info to be analyzed + * + * \return A list of non-negative expressions. + */ +TVM_DLL Array CollectNonNegativeExpressions(const StructInfo& sinfo); + /*! * \brief Get the TIR variables that defined in the input function. * The returned list is deduplicated - each TIR variable will appear at most once. diff --git a/python/tvm/relax/analysis/__init__.py b/python/tvm/relax/analysis/__init__.py index 06b4f6432681..592e3bb5db51 100644 --- a/python/tvm/relax/analysis/__init__.py +++ b/python/tvm/relax/analysis/__init__.py @@ -21,6 +21,7 @@ all_global_vars, all_vars, bound_vars, + collect_non_negative_expressions, computable_at_compile_time, contains_impure_call, definable_tir_vars_in_struct_info, diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 83286c09803a..e5ccb67d9da4 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -202,6 +202,33 @@ def definable_tir_vars_in_struct_info(sinfo: StructInfo) -> List[tir.Var]: return _ffi_api.DefinableTIRVarsInStructInfo(sinfo) # type: ignore +def collect_non_negative_expressions(sinfo: StructInfo) -> List[tir.PrimExpr]: + """Collect TIR expressions used in non-negative contexts + + Get TIR variables that are non-negative within the context where + the struct info is used. For example, any expression used as a + tensor shape. + + The returned list is deduplicated - each TIR expression will + appear at most once. The order of the list is in the order of + occurrence within the struct info. + + Parameters + ---------- + sinfo : StructInfo + The struct info object to be analyzed. + + Returns + ------- + ret : List[tir.Var] + + The list of TIR variables that can be defined from the StructInfo + + """ + + return _ffi_api.CollectNonNegativeExpressions(sinfo) # type: ignore + + def defined_symbolic_vars(func: Function) -> List[Var]: """Get the TIR variables that defined in the input function. The returned list is deduplicated - each TIR variable will appear at most once. diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index b1932f9b5d67..e58ad4c928aa 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -1219,6 +1219,51 @@ TVM_REGISTER_GLOBAL("relax.analysis.TIRVarsInStructInfo").set_body_typed(TIRVars TVM_REGISTER_GLOBAL("relax.analysis.DefinableTIRVarsInStructInfo") .set_body_typed(DefinableTIRVarsInStructInfo); +class NonNegativeExpressionCollector : relax::StructInfoVisitor { + public: + static Array Collect(const StructInfo& sinfo) { + NonNegativeExpressionCollector visitor; + visitor(sinfo); + return visitor.expressions_; + } + + private: + void VisitStructInfo_(const TensorStructInfoNode* op) override { + if (op->shape.defined()) { + VisitStructInfo(GetStructInfo(op->shape.value())); + } + } + + void VisitStructInfo_(const PrimStructInfoNode* op) override { + // Unlike the expressions in TensorStructInfo or ShapeStructInfo, + // PrimStructInfo may contain negative values. This override + // prevents calling VisitStructInfoExprField from the default + // StructInfoVisitor implementation. + } + + void VisitStructInfoExprField(const PrimExpr& size_expr) override { + if (auto size_int = size_expr.as(); size_int && size_int->value >= 0) { + // Avoid cluttering the result with non-negative integers + return; + } + + if (!dedup_lookup_.count(size_expr)) { + expressions_.push_back(size_expr); + dedup_lookup_.insert(size_expr); + } + } + + Array expressions_; + std::unordered_set dedup_lookup_; +}; + +Array CollectNonNegativeExpressions(const StructInfo& sinfo) { + return NonNegativeExpressionCollector::Collect(sinfo); +} + +TVM_REGISTER_GLOBAL("relax.analysis.CollectNonNegativeExpressions") + .set_body_typed(CollectNonNegativeExpressions); + class SymbolicVarCollector : public relax::ExprVisitor, public relax::StructInfoVisitor, public tir::ExprVisitor { diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index e01b710df133..dbfaf60fecfc 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -779,7 +779,18 @@ Var ExprMutator::VisitVarDef(const Var& var) { Expr ExprMutator::VisitWithNewScope(const Expr& expr, Optional> params) { ICHECK(expr->IsInstance()) << "Normal form requires all new scope is stored as SeqExpr"; + + PrimExpr constraint = Bool(true); + if (params.defined()) { + auto non_negative_expressions = + CollectNonNegativeExpressions(TupleStructInfo(params.value().Map(GetStructInfo))); + for (const auto& expr : non_negative_expressions) { + constraint = constraint && (expr >= 0); + } + } + builder_->BeginScope(params); + With context(builder_->GetAnalyzer(), constraint); Expr ret = this->VisitExpr(expr); builder_->EndScope(); return ret; diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index 399860987c01..10b026785171 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -33,6 +33,7 @@ #include #include "../op/tensor/linear_algebra.h" +#include "../op/tensor/manipulate.h" namespace tvm { namespace relax { @@ -60,11 +61,34 @@ std::tuple)>> CreateP DFPattern pat_c = WildcardPattern(); auto pat_matmul = IsOp("relax.matmul"); + auto pat_permute_dims = IsOp("relax.permute_dims"); auto pat_matmul_on_lhs = pat_matmul(pat_matmul(pat_a, pat_b), pat_c); auto pat_matmul_on_rhs = pat_matmul(pat_a, pat_matmul(pat_b, pat_c)); - auto pat = pat_matmul_on_lhs | pat_matmul_on_rhs; + auto pat_permuted_matmul_on_lhs = pat_matmul(pat_permute_dims(pat_matmul(pat_b, pat_a)), pat_c); + auto pat_permuted_matmul_on_rhs = pat_matmul(pat_a, pat_permute_dims(pat_matmul(pat_c, pat_b))); + + auto pat = pat_matmul_on_lhs | pat_matmul_on_rhs | pat_permuted_matmul_on_lhs | + pat_permuted_matmul_on_rhs; + + PrimExpr symbolic_var_constraints = Bool(true); + if (auto upper_bounds = func->GetAttr>("tir_var_upper_bound")) { + Map name_lookup; + for (const auto& tir_var : TIRVarsInStructInfo(GetStructInfo(func))) { + name_lookup.Set(tir_var->name_hint, tir_var); + symbolic_var_constraints = symbolic_var_constraints && (0 <= tir_var); + } + + for (const auto& [key, obj_bound] : upper_bounds.value()) { + auto tir_var_name = Downcast(key); + if (auto opt_var = name_lookup.Get(tir_var_name)) { + auto var = opt_var.value(); + auto expr_bound = Downcast(obj_bound); + symbolic_var_constraints = symbolic_var_constraints && (var < expr_bound); + } + } + } auto rewriter = [=](Expr expr, Map matches) -> Expr { auto expr_a = matches[pat_a]; @@ -78,23 +102,6 @@ std::tuple)>> CreateP return expr; } - // If two of the three are compile-time, group those two values - // together, to allow them to be lifted out and pre-computed. - if (is_compile_time(expr_a) && is_compile_time(expr_b)) { - return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, DataType::Void()); - } else if (is_compile_time(expr_b) && is_compile_time(expr_c)) { - return matmul(expr_a, matmul(expr_b, expr_c, DataType::Void()), DataType::Void()); - } - - // Otherwise, select the order that reduces the total number of - // operations required, assuming a naive matmul. - - // Matmul on LHS: ([N,R]*[R,M]) * [M,batch] - // Matmul on RHS: [N,R] * ([R,M]*[M,batch]) - // - // LHS first: `N*R*M + N*M*batch = N*M*(R+batch)` - // RHS first: `N*R*batch + R*M*batch = (N+M)*R*batch` - auto get_shape = [](Expr expr) -> Optional> { auto sinfo = expr->struct_info_.as(); if (sinfo) { @@ -115,6 +122,39 @@ std::tuple)>> CreateP auto shape_b = opt_shape_b.value(); auto shape_c = opt_shape_c.value(); + if (matches.count(pat_permuted_matmul_on_lhs)) { + expr_a = permute_dims(expr_a, NullOpt); + expr_b = permute_dims(expr_b, NullOpt); + CHECK_EQ(shape_a.size(), 2); + CHECK_EQ(shape_b.size(), 2); + shape_a = {shape_a[1], shape_a[0]}; + shape_b = {shape_b[1], shape_b[0]}; + } else if (matches.count(pat_permuted_matmul_on_rhs)) { + expr_b = permute_dims(expr_b, NullOpt); + expr_c = permute_dims(expr_c, NullOpt); + CHECK_EQ(shape_b.size(), 2); + CHECK_EQ(shape_c.size(), 2); + shape_b = {shape_b[1], shape_b[0]}; + shape_c = {shape_c[1], shape_c[0]}; + } + + // If two of the three are compile-time, group those two values + // together, to allow them to be lifted out and pre-computed. + if (is_compile_time(expr_a) && is_compile_time(expr_b)) { + return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, DataType::Void()); + } else if (is_compile_time(expr_b) && is_compile_time(expr_c)) { + return matmul(expr_a, matmul(expr_b, expr_c, DataType::Void()), DataType::Void()); + } + + // Otherwise, select the order that reduces the total number of + // operations required, assuming a naive matmul. + + // Matmul on LHS: ([N,R]*[R,M]) * [M,batch] + // Matmul on RHS: [N,R] * ([R,M]*[M,batch]) + // + // LHS first: `N*R*M + N*M*batch = N*M*(R+batch)` + // RHS first: `N*R*batch + R*M*batch = (N+M)*R*batch` + if (shape_a.size() == 1) { shape_a = {IntImm(shape_a[0].dtype(), 1), shape_a[0]}; } @@ -142,8 +182,13 @@ std::tuple)>> CreateP auto ops_with_rhs_first = (size_M + size_N) * size_R * size_B; arith::Analyzer analyzer; + analyzer.rewrite_simplify.SetEnabledExtensions(static_cast( + analyzer.rewrite_simplify.GetEnabledExtensions() | + arith::RewriteSimplifier::Extension::kComparisonOfProductAndSum)); + With func_attr_constraint(&analyzer, symbolic_var_constraints); With analyzer_constraint( - &analyzer, size_N >= 0 && size_R >= 0 && size_M >= 0 && size_B >= 0); + &analyzer, size_N > 0 && size_R > 0 && size_M > 0 && size_B > 0); + if (analyzer.CanProve(ops_with_lhs_first < ops_with_rhs_first)) { return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, DataType::Void()); } else if (analyzer.CanProve(ops_with_rhs_first < ops_with_lhs_first)) { diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py index b28df7b22441..83b1ddd4fc9e 100644 --- a/tests/python/relax/test_analysis_struct_info_analysis.py +++ b/tests/python/relax/test_analysis_struct_info_analysis.py @@ -24,6 +24,7 @@ from tvm import TVMError from tvm import relax as rx from tvm import tir, ir +from tvm.script import relax as R def test_get_static_type_basic(): @@ -718,5 +719,47 @@ def test_collect_symbolic_var_from_non_tensor_params(param_type, param_order): assert free_vars == set() +def test_collect_nonnegative_expressions(): + @R.function + def func( + A: R.Tensor([1024, "M", "N-2"]), + B: R.Tensor([128, "N", "M+2"]), + C: R.Shape(["M", "N"]), + D: R.Prim(value="N"), + ): + return R.tuple() + + M, N = list(func.params[2].struct_info.values) + + # Expressions are de-duplicated, in order of their first appearance + tvm.ir.assert_structural_equal( + rx.analysis.collect_non_negative_expressions(func.struct_info), + [M, N - 2, N, M + 2], + ) + + # Tensor shapes can imply that their shapes are non-negative + tvm.ir.assert_structural_equal( + rx.analysis.collect_non_negative_expressions(func.params[0].struct_info), + [M, N - 2], + ) + tvm.ir.assert_structural_equal( + rx.analysis.collect_non_negative_expressions(func.params[1].struct_info), + [N, M + 2], + ) + + # ShapeExpr values can imply that their contents are non-negative + tvm.ir.assert_structural_equal( + rx.analysis.collect_non_negative_expressions(func.params[2].struct_info), + [M, N], + ) + + # PrimValue instances may contain negative values, and do not + # imply that their contents are non-negative. + tvm.ir.assert_structural_equal( + rx.analysis.collect_non_negative_expressions(func.params[3].struct_info), + [], + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_adjust_matmul_order.py b/tests/python/relax/test_transform_adjust_matmul_order.py index 8b5a26682a08..5112bf53844b 100644 --- a/tests/python/relax/test_transform_adjust_matmul_order.py +++ b/tests/python/relax/test_transform_adjust_matmul_order.py @@ -347,5 +347,169 @@ def main( Expected = Before +class TestRHSPermuteDims(Base): + """Prefer (x*A)*B instead of x*(A*B) + + Like `TestRHS`, but the weights on the RHS are transposed. + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([16]), + A: R.Tensor([32, 2]), + B: R.Tensor([2, 16]), + ) -> R.Tensor([32]): + linear_weight: R.Tensor([32, 16]) = R.matmul(A, B) + matmul_weight: R.Tensor([16, 32]) = R.permute_dims(linear_weight) + out: R.Tensor([32]) = R.matmul(x, matmul_weight) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([16]), + A: R.Tensor([32, 2]), + B: R.Tensor([2, 16]), + ) -> R.Tensor([32]): + B_transpose = R.permute_dims(B) + x: R.Tensor([2]) = R.matmul(x, B_transpose) + A_transpose = R.permute_dims(A) + x: R.Tensor([32]) = R.matmul(x, A_transpose) + return x + + +class TestRHSPermuteDimsDynamic(Base): + """Prefer (x*A)*B instead of x*(A*B) + + Like `TestRHSPermuteDims`, but the weights on the RHS have a + dynamic shape. + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([16]), + A: R.Tensor([32, "lora_r"]), + B: R.Tensor(["lora_r", 16]), + ) -> R.Tensor([32]): + linear_weight: R.Tensor([32, 16]) = R.matmul(A, B) + matmul_weight: R.Tensor([16, 32]) = R.permute_dims(linear_weight) + out: R.Tensor([32]) = R.matmul(x, matmul_weight) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([16]), + A: R.Tensor([32, "lora_r"]), + B: R.Tensor(["lora_r", 16]), + ) -> R.Tensor([32]): + lora_r = T.int64() + B_transpose = R.permute_dims(B) + x: R.Tensor([lora_r]) = R.matmul(x, B_transpose) + A_transpose = R.permute_dims(A) + x: R.Tensor([32]) = R.matmul(x, A_transpose) + return x + + +class TestRHSPermuteDimsWithDynamicBatch(Base): + """Prefer (x*A)*B instead of x*(A*B) + + Like `TestRHSPermuteDims`, but both the weights on the RHS and the + activations on the LHS have a dynamic dimension. + + Unlike most of the tests for this transform, the + `tir_vars_upper_bound` attribute is required. In order to make a + change, `AdjustMatmulOrder` must first prove that the modified + execution order reduces the number of computations. + + ops_left_to_right = (batch_size + lora_r)*4096*4096 + ops_right_to_left = (4096 + 4096)*batch_size*lora_r + + Without an upper bound on `lora_r`, we cannot prove which of these + is the preferred execution order. With the upper bound, TVM can + determine the preferred order using the following arithmethic + reasoning. + + (batch_size + lora_r)*4096*4096 < (4096 + 4096)*batch_size*lora_r + (batch_size + lora_r)*2048 < batch_size*lora_r + 1/batch_size + 1/lora_r < 1/2048 + + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["batch_size", 4096]), + A: R.Tensor([4096, "lora_r"]), + B: R.Tensor(["lora_r", 4096]), + ) -> R.Tensor(["batch_size", 4096]): + R.func_attr({"tir_var_upper_bound": {"lora_r": 2048}}) + batch_size = T.int64() + linear_weight: R.Tensor([4096, 4096]) = R.matmul(A, B) + matmul_weight: R.Tensor([4096, 4096]) = R.permute_dims(linear_weight) + out: R.Tensor([batch_size, 4096]) = R.matmul(x, matmul_weight) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(["batch_size", 4096]), + A: R.Tensor([4096, "lora_r"]), + B: R.Tensor(["lora_r", 4096]), + ) -> R.Tensor(["batch_size", 4096]): + R.func_attr({"tir_var_upper_bound": {"lora_r": 2048}}) + lora_r = T.int64() + batch_size = T.int64() + B_transpose = R.permute_dims(B) + x: R.Tensor([batch_size, lora_r]) = R.matmul(x, B_transpose) + A_transpose = R.permute_dims(A) + x: R.Tensor([batch_size, 4096]) = R.matmul(x, A_transpose) + return x + + +class TestRHSPermuteDimsDynamicWithSquareMatrix(Base): + """Prefer (x*A)*B instead of x*(A*B) + + Like `TestRHSPermuteDims`, but the weights on the RHS have a + dynamic shape. + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([32]), + A: R.Tensor([32, "lora_r"]), + B: R.Tensor(["lora_r", 32]), + ) -> R.Tensor([32]): + linear_weight: R.Tensor([32, 32]) = R.matmul(A, B) + matmul_weight: R.Tensor([32, 32]) = R.permute_dims(linear_weight) + out: R.Tensor([32]) = R.matmul(x, matmul_weight) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([32]), + A: R.Tensor([32, "lora_r"]), + B: R.Tensor(["lora_r", 32]), + ) -> R.Tensor([32]): + lora_r = T.int64() + B_transpose = R.permute_dims(B) + x: R.Tensor([lora_r]) = R.matmul(x, B_transpose) + A_transpose = R.permute_dims(A) + x: R.Tensor([32]) = R.matmul(x, A_transpose) + return x + + if __name__ == "__main__": tvm.testing.main()