diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 0c4373281323..7f1c3c06e5f5 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -304,6 +304,19 @@ TVM_DLL Array TIRVarsInStructInfo(const StructInfo& sinfo); */ TVM_DLL Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo); +/*! \brief Collect expressions whose usage requires them to be non-negative + * + * Any PrimExpr that is used as a tensor shape, or as an element in a + * ShapeExpr, may not be negative. This utility function can be used + * to generate assertions prior to calling a kernel, or to provide + * assumptions within a kernel that may be useful for simplification. + * + * \param sinfo The struct info to be analyzed + * + * \return A list of non-negative expressions. + */ +TVM_DLL Array CollectNonNegativeExpressions(const StructInfo& sinfo); + /*! * \brief Get the TIR variables that defined in the input function. * The returned list is deduplicated - each TIR variable will appear at most once. diff --git a/python/tvm/relax/analysis/__init__.py b/python/tvm/relax/analysis/__init__.py index 06b4f6432681..592e3bb5db51 100644 --- a/python/tvm/relax/analysis/__init__.py +++ b/python/tvm/relax/analysis/__init__.py @@ -21,6 +21,7 @@ all_global_vars, all_vars, bound_vars, + collect_non_negative_expressions, computable_at_compile_time, contains_impure_call, definable_tir_vars_in_struct_info, diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 83286c09803a..e5ccb67d9da4 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -202,6 +202,33 @@ def definable_tir_vars_in_struct_info(sinfo: StructInfo) -> List[tir.Var]: return _ffi_api.DefinableTIRVarsInStructInfo(sinfo) # type: ignore +def collect_non_negative_expressions(sinfo: StructInfo) -> List[tir.PrimExpr]: + """Collect TIR expressions used in non-negative contexts + + Get TIR variables that are non-negative within the context where + the struct info is used. For example, any expression used as a + tensor shape. + + The returned list is deduplicated - each TIR expression will + appear at most once. The order of the list is in the order of + occurrence within the struct info. + + Parameters + ---------- + sinfo : StructInfo + The struct info object to be analyzed. + + Returns + ------- + ret : List[tir.Var] + + The list of TIR variables that can be defined from the StructInfo + + """ + + return _ffi_api.CollectNonNegativeExpressions(sinfo) # type: ignore + + def defined_symbolic_vars(func: Function) -> List[Var]: """Get the TIR variables that defined in the input function. The returned list is deduplicated - each TIR variable will appear at most once. diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index b1932f9b5d67..e58ad4c928aa 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -1219,6 +1219,51 @@ TVM_REGISTER_GLOBAL("relax.analysis.TIRVarsInStructInfo").set_body_typed(TIRVars TVM_REGISTER_GLOBAL("relax.analysis.DefinableTIRVarsInStructInfo") .set_body_typed(DefinableTIRVarsInStructInfo); +class NonNegativeExpressionCollector : relax::StructInfoVisitor { + public: + static Array Collect(const StructInfo& sinfo) { + NonNegativeExpressionCollector visitor; + visitor(sinfo); + return visitor.expressions_; + } + + private: + void VisitStructInfo_(const TensorStructInfoNode* op) override { + if (op->shape.defined()) { + VisitStructInfo(GetStructInfo(op->shape.value())); + } + } + + void VisitStructInfo_(const PrimStructInfoNode* op) override { + // Unlike the expressions in TensorStructInfo or ShapeStructInfo, + // PrimStructInfo may contain negative values. This override + // prevents calling VisitStructInfoExprField from the default + // StructInfoVisitor implementation. + } + + void VisitStructInfoExprField(const PrimExpr& size_expr) override { + if (auto size_int = size_expr.as(); size_int && size_int->value >= 0) { + // Avoid cluttering the result with non-negative integers + return; + } + + if (!dedup_lookup_.count(size_expr)) { + expressions_.push_back(size_expr); + dedup_lookup_.insert(size_expr); + } + } + + Array expressions_; + std::unordered_set dedup_lookup_; +}; + +Array CollectNonNegativeExpressions(const StructInfo& sinfo) { + return NonNegativeExpressionCollector::Collect(sinfo); +} + +TVM_REGISTER_GLOBAL("relax.analysis.CollectNonNegativeExpressions") + .set_body_typed(CollectNonNegativeExpressions); + class SymbolicVarCollector : public relax::ExprVisitor, public relax::StructInfoVisitor, public tir::ExprVisitor { diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index e01b710df133..dbfaf60fecfc 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -779,7 +779,18 @@ Var ExprMutator::VisitVarDef(const Var& var) { Expr ExprMutator::VisitWithNewScope(const Expr& expr, Optional> params) { ICHECK(expr->IsInstance()) << "Normal form requires all new scope is stored as SeqExpr"; + + PrimExpr constraint = Bool(true); + if (params.defined()) { + auto non_negative_expressions = + CollectNonNegativeExpressions(TupleStructInfo(params.value().Map(GetStructInfo))); + for (const auto& expr : non_negative_expressions) { + constraint = constraint && (expr >= 0); + } + } + builder_->BeginScope(params); + With context(builder_->GetAnalyzer(), constraint); Expr ret = this->VisitExpr(expr); builder_->EndScope(); return ret; diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index 399860987c01..10b026785171 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -33,6 +33,7 @@ #include #include "../op/tensor/linear_algebra.h" +#include "../op/tensor/manipulate.h" namespace tvm { namespace relax { @@ -60,11 +61,34 @@ std::tuple)>> CreateP DFPattern pat_c = WildcardPattern(); auto pat_matmul = IsOp("relax.matmul"); + auto pat_permute_dims = IsOp("relax.permute_dims"); auto pat_matmul_on_lhs = pat_matmul(pat_matmul(pat_a, pat_b), pat_c); auto pat_matmul_on_rhs = pat_matmul(pat_a, pat_matmul(pat_b, pat_c)); - auto pat = pat_matmul_on_lhs | pat_matmul_on_rhs; + auto pat_permuted_matmul_on_lhs = pat_matmul(pat_permute_dims(pat_matmul(pat_b, pat_a)), pat_c); + auto pat_permuted_matmul_on_rhs = pat_matmul(pat_a, pat_permute_dims(pat_matmul(pat_c, pat_b))); + + auto pat = pat_matmul_on_lhs | pat_matmul_on_rhs | pat_permuted_matmul_on_lhs | + pat_permuted_matmul_on_rhs; + + PrimExpr symbolic_var_constraints = Bool(true); + if (auto upper_bounds = func->GetAttr>("tir_var_upper_bound")) { + Map name_lookup; + for (const auto& tir_var : TIRVarsInStructInfo(GetStructInfo(func))) { + name_lookup.Set(tir_var->name_hint, tir_var); + symbolic_var_constraints = symbolic_var_constraints && (0 <= tir_var); + } + + for (const auto& [key, obj_bound] : upper_bounds.value()) { + auto tir_var_name = Downcast(key); + if (auto opt_var = name_lookup.Get(tir_var_name)) { + auto var = opt_var.value(); + auto expr_bound = Downcast(obj_bound); + symbolic_var_constraints = symbolic_var_constraints && (var < expr_bound); + } + } + } auto rewriter = [=](Expr expr, Map matches) -> Expr { auto expr_a = matches[pat_a]; @@ -78,23 +102,6 @@ std::tuple)>> CreateP return expr; } - // If two of the three are compile-time, group those two values - // together, to allow them to be lifted out and pre-computed. - if (is_compile_time(expr_a) && is_compile_time(expr_b)) { - return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, DataType::Void()); - } else if (is_compile_time(expr_b) && is_compile_time(expr_c)) { - return matmul(expr_a, matmul(expr_b, expr_c, DataType::Void()), DataType::Void()); - } - - // Otherwise, select the order that reduces the total number of - // operations required, assuming a naive matmul. - - // Matmul on LHS: ([N,R]*[R,M]) * [M,batch] - // Matmul on RHS: [N,R] * ([R,M]*[M,batch]) - // - // LHS first: `N*R*M + N*M*batch = N*M*(R+batch)` - // RHS first: `N*R*batch + R*M*batch = (N+M)*R*batch` - auto get_shape = [](Expr expr) -> Optional> { auto sinfo = expr->struct_info_.as(); if (sinfo) { @@ -115,6 +122,39 @@ std::tuple)>> CreateP auto shape_b = opt_shape_b.value(); auto shape_c = opt_shape_c.value(); + if (matches.count(pat_permuted_matmul_on_lhs)) { + expr_a = permute_dims(expr_a, NullOpt); + expr_b = permute_dims(expr_b, NullOpt); + CHECK_EQ(shape_a.size(), 2); + CHECK_EQ(shape_b.size(), 2); + shape_a = {shape_a[1], shape_a[0]}; + shape_b = {shape_b[1], shape_b[0]}; + } else if (matches.count(pat_permuted_matmul_on_rhs)) { + expr_b = permute_dims(expr_b, NullOpt); + expr_c = permute_dims(expr_c, NullOpt); + CHECK_EQ(shape_b.size(), 2); + CHECK_EQ(shape_c.size(), 2); + shape_b = {shape_b[1], shape_b[0]}; + shape_c = {shape_c[1], shape_c[0]}; + } + + // If two of the three are compile-time, group those two values + // together, to allow them to be lifted out and pre-computed. + if (is_compile_time(expr_a) && is_compile_time(expr_b)) { + return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, DataType::Void()); + } else if (is_compile_time(expr_b) && is_compile_time(expr_c)) { + return matmul(expr_a, matmul(expr_b, expr_c, DataType::Void()), DataType::Void()); + } + + // Otherwise, select the order that reduces the total number of + // operations required, assuming a naive matmul. + + // Matmul on LHS: ([N,R]*[R,M]) * [M,batch] + // Matmul on RHS: [N,R] * ([R,M]*[M,batch]) + // + // LHS first: `N*R*M + N*M*batch = N*M*(R+batch)` + // RHS first: `N*R*batch + R*M*batch = (N+M)*R*batch` + if (shape_a.size() == 1) { shape_a = {IntImm(shape_a[0].dtype(), 1), shape_a[0]}; } @@ -142,8 +182,13 @@ std::tuple)>> CreateP auto ops_with_rhs_first = (size_M + size_N) * size_R * size_B; arith::Analyzer analyzer; + analyzer.rewrite_simplify.SetEnabledExtensions(static_cast( + analyzer.rewrite_simplify.GetEnabledExtensions() | + arith::RewriteSimplifier::Extension::kComparisonOfProductAndSum)); + With func_attr_constraint(&analyzer, symbolic_var_constraints); With analyzer_constraint( - &analyzer, size_N >= 0 && size_R >= 0 && size_M >= 0 && size_B >= 0); + &analyzer, size_N > 0 && size_R > 0 && size_M > 0 && size_B > 0); + if (analyzer.CanProve(ops_with_lhs_first < ops_with_rhs_first)) { return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, DataType::Void()); } else if (analyzer.CanProve(ops_with_rhs_first < ops_with_lhs_first)) { diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py index b28df7b22441..83b1ddd4fc9e 100644 --- a/tests/python/relax/test_analysis_struct_info_analysis.py +++ b/tests/python/relax/test_analysis_struct_info_analysis.py @@ -24,6 +24,7 @@ from tvm import TVMError from tvm import relax as rx from tvm import tir, ir +from tvm.script import relax as R def test_get_static_type_basic(): @@ -718,5 +719,47 @@ def test_collect_symbolic_var_from_non_tensor_params(param_type, param_order): assert free_vars == set() +def test_collect_nonnegative_expressions(): + @R.function + def func( + A: R.Tensor([1024, "M", "N-2"]), + B: R.Tensor([128, "N", "M+2"]), + C: R.Shape(["M", "N"]), + D: R.Prim(value="N"), + ): + return R.tuple() + + M, N = list(func.params[2].struct_info.values) + + # Expressions are de-duplicated, in order of their first appearance + tvm.ir.assert_structural_equal( + rx.analysis.collect_non_negative_expressions(func.struct_info), + [M, N - 2, N, M + 2], + ) + + # Tensor shapes can imply that their shapes are non-negative + tvm.ir.assert_structural_equal( + rx.analysis.collect_non_negative_expressions(func.params[0].struct_info), + [M, N - 2], + ) + tvm.ir.assert_structural_equal( + rx.analysis.collect_non_negative_expressions(func.params[1].struct_info), + [N, M + 2], + ) + + # ShapeExpr values can imply that their contents are non-negative + tvm.ir.assert_structural_equal( + rx.analysis.collect_non_negative_expressions(func.params[2].struct_info), + [M, N], + ) + + # PrimValue instances may contain negative values, and do not + # imply that their contents are non-negative. + tvm.ir.assert_structural_equal( + rx.analysis.collect_non_negative_expressions(func.params[3].struct_info), + [], + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_adjust_matmul_order.py b/tests/python/relax/test_transform_adjust_matmul_order.py index 8b5a26682a08..5112bf53844b 100644 --- a/tests/python/relax/test_transform_adjust_matmul_order.py +++ b/tests/python/relax/test_transform_adjust_matmul_order.py @@ -347,5 +347,169 @@ def main( Expected = Before +class TestRHSPermuteDims(Base): + """Prefer (x*A)*B instead of x*(A*B) + + Like `TestRHS`, but the weights on the RHS are transposed. + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([16]), + A: R.Tensor([32, 2]), + B: R.Tensor([2, 16]), + ) -> R.Tensor([32]): + linear_weight: R.Tensor([32, 16]) = R.matmul(A, B) + matmul_weight: R.Tensor([16, 32]) = R.permute_dims(linear_weight) + out: R.Tensor([32]) = R.matmul(x, matmul_weight) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([16]), + A: R.Tensor([32, 2]), + B: R.Tensor([2, 16]), + ) -> R.Tensor([32]): + B_transpose = R.permute_dims(B) + x: R.Tensor([2]) = R.matmul(x, B_transpose) + A_transpose = R.permute_dims(A) + x: R.Tensor([32]) = R.matmul(x, A_transpose) + return x + + +class TestRHSPermuteDimsDynamic(Base): + """Prefer (x*A)*B instead of x*(A*B) + + Like `TestRHSPermuteDims`, but the weights on the RHS have a + dynamic shape. + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([16]), + A: R.Tensor([32, "lora_r"]), + B: R.Tensor(["lora_r", 16]), + ) -> R.Tensor([32]): + linear_weight: R.Tensor([32, 16]) = R.matmul(A, B) + matmul_weight: R.Tensor([16, 32]) = R.permute_dims(linear_weight) + out: R.Tensor([32]) = R.matmul(x, matmul_weight) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([16]), + A: R.Tensor([32, "lora_r"]), + B: R.Tensor(["lora_r", 16]), + ) -> R.Tensor([32]): + lora_r = T.int64() + B_transpose = R.permute_dims(B) + x: R.Tensor([lora_r]) = R.matmul(x, B_transpose) + A_transpose = R.permute_dims(A) + x: R.Tensor([32]) = R.matmul(x, A_transpose) + return x + + +class TestRHSPermuteDimsWithDynamicBatch(Base): + """Prefer (x*A)*B instead of x*(A*B) + + Like `TestRHSPermuteDims`, but both the weights on the RHS and the + activations on the LHS have a dynamic dimension. + + Unlike most of the tests for this transform, the + `tir_vars_upper_bound` attribute is required. In order to make a + change, `AdjustMatmulOrder` must first prove that the modified + execution order reduces the number of computations. + + ops_left_to_right = (batch_size + lora_r)*4096*4096 + ops_right_to_left = (4096 + 4096)*batch_size*lora_r + + Without an upper bound on `lora_r`, we cannot prove which of these + is the preferred execution order. With the upper bound, TVM can + determine the preferred order using the following arithmethic + reasoning. + + (batch_size + lora_r)*4096*4096 < (4096 + 4096)*batch_size*lora_r + (batch_size + lora_r)*2048 < batch_size*lora_r + 1/batch_size + 1/lora_r < 1/2048 + + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["batch_size", 4096]), + A: R.Tensor([4096, "lora_r"]), + B: R.Tensor(["lora_r", 4096]), + ) -> R.Tensor(["batch_size", 4096]): + R.func_attr({"tir_var_upper_bound": {"lora_r": 2048}}) + batch_size = T.int64() + linear_weight: R.Tensor([4096, 4096]) = R.matmul(A, B) + matmul_weight: R.Tensor([4096, 4096]) = R.permute_dims(linear_weight) + out: R.Tensor([batch_size, 4096]) = R.matmul(x, matmul_weight) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(["batch_size", 4096]), + A: R.Tensor([4096, "lora_r"]), + B: R.Tensor(["lora_r", 4096]), + ) -> R.Tensor(["batch_size", 4096]): + R.func_attr({"tir_var_upper_bound": {"lora_r": 2048}}) + lora_r = T.int64() + batch_size = T.int64() + B_transpose = R.permute_dims(B) + x: R.Tensor([batch_size, lora_r]) = R.matmul(x, B_transpose) + A_transpose = R.permute_dims(A) + x: R.Tensor([batch_size, 4096]) = R.matmul(x, A_transpose) + return x + + +class TestRHSPermuteDimsDynamicWithSquareMatrix(Base): + """Prefer (x*A)*B instead of x*(A*B) + + Like `TestRHSPermuteDims`, but the weights on the RHS have a + dynamic shape. + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([32]), + A: R.Tensor([32, "lora_r"]), + B: R.Tensor(["lora_r", 32]), + ) -> R.Tensor([32]): + linear_weight: R.Tensor([32, 32]) = R.matmul(A, B) + matmul_weight: R.Tensor([32, 32]) = R.permute_dims(linear_weight) + out: R.Tensor([32]) = R.matmul(x, matmul_weight) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([32]), + A: R.Tensor([32, "lora_r"]), + B: R.Tensor(["lora_r", 32]), + ) -> R.Tensor([32]): + lora_r = T.int64() + B_transpose = R.permute_dims(B) + x: R.Tensor([lora_r]) = R.matmul(x, B_transpose) + A_transpose = R.permute_dims(A) + x: R.Tensor([32]) = R.matmul(x, A_transpose) + return x + + if __name__ == "__main__": tvm.testing.main()