2626from topi .util import get_const_tuple
2727from topi .vision import ssd , non_max_suppression , get_valid_counts
2828
29+ from common import get_schedule
30+
31+ _get_valid_counts_schedule = {
32+ "generic" : topi .generic .schedule_get_valid_counts ,
33+ "gpu" : topi .cuda .schedule_get_valid_counts ,
34+ }
35+
36+ _nms_schedule = {
37+ "generic" : topi .generic .schedule_nms ,
38+ "gpu" : topi .cuda .schedule_nms ,
39+ }
40+
41+ _multibox_prior_schedule = {
42+ "generic" : topi .generic .schedule_multibox_prior ,
43+ "gpu" : topi .cuda .schedule_multibox_prior ,
44+ }
45+
46+ _multibox_detection_schedule = {
47+ "generic" : topi .generic .schedule_multibox_detection ,
48+ "gpu" : topi .cuda .schedule_multibox_detection ,
49+ }
50+
51+ _roi_align_schedule = {
52+ "generic" : topi .generic .schedule_roi_align ,
53+ "gpu" : topi .cuda .schedule_roi_align ,
54+ }
55+
56+ _roi_pool_schedule = {
57+ "generic" : topi .generic .schedule_roi_pool ,
58+ "gpu" : topi .cuda .schedule_roi_pool ,
59+ }
60+
61+ _proposal_schedule = {
62+ "generic" : topi .generic .schedule_proposal ,
63+ "gpu" : topi .cuda .schedule_proposal ,
64+ }
2965
3066def verify_get_valid_counts (dshape , score_threshold , id_index , score_index ):
3167 dtype = "float32"
@@ -56,7 +92,8 @@ def check_device(device):
5692 with tvm .target .create (device ):
5793 data = tvm .placeholder (dshape , name = "data" , dtype = dtype )
5894 outs = get_valid_counts (data , score_threshold , id_index , score_index )
59- s = topi .generic .schedule_get_valid_counts (outs )
95+ s_func = get_schedule (device , _get_valid_counts_schedule )
96+ s = s_func (outs )
6097
6198 tvm_input_data = tvm .nd .array (np_data , ctx )
6299 tvm_out1 = tvm .nd .array (np .zeros (np_out1 .shape , dtype = "int32" ), ctx )
@@ -68,8 +105,6 @@ def check_device(device):
68105
69106 for device in ['llvm' , 'cuda' , 'opencl' ]:
70107 # Disable gpu test for now
71- if device != "llvm" :
72- continue
73108 check_device (device )
74109
75110
@@ -107,7 +142,8 @@ def check_device(device):
107142 return_indices = False )
108143 indices_out = topi .cuda .non_max_suppression (data , valid_count , - 1 , iou_threshold , force_suppress , top_k ,
109144 coord_start = coord_start , score_index = score_index , id_index = id_index )
110- s = topi .generic .schedule_nms (out )
145+ s_func = get_schedule (device , _nms_schedule )
146+ s = s_func (out )
111147 indices_s = topi .generic .schedule_nms (indices_out )
112148
113149 tvm_data = tvm .nd .array (np_data , ctx )
@@ -198,7 +234,8 @@ def check_device(device):
198234 out = ssd .multibox_prior (data , sizes , ratios , steps , offsets , clip )
199235 else :
200236 out = topi .cuda .ssd .multibox_prior (data , sizes , ratios , steps , offsets , clip )
201- s = topi .generic .schedule_multibox_prior (out )
237+ s_func = get_schedule (device , _multibox_prior_schedule )
238+ s = s_func (out )
202239
203240 tvm_input_data = tvm .nd .array (input_data , ctx )
204241 tvm_out = tvm .nd .array (np .zeros (oshape , dtype = dtype ), ctx )
@@ -244,7 +281,8 @@ def check_device(device):
244281 out = ssd .multibox_detection (cls_prob , loc_preds , anchors )
245282 else :
246283 out = topi .cuda .ssd .multibox_detection (cls_prob , loc_preds , anchors )
247- s = topi .generic .schedule_multibox_detection (out )
284+ s_func = get_schedule (device , _multibox_detection_schedule )
285+ s = s_func (out )
248286
249287 tvm_cls_prob = tvm .nd .array (np_cls_prob .astype (cls_prob .dtype ), ctx )
250288 tvm_loc_preds = tvm .nd .array (np_loc_preds .astype (loc_preds .dtype ), ctx )
@@ -289,7 +327,8 @@ def check_device(device):
289327 b = topi .vision .rcnn .roi_align_nchw (a , rois , pooled_size = pooled_size ,
290328 spatial_scale = spatial_scale ,
291329 sample_ratio = sample_ratio )
292- s = topi .generic .schedule_roi_align (b )
330+ s_func = get_schedule (device , _roi_align_schedule )
331+ s = s_func (b )
293332
294333 tvm_a = tvm .nd .array (a_np , ctx )
295334 tvm_rois = tvm .nd .array (rois_np , ctx )
@@ -338,7 +377,8 @@ def check_device(device):
338377 with tvm .target .create (device ):
339378 b = topi .vision .rcnn .roi_pool_nchw (a , rois , pooled_size = pooled_size ,
340379 spatial_scale = spatial_scale )
341- s = topi .generic .schedule_roi_pool (b )
380+ s_func = get_schedule (device , _roi_pool_schedule )
381+ s = s_func (b )
342382
343383 tvm_a = tvm .nd .array (a_np , ctx )
344384 tvm_rois = tvm .nd .array (rois_np , ctx )
@@ -369,7 +409,8 @@ def check_device(device):
369409 print ("Running on target: %s" % device )
370410 with tvm .target .create (device ):
371411 out = topi .vision .proposal (cls_prob , bbox_pred , im_info , ** attrs )
372- s = topi .generic .schedule_proposal (out )
412+ s_func = get_schedule (device , _proposal_schedule )
413+ s = s_func (out )
373414 f = tvm .build (s , [cls_prob , bbox_pred , im_info , out ], device )
374415 tvm_cls_prob = tvm .nd .array (np_cls_prob , ctx = ctx )
375416 tvm_bbox_pred = tvm .nd .array (np_bbox_pred , ctx = ctx )
0 commit comments