Skip to content

Commit 2b3d2e2

Browse files
committed
[PASS] Improve GraphFuse to include five patterns (apache#26)
1 parent 2e9b6b9 commit 2b3d2e2

9 files changed

Lines changed: 162 additions & 46 deletions

File tree

nnvm/docs/top.rst

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
NNVM Core Primitives
2-
====================
1+
NNVM Core Tensor Operators
2+
==========================
33

4-
**Level 1: Basic Ops**
4+
**Level 1: Basic Operators**
5+
This level enables fully connected multi-layer perceptron.
56

67
.. autosummary::
78
:nosignatures:
@@ -12,12 +13,14 @@ NNVM Core Primitives
1213
nnvm.symbol.sigmoid
1314
nnvm.symbol.exp
1415
nnvm.symbol.log
16+
nnvm.symbol.sqrt
1517
nnvm.symbol.elemwise_add
1618
nnvm.symbol.elemwise_sub
1719
nnvm.symbol.elemwise_mul
1820
nnvm.symbol.elemwise_div
1921
nnvm.symbol.flatten
2022
nnvm.symbol.concatenate
23+
nnvm.symbol.expand_dims
2124
nnvm.symbol.split
2225
nnvm.symbol.dropout
2326
nnvm.symbol.batch_norm
@@ -27,6 +30,8 @@ NNVM Core Primitives
2730

2831
**Level 2: Convolutions**
2932

33+
This level enables typical convnet models.
34+
3035
.. autosummary::
3136
:nosignatures:
3237

@@ -78,12 +83,14 @@ NNVM Core Primitives
7883
.. autofunction:: nnvm.symbol.sigmoid
7984
.. autofunction:: nnvm.symbol.exp
8085
.. autofunction:: nnvm.symbol.log
86+
.. autofunction:: nnvm.symbol.sqrt
8187
.. autofunction:: nnvm.symbol.elemwise_add
8288
.. autofunction:: nnvm.symbol.elemwise_sub
8389
.. autofunction:: nnvm.symbol.elemwise_mul
8490
.. autofunction:: nnvm.symbol.elemwise_div
8591
.. autofunction:: nnvm.symbol.flatten
8692
.. autofunction:: nnvm.symbol.concatenate
93+
.. autofunction:: nnvm.symbol.expand_dims
8794
.. autofunction:: nnvm.symbol.split
8895
.. autofunction:: nnvm.symbol.dropout
8996
.. autofunction:: nnvm.symbol.batch_norm

nnvm/include/nnvm/compiler/op_attr_types.h

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,23 @@ using ::tvm::Tensor;
2525
using ::tvm::Schedule;
2626

2727
/*! \brief operator pattern used in graph fusion */
28-
enum OpPatternKind : int {
28+
enum OpPatternKind {
2929
// Elementwise operation
3030
kElemWise = 0,
31-
// Broadcast operation
31+
// Broadcasting operator, can always map output axis to the input in order.
32+
// for example :code:`out[i, ax1, j, ax2] = input[i, j]`.
33+
// Note that the axis need to be in order so transpose is not a bcast operator.
3234
kBroadcast = 1,
33-
// Complex operation, can fuse bcast in input/outputs
35+
// Injective operator, can always injectively map output axis to a single input axis.
36+
// All injective operator can still be safely fused to injective and reduction.
37+
kInjective = 2,
38+
// Communicative reduction operator.
39+
kCommReduce = 3,
40+
// Complex operation, can still fuse elemwise operations into its output.
3441
// but cannot chain another complex op
35-
kComplex = 2,
36-
// Extern operation, cannot fuse anything.
37-
kExtern = 3
42+
kOutEWiseFusable = 4,
43+
// Opaque operation, cannot fuse anything.
44+
kOpaque = 8
3845
};
3946

4047
/*! \brief the operator pattern */

nnvm/python/nnvm/compiler/registry.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,24 @@
33
import tvm
44

55
class OpPattern(object):
6-
ELEM_WISE = 0
6+
"""Operator generic patterns
7+
8+
See Also
9+
--------
10+
top.tag : Contains explaination of the tag type.
11+
"""
12+
# Elementwise operator
13+
ELEMWISE = 0
14+
# Broadcast operator
715
BROADCAST = 1
8-
# Complex means we can fuse elemwise to it
9-
COMPLEX = 2
10-
# Extern means the op is not fusable
11-
EXTERN = 3
16+
# Injective mapping
17+
INJECTIVE = 2
18+
# Comunication
19+
COMM_REDUCE = 3
20+
# Complex op, can still fuse ewise into it
21+
OUT_ELEMWISE_FUSABLE = 4
22+
# Not fusable opaque op
23+
OPAQUE = 8
1224

1325
_register_compute = tvm.get_global_func("nnvm._register_compute")
1426
_register_schedule = tvm.get_global_func("nnvm._register_schedule")

nnvm/python/nnvm/top/nn.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,16 @@ def compute_relu(attrs, inputs, _):
1616
return topi.nn.relu(inputs[0])
1717

1818
reg.register_schedule("relu", _fschedule_broadcast)
19-
reg.register_pattern("relu", OpPattern.ELEM_WISE)
19+
reg.register_pattern("relu", OpPattern.ELEMWISE)
2020

21+
# leaky_relu
22+
@reg.register_compute("leaky_relu")
23+
def compute_relu(attrs, inputs, _):
24+
"""Compute definition of relu"""
25+
return topi.nn.leaky_relu(inputs[0])
26+
27+
reg.register_schedule("leaky_relu", _fschedule_broadcast)
28+
reg.register_pattern("leaky_relu", OpPattern.ELEMWISE)
2129

2230
# flatten
2331
@reg.register_compute("flatten")
@@ -26,7 +34,7 @@ def compute_flatten(attrs, inputs, _):
2634
return topi.nn.flatten(inputs[0])
2735

2836
reg.register_schedule("flatten", _fschedule_broadcast)
29-
reg.register_pattern("flatten", OpPattern.COMPLEX)
37+
reg.register_pattern("flatten", OpPattern.INJECTIVE)
3038

3139

3240
# softmax
@@ -46,7 +54,7 @@ def schedule_softmax(_, outs, target):
4654
return tvm.create_schedule([x.op for x in outs])
4755

4856
# Mark softmax as extern as we do not fuse it in call cases
49-
reg.register_pattern("softmax", OpPattern.EXTERN)
57+
reg.register_pattern("softmax", OpPattern.OPAQUE)
5058

5159

5260
# dense
@@ -67,7 +75,7 @@ def schedule_dense(_, outs, target):
6775
return tvm.create_schedule([x.op for x in outs])
6876

6977
# register extern for now, change me when fusion is enabled.
70-
reg.register_pattern("dense", OpPattern.EXTERN)
78+
reg.register_pattern("dense", OpPattern.OPAQUE)
7179

7280

7381
# conv
@@ -105,4 +113,4 @@ def schedule_conv2d(attrs, outs, target):
105113
# naive schedule
106114
return tvm.create_schedule([x.op for x in outs])
107115

108-
reg.register_pattern("conv2d", OpPattern.COMPLEX)
116+
reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)

nnvm/python/nnvm/top/tensor.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
from ..compiler import registry as reg
99
from ..compiler import OpPattern
1010

11-
def _schedule_broadcast(_, outs, target):
11+
def _schedule_injective(_, outs, target):
1212
"""Generic schedule for binary bcast"""
1313
if target == "cuda":
14-
return topi.cuda.schedule_elemwise(outs)
14+
return topi.cuda.schedule_injective(outs)
1515
assert target.startswith("llvm")
1616
s = tvm.create_schedule([x.op for x in outs])
17+
x = outs[0]
1718
tvm.schedule.AutoInlineInjective(s)
19+
s[x].fuse(s[x].op.axis)
1820
return s
1921

2022
def _compute_binary_scalar(f):
@@ -42,89 +44,91 @@ def _compute(attrs, x, _):
4244
return _compute
4345

4446

45-
_fschedule_broadcast = tvm.convert(_schedule_broadcast)
47+
_fschedule_injective = tvm.convert(_schedule_injective)
48+
_fschedule_broadcast = _fschedule_injective
49+
_fschedule_elemwise = _fschedule_injective
4650

4751
# copy
4852
reg.register_compute("copy", _compute_unary(topi.identity))
49-
reg.register_pattern("copy", OpPattern.ELEM_WISE)
53+
reg.register_pattern("copy", OpPattern.ELEMWISE)
5054
reg.register_schedule("copy", _fschedule_broadcast)
5155

5256
# exp
5357
reg.register_compute("exp", _compute_unary(topi.exp))
54-
reg.register_pattern("exp", OpPattern.ELEM_WISE)
58+
reg.register_pattern("exp", OpPattern.ELEMWISE)
5559
reg.register_schedule("exp", _fschedule_broadcast)
5660

5761
# sqrt
5862
reg.register_compute("sqrt", _compute_unary(topi.sqrt))
59-
reg.register_pattern("sqrt", OpPattern.ELEM_WISE)
63+
reg.register_pattern("sqrt", OpPattern.ELEMWISE)
6064
reg.register_schedule("sqrt", _fschedule_broadcast)
6165

6266
# log
6367
reg.register_compute("log", _compute_unary(topi.log))
64-
reg.register_pattern("log", OpPattern.ELEM_WISE)
68+
reg.register_pattern("log", OpPattern.ELEMWISE)
6569
reg.register_schedule("log", _fschedule_broadcast)
6670

6771
# tanh
6872
reg.register_compute("tanh", _compute_unary(topi.tanh))
69-
reg.register_pattern("tanh", OpPattern.ELEM_WISE)
73+
reg.register_pattern("tanh", OpPattern.ELEMWISE)
7074
reg.register_schedule("tanh", _fschedule_broadcast)
7175

7276
# negative
7377
reg.register_compute("negative", _compute_unary(topi.negative))
74-
reg.register_pattern("negative", OpPattern.ELEM_WISE)
78+
reg.register_pattern("negative", OpPattern.ELEMWISE)
7579
reg.register_schedule("negative", _fschedule_broadcast)
7680

7781
# sigmoid
7882
reg.register_compute("sigmoid", _compute_unary(topi.sigmoid))
79-
reg.register_pattern("sigmoid", OpPattern.ELEM_WISE)
83+
reg.register_pattern("sigmoid", OpPattern.ELEMWISE)
8084
reg.register_schedule("sigmoid", _fschedule_broadcast)
8185

8286
# add_scalar
8387
reg.register_compute("__add_scalar__",
8488
_compute_binary_scalar(lambda x, y: x + y))
85-
reg.register_pattern("__add_scalar__", OpPattern.ELEM_WISE)
89+
reg.register_pattern("__add_scalar__", OpPattern.ELEMWISE)
8690
reg.register_schedule("__add_scalar__", _fschedule_broadcast)
8791

8892
# sub_calar
8993
reg.register_compute("__sub_scalar__",
9094
_compute_binary_scalar(lambda x, y: x - y))
91-
reg.register_pattern("__sub_scalar__", OpPattern.ELEM_WISE)
95+
reg.register_pattern("__sub_scalar__", OpPattern.ELEMWISE)
9296
reg.register_schedule("__sub_scalar__", _fschedule_broadcast)
9397

9498
# rsub_scalar
9599
reg.register_compute("__rsub_scalar__",
96100
_compute_binary_scalar(lambda x, y: y - x))
97-
reg.register_pattern("__rsub_scalar__", OpPattern.ELEM_WISE)
101+
reg.register_pattern("__rsub_scalar__", OpPattern.ELEMWISE)
98102
reg.register_schedule("__rsub_scalar__", _fschedule_broadcast)
99103

100104
# mul_scalar
101105
reg.register_compute("__mul_scalar__",
102106
_compute_binary_scalar(lambda x, y: x * y))
103-
reg.register_pattern("__mul_scalar__", OpPattern.ELEM_WISE)
107+
reg.register_pattern("__mul_scalar__", OpPattern.ELEMWISE)
104108
reg.register_schedule("__mul_scalar__", _fschedule_broadcast)
105109

106110
# div_scalar
107111
reg.register_compute("__div_scalar__",
108112
_compute_binary_scalar(lambda x, y: x / y))
109-
reg.register_pattern("__div_scalar__", OpPattern.ELEM_WISE)
113+
reg.register_pattern("__div_scalar__", OpPattern.ELEMWISE)
110114
reg.register_schedule("__div_scalar__", _fschedule_broadcast)
111115

112116
# rdiv_scalar
113117
reg.register_compute("__rdiv_scalar__",
114118
_compute_binary_scalar(lambda x, y: y / x))
115-
reg.register_pattern("__rdiv_scalar__", OpPattern.ELEM_WISE)
119+
reg.register_pattern("__rdiv_scalar__", OpPattern.ELEMWISE)
116120
reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast)
117121

