Skip to content

Commit 9bcaeb0

Browse files
committed
[PASS] UnrollLoop
1 parent d89917b commit 9bcaeb0

8 files changed

Lines changed: 146 additions & 9 deletions

File tree

include/tvm/ir_pass.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ Stmt ConvertSSA(Stmt stmt);
6868
* \param value_map The map of new values.
6969
* \return The converted form.
7070
*/
71-
Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map);
71+
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map);
7272

7373
/*!
7474
* \brief inline all calls of f in stmt.
@@ -97,6 +97,13 @@ Stmt Inline(Stmt stmt,
9797
Stmt StorageFlatten(Stmt stmt,
9898
Map<Tensor, Buffer> extern_buffer);
9999

100+
/*!
101+
* \brief unroll the constant loops
102+
* \param stmt The statment to be unrolled.
103+
* \param max_auto_step The maximum step to stop performing automatic unrolling.
104+
*/
105+
Stmt UnrollLoop(Stmt stmt, int max_auto_step);
106+
100107
/*!
101108
* \brief Make an user callable API LoweredFunc.
102109
*
@@ -153,6 +160,7 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
153160
*/
154161
LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope);
155162

163+
156164
} // namespace ir
157165
} // namespace tvm
158166

include/tvm/runtime/packed_func.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ inline TVMArgValue TVMArgs::operator[](int i) const {
562562
CHECK_LT(i, num_args)
563563
<< "not enough argument passed, "
564564
<< num_args << " passed"
565-
<< "but request arg" << i;
565+
<< " but request arg[" << i << "].";
566566
return TVMArgValue(values[i], type_codes[i]);
567567
}
568568

src/api/api_pass.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <tvm/expr.h>
77
#include <tvm/ir.h>
88
#include <tvm/ir_pass.h>
9+
#include <tvm/ir_visitor.h>
910
#include <tvm/api_registry.h>
1011

1112
namespace tvm {
@@ -29,6 +30,14 @@ TVM_REGISTER_API(_pass_Equal)
2930
}
3031
});
3132

