|
25 | 25 | from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity |
26 | 26 | from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, conv2d_infer_layout |
27 | 27 | from ..nn.util import get_pad_tuple |
| 28 | +from ..nn.depthwise_conv2d import depthwise_conv2d_nchw |
28 | 29 | from ..nn import pad |
29 | 30 | from .. import tag |
30 | 31 | from .. import generic |
@@ -162,21 +163,77 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): |
162 | 163 | if F.__name__ == 'tvm.relay.op': |
163 | 164 | # Derive channels for frontends (e.g ONNX) that miss "channel" field. |
164 | 165 | new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')] |
165 | | - |
166 | 166 | # Remove attached compilation target because conv2d_NCHWc needs to create |
167 | 167 | # a conv2d_nchwc op and target is not one of conv2d's parameters. |
168 | 168 | if "target" in new_attrs: |
169 | 169 | del new_attrs["target"] |
170 | 170 |
|
171 | | - if F.__name__ == 'nnvm.symbol': |
172 | | - out = F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) |
173 | | - else: |
174 | | - out = F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs) |
| 171 | + data, kernel = tinfo[0], tinfo[1] |
| 172 | + batch_size, in_channel, height, width = get_const_tuple(data.shape) |
| 173 | + |
| 174 | + groups = attrs.get_int("groups") |
| 175 | + out_channel = attrs.get_int("channels") |
| 176 | + padding = attrs.get_int_tuple("padding") |
| 177 | + strides = attrs.get_int_tuple("strides") |
| 178 | + dilation = attrs.get_int_tuple("dilation") |
| 179 | + out_dtype = attrs["out_dtype"] |
| 180 | + |
| 181 | + layout_name = 'layout' if F == sym else 'data_layout' |
| 182 | + layout = attrs[layout_name] |
| 183 | + kh, kw = attrs.get_int_tuple("kernel_size") |
| 184 | + |
| 185 | + dtype = data.dtype |
| 186 | + out_dtype = dtype if out_dtype in ("same", "") else out_dtype |
| 187 | + is_depthwise = groups == in_channel and groups == out_channel |
| 188 | + |
| 189 | + # only optimize for NCHW |
| 190 | + if layout != 'NCHW': |
| 191 | + return None |
| 192 | + if groups != 1 and not is_depthwise: |
| 193 | + return None |
| 194 | + |
| 195 | + dispatch_ctx = autotvm.task.DispatchContext.current |
| 196 | + target = tvm.target.current_target() |
| 197 | + |
| 198 | + # query schedule and fallback if necessary |
| 199 | + workload = autotvm.task.args_to_workload( |
| 200 | + [data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw) \ |
| 201 | + if is_depthwise else \ |
| 202 | + autotvm.task.args_to_workload( |
| 203 | + [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d) |
| 204 | + if is_depthwise: |
| 205 | + return None |
| 206 | + cfg = dispatch_ctx.query(target, workload) |
| 207 | + if cfg.is_fallback: |
| 208 | + _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise) |
| 209 | + |
| 210 | + ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] |
| 211 | + |
| 212 | + new_attrs[layout_name] = 'NCHW%dc' % ic_bn |
| 213 | + new_attrs['out_layout'] = 'NCHW%dc' % oc_bn |
| 214 | + |
| 215 | + new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), |
| 216 | + dtype=data.dtype) |
| 217 | + |
| 218 | + out_channel, _, kh, kw = get_const_tuple(kernel.shape) |
| 219 | + # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) |
| 220 | + new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) |
| 221 | + |
| 222 | + # Store altered operator's config |
| 223 | + new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn), |
| 224 | + dtype=kernel.dtype) |
| 225 | + new_workload = autotvm.task.args_to_workload( |
| 226 | + [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name], |
| 227 | + new_attrs['out_layout'], out_dtype], conv2d_NCHWc) |
175 | 228 |
|
176 | | - return out |
| 229 | + dispatch_ctx.update(target, new_workload, cfg) |
| 230 | + if F == sym: |
| 231 | + return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) |
| 232 | + return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs) |
177 | 233 |
|
178 | | -@conv2d_NCHWc.register(["intel_graphics"]) |
179 | | -def _decl_conv2d(data, kernel, stride, padding, dilation, layout, out_layout, out_dtype='float32'): |
| 234 | +@autotvm.register_topi_compute(conv2d_NCHWc, 'intel_graphics', 'direct') |
| 235 | +def _decl_conv2d(cfg, data, kernel, strides, padding, dilation, |
| 236 | + layout, out_layout, out_dtype='float32'): |
180 | 237 | """Conv2D operator for Intel Graphics backend. |
181 | 238 |
|
182 | 239 | Parameters |
|
0 commit comments