|
8 | 8 | from ..compiler import registry as reg |
9 | 9 | from ..compiler import OpPattern |
10 | 10 |
|
11 | | -def _schedule_broadcast(_, outs, target): |
| 11 | +def _schedule_injective(_, outs, target): |
12 | 12 | """Generic schedule for binary bcast""" |
13 | 13 | if target == "cuda": |
14 | | - return topi.cuda.schedule_elemwise(outs) |
| 14 | + return topi.cuda.schedule_injective(outs) |
15 | 15 | assert target.startswith("llvm") |
16 | 16 | s = tvm.create_schedule([x.op for x in outs]) |
| 17 | + x = outs[0] |
17 | 18 | tvm.schedule.AutoInlineInjective(s) |
| 19 | + s[x].fuse(s[x].op.axis) |
18 | 20 | return s |
19 | 21 |
|
20 | 22 | def _compute_binary_scalar(f): |
@@ -42,89 +44,91 @@ def _compute(attrs, x, _): |
42 | 44 | return _compute |
43 | 45 |
|
44 | 46 |
|
45 | | -_fschedule_broadcast = tvm.convert(_schedule_broadcast) |
| 47 | +_fschedule_injective = tvm.convert(_schedule_injective) |
| 48 | +_fschedule_broadcast = _fschedule_injective |
| 49 | +_fschedule_elemwise = _fschedule_injective |
46 | 50 |
|
47 | 51 | # copy |
48 | 52 | reg.register_compute("copy", _compute_unary(topi.identity)) |
49 | | -reg.register_pattern("copy", OpPattern.ELEM_WISE) |
| 53 | +reg.register_pattern("copy", OpPattern.ELEMWISE) |
50 | 54 | reg.register_schedule("copy", _fschedule_broadcast) |
51 | 55 |
|
52 | 56 | # exp |
53 | 57 | reg.register_compute("exp", _compute_unary(topi.exp)) |
54 | | -reg.register_pattern("exp", OpPattern.ELEM_WISE) |
| 58 | +reg.register_pattern("exp", OpPattern.ELEMWISE) |
55 | 59 | reg.register_schedule("exp", _fschedule_broadcast) |
56 | 60 |
|
57 | 61 | # sqrt |
58 | 62 | reg.register_compute("sqrt", _compute_unary(topi.sqrt)) |
59 | | -reg.register_pattern("sqrt", OpPattern.ELEM_WISE) |
| 63 | +reg.register_pattern("sqrt", OpPattern.ELEMWISE) |
60 | 64 | reg.register_schedule("sqrt", _fschedule_broadcast) |
61 | 65 |
|
62 | 66 | # log |
63 | 67 | reg.register_compute("log", _compute_unary(topi.log)) |
64 | | -reg.register_pattern("log", OpPattern.ELEM_WISE) |
| 68 | +reg.register_pattern("log", OpPattern.ELEMWISE) |
65 | 69 | reg.register_schedule("log", _fschedule_broadcast) |
66 | 70 |
|
67 | 71 | # tanh |
68 | 72 | reg.register_compute("tanh", _compute_unary(topi.tanh)) |
69 | | -reg.register_pattern("tanh", OpPattern.ELEM_WISE) |
| 73 | +reg.register_pattern("tanh", OpPattern.ELEMWISE) |
70 | 74 | reg.register_schedule("tanh", _fschedule_broadcast) |
71 | 75 |
|
72 | 76 | # negative |
73 | 77 | reg.register_compute("negative", _compute_unary(topi.negative)) |
74 | | -reg.register_pattern("negative", OpPattern.ELEM_WISE) |
| 78 | +reg.register_pattern("negative", OpPattern.ELEMWISE) |
75 | 79 | reg.register_schedule("negative", _fschedule_broadcast) |
76 | 80 |
|
77 | 81 | # sigmoid |
78 | 82 | reg.register_compute("sigmoid", _compute_unary(topi.sigmoid)) |
79 | | -reg.register_pattern("sigmoid", OpPattern.ELEM_WISE) |
| 83 | +reg.register_pattern("sigmoid", OpPattern.ELEMWISE) |
80 | 84 | reg.register_schedule("sigmoid", _fschedule_broadcast) |
81 | 85 |
|
82 | 86 | # add_scalar |
83 | 87 | reg.register_compute("__add_scalar__", |
84 | 88 | _compute_binary_scalar(lambda x, y: x + y)) |
85 | | -reg.register_pattern("__add_scalar__", OpPattern.ELEM_WISE) |
| 89 | +reg.register_pattern("__add_scalar__", OpPattern.ELEMWISE) |
86 | 90 | reg.register_schedule("__add_scalar__", _fschedule_broadcast) |
87 | 91 |
|
88 | 92 | # sub_calar |
89 | 93 | reg.register_compute("__sub_scalar__", |
90 | 94 | _compute_binary_scalar(lambda x, y: x - y)) |
91 | | -reg.register_pattern("__sub_scalar__", OpPattern.ELEM_WISE) |
| 95 | +reg.register_pattern("__sub_scalar__", OpPattern.ELEMWISE) |
92 | 96 | reg.register_schedule("__sub_scalar__", _fschedule_broadcast) |
93 | 97 |
|
94 | 98 | # rsub_scalar |
95 | 99 | reg.register_compute("__rsub_scalar__", |
96 | 100 | _compute_binary_scalar(lambda x, y: y - x)) |
97 | | -reg.register_pattern("__rsub_scalar__", OpPattern.ELEM_WISE) |
| 101 | +reg.register_pattern("__rsub_scalar__", OpPattern.ELEMWISE) |
98 | 102 | reg.register_schedule("__rsub_scalar__", _fschedule_broadcast) |
99 | 103 |
|
100 | 104 | # mul_scalar |
101 | 105 | reg.register_compute("__mul_scalar__", |
102 | 106 | _compute_binary_scalar(lambda x, y: x * y)) |
103 | | -reg.register_pattern("__mul_scalar__", OpPattern.ELEM_WISE) |
| 107 | +reg.register_pattern("__mul_scalar__", OpPattern.ELEMWISE) |
104 | 108 | reg.register_schedule("__mul_scalar__", _fschedule_broadcast) |
105 | 109 |
|
106 | 110 | # div_scalar |
107 | 111 | reg.register_compute("__div_scalar__", |
108 | 112 | _compute_binary_scalar(lambda x, y: x / y)) |
109 | | -reg.register_pattern("__div_scalar__", OpPattern.ELEM_WISE) |
| 113 | +reg.register_pattern("__div_scalar__", OpPattern.ELEMWISE) |
110 | 114 | reg.register_schedule("__div_scalar__", _fschedule_broadcast) |
111 | 115 |
|
112 | 116 | # rdiv_scalar |
113 | 117 | reg.register_compute("__rdiv_scalar__", |
114 | 118 | _compute_binary_scalar(lambda x, y: y / x)) |
115 | | -reg.register_pattern("__rdiv_scalar__", OpPattern.ELEM_WISE) |
| 119 | +reg.register_pattern("__rdiv_scalar__", OpPattern.ELEMWISE) |
116 | 120 | reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast) |
117 | 121 |
|
118 | 122 | # pow_scalar |
119 | 123 | reg.register_compute("__pow_scalar__", |
120 | 124 | _compute_binary_scalar(tvm.power)) |
121 | | -reg.register_pattern("__pow_scalar__", OpPattern.ELEM_WISE) |
| 125 | +reg.register_pattern("__pow_scalar__", OpPattern.ELEMWISE) |
122 | 126 | reg.register_schedule("__pow_scalar__", _fschedule_broadcast) |
123 | 127 |
|
124 | 128 | # rpow_scalar |
125 | 129 | reg.register_compute("__rpow_scalar__", |
126 | 130 | _compute_binary_scalar(lambda x, y: tvm.power(y, x))) |
127 | | -reg.register_pattern("__rpow_scalar__", OpPattern.ELEM_WISE) |
| 131 | +reg.register_pattern("__rpow_scalar__", OpPattern.ELEMWISE) |
128 | 132 | reg.register_schedule("__rpow_scalar__", _fschedule_broadcast) |
129 | 133 |
|
130 | 134 | # elemwise_add |
|
0 commit comments