118122
# pow_scalar
119123
reg.register_compute("__pow_scalar__",
120124
_compute_binary_scalar(tvm.power))
121-
reg.register_pattern("__pow_scalar__", OpPattern.ELEM_WISE)
125+
reg.register_pattern("__pow_scalar__", OpPattern.ELEMWISE)
122126
reg.register_schedule("__pow_scalar__", _fschedule_broadcast)
123127

124128
# rpow_scalar
125129
reg.register_compute("__rpow_scalar__",
126130
_compute_binary_scalar(lambda x, y: tvm.power(y, x)))
127-
reg.register_pattern("__rpow_scalar__", OpPattern.ELEM_WISE)
131+
reg.register_pattern("__rpow_scalar__", OpPattern.ELEMWISE)
128132
reg.register_schedule("__rpow_scalar__", _fschedule_broadcast)
129133

130134
# elemwise_add

nnvm/python/nnvm/top/transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,5 @@ def compute_reshape(attrs, inputs, out_info):
3737
oshape = out_info[0].shape
3838
x = inputs[0]
3939
return tvm.compute(oshape, lambda *i: x(_flatten_index(i, oshape)))
40-
reg.register_pattern("reshape", OpPattern.COMPLEX)
40+
reg.register_pattern("reshape", OpPattern.INJECTIVE)
4141
reg.register_schedule("reshape", _fschedule_broadcast)

