Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions src/tir/analysis/estimate_flops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>

#include "tvm/arith/analyzer.h"

namespace tvm {
namespace tir {

Expand Down Expand Up @@ -84,6 +86,8 @@ struct TResult {

class FlopEstimator : private ExprFunctor<TResult(const PrimExpr& n)>,
private StmtFunctor<TResult(const Stmt& n)> {
arith::Analyzer ana;

public:
TResult VisitExpr(const PrimExpr& expr) override { return ExprFunctor::VisitExpr(expr); }
TResult VisitStmt(const Stmt& stmt) override { return StmtFunctor::VisitStmt(stmt); }
Expand Down Expand Up @@ -112,6 +116,15 @@ class FlopEstimator : private ExprFunctor<TResult(const PrimExpr& n)>,
TResult VisitExpr_(const GTNode* op) override { return TResult(); }
TResult VisitExpr_(const GENode* op) override { return TResult(); }

int64_t GetLoopExtent(const ForNode* node, const arith::Analyzer& ana) {
int64_t bound = ana.const_int_bound(node->extent)->max_value;
if (bound == arith::ConstIntBound::kPosInf) {
return 1; // Analyzer could not determine a valid bound, use 1 instead.
} else {
return bound;
}
}

TResult VisitExpr_(const NotNode* op) override { return VisitExpr(op->a); }
TResult VisitExpr_(const AndNode* op) final {
TResult result = VisitExpr(op->a);
Expand All @@ -138,11 +151,10 @@ class FlopEstimator : private ExprFunctor<TResult(const PrimExpr& n)>,
return result;
}
TResult VisitStmt_(const ForNode* loop) override {
ana.Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
const auto int_imm = GetLoopExtent(loop, ana);
TResult result = VisitStmt(loop->body);
const auto* int_imm = loop->extent.as<IntImmNode>();
ICHECK(int_imm) << "TypeError: Expect the extent of a loop to be IntImm, but gets: "
<< loop->extent->GetTypeKey();
result *= int_imm->value;
result *= int_imm;
return result;
}

Expand Down
28 changes: 22 additions & 6 deletions tests/python/unittest/test_tir_analysis_estimate_tir_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def flops_override(A: T.Buffer(16, "float32")):
A[0] = A[0] + 1


def test_estimate_flops_forloop_as_experssion():
def test_estimate_flops_forloop_as_expression():
flops = estimate_tir_flops(
IRModule({"main": flops_with_forloop_as_expression.with_attr("estimated_flops", 32)})
)
Expand All @@ -102,11 +102,6 @@ def test_estimate_flops_forloop_as_experssion():
assert flops == 32


def test_exception():
with pytest.raises(tvm.TVMError):
flops = estimate_tir_flops(IRModule({"main": flops_with_forloop_as_expression}))


def test_estimate_flops_with_decl_buffer():
def make_func(use_decl_buffer):
buffer_func = T.decl_buffer if use_decl_buffer else T.Buffer
Expand All @@ -124,5 +119,26 @@ def func(A_data: T.handle("float32")):
assert flops_with_decl_buffer == flops_without_decl_buffer


@T.prim_func
def flops_with_nonint_extent(a: T.Buffer(16, "float32")):
for i in range(4 + 4):
a[i] = 2 * a[i]


def test_flops_with_nonint_extent():
assert estimate_tir_flops(IRModule({"main": flops_with_nonint_extent})) == 8


@T.prim_func
def flops_with_variable_extent(a: T.Buffer(16, "float32")):
for i in range(4 + 4):
for j in range(i + 8):
a[j] = 2 * a[i]


def test_flops_with_variable_extent():
assert estimate_tir_flops(IRModule({"main": flops_with_variable_extent})) == 120


if __name__ == "__main__":
tvm.testing.main()