33+
TVM_REGISTER_API(_pass_PostOrderVisit)
34+
.set_body([](TVMArgs args, TVMRetValue *ret) {
35+
PackedFunc f = args[1];
36+
ir::PostOrderVisit(args[0], [f](const NodeRef& n) {
37+
f(n);
38+
});
39+
});
40+
3241
// make from two arguments
3342
#define REGISTER_PASS1(PassName) \
3443
TVM_REGISTER_API(_pass_## PassName) \
@@ -52,6 +61,7 @@ REGISTER_PASS1(ConvertSSA);
5261
REGISTER_PASS1(VerifySSA);
5362
REGISTER_PASS4(Inline);
5463
REGISTER_PASS2(StorageFlatten);
64+
REGISTER_PASS2(UnrollLoop);
5565
REGISTER_PASS2(StorageSync);
5666
REGISTER_PASS4(MakeAPI);
5767
REGISTER_PASS1(SplitHostDevice);

src/pass/inline.cc

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,24 @@ class IRInline : public IRMutator {
2424
if (op->func == f_) {
2525
CHECK_EQ(op->value_index, 0);
2626
Expr expr = body_;
27-
CHECK_EQ(args_.size(), op->args.size())
28-
<< op->args.size() << " vs " << args_.size();
29-
for (size_t i = 0; i < args_.size(); ++i) {
30-
expr = Let::make(args_[i], op->args[i], expr);
27+
CHECK_EQ(args_.size(), op->args.size());
28+
29+
bool has_side_effect = false;
30+
for (size_t i = 0; i < op->args.size(); ++i) {
31+
if (HasSideEffect(op->args[i])) has_side_effect = true;
32+
}
33+
34+
if (has_side_effect) {
35+
for (size_t i = 0; i < args_.size(); ++i) {
36+
expr = Let::make(args_[i], op->args[i], expr);
37+
}
38+
} else {
39+
Map<Var, Expr> vmap;
40+
for (size_t i = 0; i < args_.size(); ++i) {
41+
vmap.Set(args_[i], op->args[i]);
42+
}
43+
expr = Substitute(
44+
Evaluate::make(expr), vmap).as<Evaluate>()->value;
3145
}
3246
return expr;
3347
} else {

src/pass/simple_passes.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ class IRSubstitue : public IRMutator {
4747
std::unordered_map<const Variable*, Expr> smap;
4848
};
4949

50-
Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map) {
50+
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
5151
IRSubstitue m;
5252
for (auto kv : value_map) {
53-
m.smap[kv.first->var.get()] = kv.second;
53+
m.smap[kv.first.get()] = kv.second;
5454
}
5555
return m.Mutate(stmt);
5656
}

src/pass/unroll_loop.cc

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*!
2+
* Copyright (c) 2016 by Contributors
3+
* SSA related checks and pass.
4+
* \file ssa.cc
5+
*/
6+
#include <tvm/ir.h>
7+
#include <tvm/ir_pass.h>
8+
#include <tvm/ir_mutator.h>
9+
#include <unordered_set>
10+
#include <unordered_map>
11+
#include <vector>
12+
#include "../schedule/compute_expr.h"
13+
14+
namespace tvm {
15+
namespace ir {
16+
17+
class LoopUnroller : public IRMutator {
18+
public:
19+
explicit LoopUnroller(int max_auto_step)
20+
: max_auto_step_(max_auto_step) {
21+
}
22+
23+
Stmt Mutate_(const For* op, const Stmt& s) {
24+
Stmt stmt = s;
25+
// constant folding.
26+
Expr extent = ir::Simplify(op->extent);
27+
const IntImm* v1 = extent.as<IntImm>();
28+
const UIntImm* v2 = extent.as<UIntImm>();
29+
int value = -1;
30+
if (v1 != nullptr) {
31+
value = static_cast<int>(v1->value);
32+
}
33+
if (v2 != nullptr) {
34+
value = static_cast<int>(v2->value);
35+
}
36+
bool allow_unroll = value >= 0 && value <= max_auto_step_;
37+
if (op->for_type == ForType::Unrolled) {
38+
CHECK_GE(value, 0)
39+
<< "Cannot unroll non-constant loop";
40+
allow_unroll = true;
41+
}
42+
43+
if (allow_unroll) {
44+
if (value == 0) return Evaluate::make(0);
45+
Stmt body = op->body;
46+
Map<Var, Expr> vmap;
47+
Stmt unrolled;
48+
for (int i = 0; i < value; ++i) {
49+
Var lv(op->loop_var.node_);
50+
vmap.Set(lv,
51+
schedule::ComputeExpr<Add>(
52+
op->min, make_const(op->loop_var.type(), i)));
53+
Stmt step = Substitute(body, vmap);
54+
if (unrolled.defined()) {
55+
unrolled = Block::make(unrolled, step);
56+
} else {
57+
unrolled = step;
58+
}
59+
}
60+
return this->Mutate(unrolled);
61+
} else {
62+
return IRMutator::Mutate_(op, stmt);
63+
}
64+
}
65+
66+
private:
67+
int max_auto_step_;
68+
};
69+
70+
71+
Stmt UnrollLoop(Stmt stmt, int max_auto_step) {
72+
Stmt ret = LoopUnroller(max_auto_step).Mutate(stmt);
73+
return ConvertSSA(ret);
74+
}
75+
76+
} // namespace ir
77+
} // namespace tvm

src/schedule/schedule_ops.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,15 @@ MakeLoopNest(const Stage& sch,
230230
return nest;
231231
}
232232

233+
Stmt Substitute(Stmt s,
234+
const std::unordered_map<IterVar, Expr>& value_map) {
235+
Map<Var, Expr> temp;
236+
for (const auto& kv : value_map) {
237+
temp.Set(kv.first->var, kv.second);
238+
}
239+
return ir::Substitute(s, temp);
240+
}
241+
233242
Stmt MakeLoop(const Stage& s,
234243
const Map<IterVar, Range>& dom_map,
235244
Stmt provide,
@@ -244,7 +253,6 @@ Stmt MakeLoop(const Stage& s,
244253
auto nest = MakeLoopNest(s, dom_map, 0, false,
245254
bound_state, {}, &value_map);
246255

247-
248256
provide = Substitute(provide, value_map);
249257
if (init.defined()) {
250258
// try to find the location to insert the initialization.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import tvm
2+
3+
def test_unroll_loop():
4+
dtype = 'int64'
5+
n = tvm.Var('n')
6+
Ab = tvm.Buffer((n, ), dtype)
7+
i = tvm.Var('i')
8+
j = tvm.Var('j')
9+
# for i in 0 to n-1:
10+
stmt = tvm.make.For(
11+
i, n, 2, 0, 0,
12+
tvm.make.For(j, 0, n, 0, 0,
13+
tvm.make.Store(Ab.data,
14+
tvm.make.Load(dtype, Ab.data, i) + 1,
15+
j + 1)))
16+
stmt = tvm.ir_pass.UnrollLoop(stmt, 8)
17+
print(stmt)
18+
19+
if __name__ == "__main__":
20+
test_unroll_loop()

0 commit comments

Comments
 (0)