nnvm/src/compiler/graph_fuse.cc

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
7171
ref_count[e.node_id] += 2;
7272
}
7373
// Pattern for the subgraph
74-
std::vector<TOpPattern> pattern_vec(idx.num_nodes(), kExtern);
74+
std::vector<TOpPattern> pattern_vec(idx.num_nodes(), kOpaque);
7575
// Whether node can be fused to parent.
7676
std::vector<FuseRule> fuse_vec(idx.num_nodes(), FuseRule::kUknown);
7777
// Master node id of fusion segment.
@@ -84,19 +84,21 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
8484
if (inode.source->is_variable()) {
8585
fuse_vec[nid] = FuseRule::kRealize; continue;
8686
}
87-
TOpPattern pt = op_pattern.get(inode.source->op(), kExtern);
87+
TOpPattern pt = op_pattern.get(inode.source->op(), kOpaque);
8888

8989
if (pt <= kBroadcast) {
90+
// Try to check if we can fuse to the master.
9091
int chosen_master = -1;
9192
bool ewise = inode.source->num_outputs() == 1;
9293
for (const auto& e : inode.inputs) {
9394
if (fuse_vec[e.node_id] == FuseRule::kUknown) {
9495
TOpPattern ipt = pattern_vec[e.node_id];
9596
if (ipt != kElemWise) ewise = false;
96-
if (ipt <= kBroadcast) {
97+
if (ipt <= kInjective) {
9798
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
98-
} else if (ipt == kComplex && chosen_master == -1 &&
99-
shape_vec[idx.entry_id(nid, 0)] == shape_vec[idx.entry_id(e)]) {
99+
} else if (ipt == kOutEWiseFusable &&
100+
chosen_master == -1 &&
101+
shape_vec[idx.entry_id(nid, 0)] == shape_vec[idx.entry_id(e)]) {
100102
chosen_master = master_vec[e.node_id];
101103
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
102104
} else {
@@ -111,11 +113,27 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
111113
}
112114
master_vec[nid] = chosen_master;
113115
if (chosen_master != -1) {
114-
pt = kComplex;
116+
pt = kOutEWiseFusable;
115117
} else {
116118
pt = ewise ? kElemWise : kBroadcast;
117119
}
120+
} else if (pt == kInjective || pt == kCommReduce) {
121+
// fuse to the comm reduce or injective
122+
for (const auto& e : inode.inputs) {
123+
if (fuse_vec[e.node_id] == FuseRule::kUknown) {
124+
TOpPattern ipt = pattern_vec[e.node_id];
125+
if (ipt <= kInjective) {
126+
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
127+
} else {
128+
fuse_vec[e.node_id] = FuseRule::kRealize;
129+
}
130+
}
131+
}
132+
if (pt == kCommReduce) {
133+
master_vec[nid] = nid;
134+
}
118135
} else {
136+
// realize
119137
master_vec[nid] = nid;
120138
for (const auto& e : inode.inputs) {
121139
if (fuse_vec[e.node_id] == FuseRule::kUknown) {
@@ -136,7 +154,6 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
136154
}
137155
}
138156

139-
140157
// point to the group root id of each node
141158
std::vector<int> group_vec(idx.num_nodes(), -1);
142159
for (uint32_t i = idx.num_nodes(); i != 0; --i) {

nnvm/src/compiler/layout_transform.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) {
5252

5353
// use op pattern to decide whether an op is map
5454
auto is_map_op = [&](size_t nid) {
55-
TOpPattern pt = op_pattern.get(idx[nid].source->op(), kExtern);
55+
TOpPattern pt = op_pattern.get(idx[nid].source->op(), kOpaque);
5656
bool is_map = (pt <= kBroadcast);
5757
if (pt == kBroadcast) {
5858
for (const auto& e : idx[nid].inputs) {

0 commit comments

Comments
 (0)