Skip to content

Commit de97cd2

Browse files
authored
Changing intel_graphics conv2d to match upstream (apache#48)
1 parent 7d8784b commit de97cd2

1 file changed

Lines changed: 65 additions & 8 deletions

File tree

topi/python/topi/intel_graphics/conv2d.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
2626
from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, conv2d_infer_layout
2727
from ..nn.util import get_pad_tuple
28+
from ..nn.depthwise_conv2d import depthwise_conv2d_nchw
2829
from ..nn import pad
2930
from .. import tag
3031
from .. import generic
@@ -162,21 +163,77 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
162163
if F.__name__ == 'tvm.relay.op':
163164
# Derive channels for frontends (e.g ONNX) that miss "channel" field.
164165
new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')]
165-
166166
# Remove attached compilation target because conv2d_NCHWc needs to create
167167
# a conv2d_nchwc op and target is not one of conv2d's parameters.
168168
if "target" in new_attrs:
169169
del new_attrs["target"]
170170

171-
if F.__name__ == 'nnvm.symbol':
172-
out = F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
173-
else:
174-
out = F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
171+
data, kernel = tinfo[0], tinfo[1]
172+
batch_size, in_channel, height, width = get_const_tuple(data.shape)
173+
174+
groups = attrs.get_int("groups")
175+
out_channel = attrs.get_int("channels")
176+
padding = attrs.get_int_tuple("padding")
177+
strides = attrs.get_int_tuple("strides")
178+
dilation = attrs.get_int_tuple("dilation")
179+
out_dtype = attrs["out_dtype"]
180+
181+
layout_name = 'layout' if F == sym else 'data_layout'
182+
layout = attrs[layout_name]
183+
kh, kw = attrs.get_int_tuple("kernel_size")
184+
185+
dtype = data.dtype
186+
out_dtype = dtype if out_dtype in ("same", "") else out_dtype
187+
is_depthwise = groups == in_channel and groups == out_channel
188+
189+
# only optimize for NCHW
190+
if layout != 'NCHW':
191+
return None
192+
if groups != 1 and not is_depthwise:
193+
return None
194+
195+
dispatch_ctx = autotvm.task.DispatchContext.current
196+
target = tvm.target.current_target()
197+
198+
# query schedule and fallback if necessary
199+
workload = autotvm.task.args_to_workload(
200+
[data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw) \
201+
if is_depthwise else \
202+
autotvm.task.args_to_workload(
203+
[data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
204+
if is_depthwise:
205+
return None
206+
cfg = dispatch_ctx.query(target, workload)
207+
if cfg.is_fallback:
208+
_get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise)
209+
210+
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
211+
212+
new_attrs[layout_name] = 'NCHW%dc' % ic_bn
213+
new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
214+
215+
new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
216+
dtype=data.dtype)
217+
218+
out_channel, _, kh, kw = get_const_tuple(kernel.shape)
219+
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
220+
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
221+
222+
# Store altered operator's config
223+
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn),
224+
dtype=kernel.dtype)
225+
new_workload = autotvm.task.args_to_workload(
226+
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
227+
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
175228

176-
return out
229+
dispatch_ctx.update(target, new_workload, cfg)
230+
if F == sym:
231+
return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
232+
return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
177233

178-
@conv2d_NCHWc.register(["intel_graphics"])
179-
def _decl_conv2d(data, kernel, stride, padding, dilation, layout, out_layout, out_dtype='float32'):
234+
@autotvm.register_topi_compute(conv2d_NCHWc, 'intel_graphics', 'direct')
235+
def _decl_conv2d(cfg, data, kernel, strides, padding, dilation,
236+
layout, out_layout, out_dtype='float32'):
180237
"""Conv2D operator for Intel Graphics backend.
181238
182239
Parameters

0 commit comments

Comments
 (0)