Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,19 @@ TVM_DLL Array<tir::Var> TIRVarsInStructInfo(const StructInfo& sinfo);
*/
TVM_DLL Array<tir::Var> DefinableTIRVarsInStructInfo(const StructInfo& sinfo);

/*! \brief Collect expressions whose usage requires them to be non-negative
*
* Any PrimExpr that is used as a tensor shape, or as an element in a
* ShapeExpr, may not be negative. This utility function can be used
* to generate assertions prior to calling a kernel, or to provide
* assumptions within a kernel that may be useful for simplification.
*
* \param sinfo The struct info to be analyzed
*
* \return A list of non-negative expressions.
*/
TVM_DLL Array<PrimExpr> CollectNonNegativeExpressions(const StructInfo& sinfo);

/*!
* \brief Get the TIR variables that defined in the input function.
* The returned list is deduplicated - each TIR variable will appear at most once.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
all_global_vars,
all_vars,
bound_vars,
collect_non_negative_expressions,
computable_at_compile_time,
contains_impure_call,
definable_tir_vars_in_struct_info,
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/relax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,33 @@ def definable_tir_vars_in_struct_info(sinfo: StructInfo) -> List[tir.Var]:
return _ffi_api.DefinableTIRVarsInStructInfo(sinfo) # type: ignore


def collect_non_negative_expressions(sinfo: StructInfo) -> List[tir.PrimExpr]:
"""Collect TIR expressions used in non-negative contexts

Get TIR variables that are non-negative within the context where
the struct info is used. For example, any expression used as a
tensor shape.

The returned list is deduplicated - each TIR expression will
appear at most once. The order of the list is in the order of
occurrence within the struct info.

Parameters
----------
sinfo : StructInfo
The struct info object to be analyzed.

Returns
-------
ret : List[tir.Var]

The list of TIR variables that can be defined from the StructInfo

"""

return _ffi_api.CollectNonNegativeExpressions(sinfo) # type: ignore


def defined_symbolic_vars(func: Function) -> List[Var]:
"""Get the TIR variables that defined in the input function.
The returned list is deduplicated - each TIR variable will appear at most once.
Expand Down
45 changes: 45 additions & 0 deletions src/relax/analysis/struct_info_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,51 @@ TVM_REGISTER_GLOBAL("relax.analysis.TIRVarsInStructInfo").set_body_typed(TIRVars
TVM_REGISTER_GLOBAL("relax.analysis.DefinableTIRVarsInStructInfo")
.set_body_typed(DefinableTIRVarsInStructInfo);

class NonNegativeExpressionCollector : relax::StructInfoVisitor {
public:
static Array<PrimExpr> Collect(const StructInfo& sinfo) {
NonNegativeExpressionCollector visitor;
visitor(sinfo);
return visitor.expressions_;
}

private:
void VisitStructInfo_(const TensorStructInfoNode* op) override {
if (op->shape.defined()) {
VisitStructInfo(GetStructInfo(op->shape.value()));
}
}

void VisitStructInfo_(const PrimStructInfoNode* op) override {
// Unlike the expressions in TensorStructInfo or ShapeStructInfo,
// PrimStructInfo may contain negative values. This override
// prevents calling VisitStructInfoExprField from the default
// StructInfoVisitor implementation.
}

void VisitStructInfoExprField(const PrimExpr& size_expr) override {
if (auto size_int = size_expr.as<IntImmNode>(); size_int && size_int->value >= 0) {
// Avoid cluttering the result with non-negative integers
return;
}

if (!dedup_lookup_.count(size_expr)) {
expressions_.push_back(size_expr);
dedup_lookup_.insert(size_expr);
}
}

Array<PrimExpr> expressions_;
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> dedup_lookup_;
};

Array<PrimExpr> CollectNonNegativeExpressions(const StructInfo& sinfo) {
return NonNegativeExpressionCollector::Collect(sinfo);
}

TVM_REGISTER_GLOBAL("relax.analysis.CollectNonNegativeExpressions")
.set_body_typed(CollectNonNegativeExpressions);

class SymbolicVarCollector : public relax::ExprVisitor,
public relax::StructInfoVisitor,
public tir::ExprVisitor {
Expand Down
11 changes: 11 additions & 0 deletions src/relax/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,18 @@ Var ExprMutator::VisitVarDef(const Var& var) {
Expr ExprMutator::VisitWithNewScope(const Expr& expr, Optional<Array<Var>> params) {
ICHECK(expr->IsInstance<SeqExprNode>())
<< "Normal form requires all new scope is stored as SeqExpr";

PrimExpr constraint = Bool(true);
if (params.defined()) {
auto non_negative_expressions =
CollectNonNegativeExpressions(TupleStructInfo(params.value().Map(GetStructInfo)));
for (const auto& expr : non_negative_expressions) {
constraint = constraint && (expr >= 0);
}
}

builder_->BeginScope(params);
With<arith::ConstraintContext> context(builder_->GetAnalyzer(), constraint);
Expr ret = this->VisitExpr(expr);
builder_->EndScope();
return ret;
Expand Down
83 changes: 64 additions & 19 deletions src/relax/transform/adjust_matmul_order.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <vector>

#include "../op/tensor/linear_algebra.h"
#include "../op/tensor/manipulate.h"

namespace tvm {
namespace relax {
Expand Down Expand Up @@ -60,11 +61,34 @@ std::tuple<DFPattern, TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)>> CreateP
DFPattern pat_c = WildcardPattern();

auto pat_matmul = IsOp("relax.matmul");
auto pat_permute_dims = IsOp("relax.permute_dims");

auto pat_matmul_on_lhs = pat_matmul(pat_matmul(pat_a, pat_b), pat_c);
auto pat_matmul_on_rhs = pat_matmul(pat_a, pat_matmul(pat_b, pat_c));

auto pat = pat_matmul_on_lhs | pat_matmul_on_rhs;
auto pat_permuted_matmul_on_lhs = pat_matmul(pat_permute_dims(pat_matmul(pat_b, pat_a)), pat_c);
auto pat_permuted_matmul_on_rhs = pat_matmul(pat_a, pat_permute_dims(pat_matmul(pat_c, pat_b)));

auto pat = pat_matmul_on_lhs | pat_matmul_on_rhs | pat_permuted_matmul_on_lhs |
pat_permuted_matmul_on_rhs;

PrimExpr symbolic_var_constraints = Bool(true);
if (auto upper_bounds = func->GetAttr<Map<ObjectRef, ObjectRef>>("tir_var_upper_bound")) {
Map<String, tir::Var> name_lookup;
for (const auto& tir_var : TIRVarsInStructInfo(GetStructInfo(func))) {
name_lookup.Set(tir_var->name_hint, tir_var);
symbolic_var_constraints = symbolic_var_constraints && (0 <= tir_var);
}

for (const auto& [key, obj_bound] : upper_bounds.value()) {
auto tir_var_name = Downcast<String>(key);
if (auto opt_var = name_lookup.Get(tir_var_name)) {
auto var = opt_var.value();
auto expr_bound = Downcast<PrimExpr>(obj_bound);
symbolic_var_constraints = symbolic_var_constraints && (var < expr_bound);
}
}
}

auto rewriter = [=](Expr expr, Map<DFPattern, Expr> matches) -> Expr {
auto expr_a = matches[pat_a];
Expand All @@ -78,23 +102,6 @@ std::tuple<DFPattern, TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)>> CreateP
return expr;
}

// If two of the three are compile-time, group those two values
// together, to allow them to be lifted out and pre-computed.
if (is_compile_time(expr_a) && is_compile_time(expr_b)) {
return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, DataType::Void());
} else if (is_compile_time(expr_b) && is_compile_time(expr_c)) {
return matmul(expr_a, matmul(expr_b, expr_c, DataType::Void()), DataType::Void());
}

// Otherwise, select the order that reduces the total number of
// operations required, assuming a naive matmul.

// Matmul on LHS: ([N,R]*[R,M]) * [M,batch]
// Matmul on RHS: [N,R] * ([R,M]*[M,batch])
//
// LHS first: `N*R*M + N*M*batch = N*M*(R+batch)`
// RHS first: `N*R*batch + R*M*batch = (N+M)*R*batch`

auto get_shape = [](Expr expr) -> Optional<Array<PrimExpr>> {
auto sinfo = expr->struct_info_.as<TensorStructInfoNode>();
if (sinfo) {
Expand All @@ -115,6 +122,39 @@ std::tuple<DFPattern, TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)>> CreateP
auto shape_b = opt_shape_b.value();
auto shape_c = opt_shape_c.value();

if (matches.count(pat_permuted_matmul_on_lhs)) {
expr_a = permute_dims(expr_a, NullOpt);
expr_b = permute_dims(expr_b, NullOpt);
CHECK_EQ(shape_a.size(), 2);
CHECK_EQ(shape_b.size(), 2);
shape_a = {shape_a[1], shape_a[0]};
shape_b = {shape_b[1], shape_b[0]};
} else if (matches.count(pat_permuted_matmul_on_rhs)) {
expr_b = permute_dims(expr_b, NullOpt);
expr_c = permute_dims(expr_c, NullOpt);
CHECK_EQ(shape_b.size(), 2);
CHECK_EQ(shape_c.size(), 2);
shape_b = {shape_b[1], shape_b[0]};
shape_c = {shape_c[1], shape_c[0]};
}

// If two of the three are compile-time, group those two values
// together, to allow them to be lifted out and pre-computed.
if (is_compile_time(expr_a) && is_compile_time(expr_b)) {
return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, DataType::Void());
} else if (is_compile_time(expr_b) && is_compile_time(expr_c)) {
return matmul(expr_a, matmul(expr_b, expr_c, DataType::Void()), DataType::Void());
}

// Otherwise, select the order that reduces the total number of
// operations required, assuming a naive matmul.

// Matmul on LHS: ([N,R]*[R,M]) * [M,batch]
// Matmul on RHS: [N,R] * ([R,M]*[M,batch])
//
// LHS first: `N*R*M + N*M*batch = N*M*(R+batch)`
// RHS first: `N*R*batch + R*M*batch = (N+M)*R*batch`

if (shape_a.size() == 1) {
shape_a = {IntImm(shape_a[0].dtype(), 1), shape_a[0]};
}
Expand Down Expand Up @@ -142,8 +182,13 @@ std::tuple<DFPattern, TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)>> CreateP
auto ops_with_rhs_first = (size_M + size_N) * size_R * size_B;

arith::Analyzer analyzer;
analyzer.rewrite_simplify.SetEnabledExtensions(static_cast<arith::RewriteSimplifier::Extension>(
analyzer.rewrite_simplify.GetEnabledExtensions() |
arith::RewriteSimplifier::Extension::kComparisonOfProductAndSum));
With<arith::ConstraintContext> func_attr_constraint(&analyzer, symbolic_var_constraints);
With<arith::ConstraintContext> analyzer_constraint(
&analyzer, size_N >= 0 && size_R >= 0 && size_M >= 0 && size_B >= 0);
&analyzer, size_N > 0 && size_R > 0 && size_M > 0 && size_B > 0);

if (analyzer.CanProve(ops_with_lhs_first < ops_with_rhs_first)) {
return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, DataType::Void());
} else if (analyzer.CanProve(ops_with_rhs_first < ops_with_lhs_first)) {
Expand Down
43 changes: 43 additions & 0 deletions tests/python/relax/test_analysis_struct_info_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tvm import TVMError
from tvm import relax as rx
from tvm import tir, ir
from tvm.script import relax as R


def test_get_static_type_basic():
Expand Down Expand Up @@ -718,5 +719,47 @@ def test_collect_symbolic_var_from_non_tensor_params(param_type, param_order):
assert free_vars == set()


def test_collect_nonnegative_expressions():
@R.function
def func(
A: R.Tensor([1024, "M", "N-2"]),
B: R.Tensor([128, "N", "M+2"]),
C: R.Shape(["M", "N"]),
D: R.Prim(value="N"),
):
return R.tuple()

M, N = list(func.params[2].struct_info.values)

# Expressions are de-duplicated, in order of their first appearance
tvm.ir.assert_structural_equal(
rx.analysis.collect_non_negative_expressions(func.struct_info),
[M, N - 2, N, M + 2],
)

# Tensor shapes can imply that their shapes are non-negative
tvm.ir.assert_structural_equal(
rx.analysis.collect_non_negative_expressions(func.params[0].struct_info),
[M, N - 2],
)
tvm.ir.assert_structural_equal(
rx.analysis.collect_non_negative_expressions(func.params[1].struct_info),
[N, M + 2],
)

# ShapeExpr values can imply that their contents are non-negative
tvm.ir.assert_structural_equal(
rx.analysis.collect_non_negative_expressions(func.params[2].struct_info),
[M, N],
)

# PrimValue instances may contain negative values, and do not
# imply that their contents are non-negative.
tvm.ir.assert_structural_equal(
rx.analysis.collect_non_negative_expressions(func.params[3].struct_info),
[],
)


if __name__ == "__main__":
tvm.testing.main()
Loading