Skip to content

Commit 15584c2

Browse files
author
Mikael Sevenier
committed
Merge branches 'master' and 'master' of https://github.com/apache/incubator-tvm
2 parents b6b920d + 2621554 commit 15584c2

17 files changed

Lines changed: 372 additions & 98 deletions

File tree

docs/deploy/arm_compute_lib.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ Operator support
232232
+----------------------+-------------------------------------------------------------------------+
233233
| reshape | fp32, uint8 |
234234
+----------------------+-------------------------------------------------------------------------+
235+
| maximum | fp32 |
236+
+----------------------+-------------------------------------------------------------------------+
235237

236238
.. note::
237239
A composite operator is a series of operators that map to a single Arm Compute Library operator. You can view this

python/tvm/relay/op/contrib/arm_compute_lib.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,3 +337,11 @@ def global_avg_pool2d(attrs, args):
337337
if attrs.layout != "NHWC":
338338
return False
339339
return True
340+
341+
342+
@tvm.ir.register_op_attr("maximum", "target.arm_compute_lib")
343+
def maximum(attrs, args):
344+
"""Check if the external ACL codegen for maximum should be used."""
345+
type_a = args[0].checked_type
346+
type_b = args[0].checked_type
347+
return (type_a.dtype == "float32") and (type_b.dtype == "float32")

python/tvm/relay/testing/yolo_detection.py

Lines changed: 82 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -196,41 +196,91 @@ def do_nms_sort(dets, classes, thresh):
196196
dets[j]["prob"][k] = 0
197197

198198

199+
def get_detections(im, det, thresh, names, classes):
200+
"Draw the markings around the detected region"
201+
labelstr = []
202+
category = -1
203+
detection = None
204+
valid = False
205+
for j in range(classes):
206+
if det["prob"][j] > thresh:
207+
if category == -1:
208+
category = j
209+
labelstr.append(names[j] + " " + str(round(det["prob"][j], 4)))
210+
211+
if category > -1:
212+
valid = True
213+
imc, imh, imw = im.shape
214+
width = int(imh * 0.006)
215+
offset = category * 123457 % classes
216+
red = _get_color(2, offset, classes)
217+
green = _get_color(1, offset, classes)
218+
blue = _get_color(0, offset, classes)
219+
rgb = [red, green, blue]
220+
b = det["bbox"]
221+
left = int((b.x - b.w / 2.0) * imw)
222+
right = int((b.x + b.w / 2.0) * imw)
223+
top = int((b.y - b.h / 2.0) * imh)
224+
bot = int((b.y + b.h / 2.0) * imh)
225+
226+
if left < 0:
227+
left = 0
228+
if right > imw - 1:
229+
right = imw - 1
230+
if top < 0:
231+
top = 0
232+
if bot > imh - 1:
233+
bot = imh - 1
234+
235+
detection = {
236+
"category": category,
237+
"labelstr": labelstr,
238+
"left": left,
239+
"top": top,
240+
"right": right,
241+
"bot": bot,
242+
"width": width,
243+
"rgb": rgb,
244+
}
245+
246+
return valid, detection
247+
248+
199249
def draw_detections(font_path, im, dets, thresh, names, classes):
200250
"Draw the markings around the detected region"
201251
for det in dets:
202-
labelstr = []
203-
category = -1
204-
for j in range(classes):
205-
if det["prob"][j] > thresh:
206-
if category == -1:
207-
category = j
208-
labelstr.append(names[j] + " " + str(round(det["prob"][j], 4)))
209-
if category > -1:
210-
imc, imh, imw = im.shape
211-
width = int(imh * 0.006)
212-
offset = category * 123457 % classes
213-
red = _get_color(2, offset, classes)
214-
green = _get_color(1, offset, classes)
215-
blue = _get_color(0, offset, classes)
216-
rgb = [red, green, blue]
217-
b = det["bbox"]
218-
left = int((b.x - b.w / 2.0) * imw)
219-
right = int((b.x + b.w / 2.0) * imw)
220-
top = int((b.y - b.h / 2.0) * imh)
221-
bot = int((b.y + b.h / 2.0) * imh)
222-
223-
if left < 0:
224-
left = 0
225-
if right > imw - 1:
226-
right = imw - 1
227-
if top < 0:
228-
top = 0
229-
if bot > imh - 1:
230-
bot = imh - 1
231-
_draw_box_width(im, left, top, right, bot, width, red, green, blue)
232-
label = _get_label(font_path, "".join(labelstr), rgb)
233-
_draw_label(im, top + width, left, label, rgb)
252+
valid, detection = get_detections(im, det, thresh, names, classes)
253+
if valid:
254+
rgb = detection["rgb"]
255+
label = _get_label(font_path, "".join(detection["labelstr"]), rgb)
256+
_draw_box_width(
257+
im,
258+
detection["left"],
259+
detection["top"],
260+
detection["right"],
261+
detection["bot"],
262+
detection["width"],
263+
rgb[0],
264+
rgb[1],
265+
rgb[2],
266+
)
267+
_draw_label(im, detection["top"] + detection["width"], detection["left"], label, rgb)
268+
269+
270+
def show_detections(im, dets, thresh, names, classes):
271+
"Print the markings and the detected region"
272+
for det in dets:
273+
valid, detection = get_detections(im, det, thresh, names, classes)
274+
if valid:
275+
print(
276+
"class:{} left:{} right:{} top:{} bottom:{}".format(
277+
detection["labelstr"],
278+
detection["left"],
279+
detection["top"],
280+
detection["right"],
281+
detection["bot"],
282+
)
283+
)
234284

235285

236286
def _get_pixel(im, x, y, c):

rust/tvm/examples/resnet/src/build_resnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,9 @@ def build(target_dir):
7474
block = get_model("resnet18_v1", pretrained=True)
7575
net, params = relay.frontend.from_mxnet(block, {"data": data_shape})
7676
# we want a probability so add a softmax operator
77+
func = net["main"]
7778
net = relay.Function(
78-
net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs
79+
func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs
7980
)
8081
else:
8182
# use random weights from relay.testing

src/auto_scheduler/search_policy/sketch_policy.cc

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "sketch_policy.h"
2828

2929
#include <tvm/runtime/registry.h>
30+
#include <tvm/support/parallel_for.h>
3031

3132
#include <algorithm>
3233
#include <iomanip>
@@ -334,28 +335,44 @@ Array<State> SketchPolicyNode::GenerateSketches() {
334335
Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches, int out_size) {
335336
int fail_ct = 0;
336337
Array<State> out_states;
338+
std::vector<std::mt19937> rand_gens;
339+
rand_gens.reserve(out_size);
340+
for (int i = 0; i < out_size; i++) {
341+
rand_gens.push_back(std::mt19937(rand_gen()));
342+
}
337343
auto tic_begin = std::chrono::high_resolution_clock::now();
338344

339345
while (static_cast<int>(out_states.size()) < out_size && fail_ct < out_size) {
340-
// Random choose a starting sketch
341-
// TODO(jcf94, merrymercy): Maybe choose sketches in different possibility for they may have
342-
// different potential on generating state with better performance
343-
State tmp_s = sketches[(rand_gen)() % sketches.size()];
344-
345-
// Derivation rule based enumeration
346-
bool valid = true;
347-
for (const auto& rule : init_rules) {
348-
if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kInvalid) {
349-
valid = false;
350-
break;
346+
std::vector<State> temp_states(out_size);
347+
348+
support::parallel_for(0, out_size - out_states.size(),
349+
[this, &temp_states, &sketches, &rand_gens](int index) {
350+
// Random choose a starting sketch
351+
// TODO(jcf94, merrymercy): Maybe choose sketches in different
352+
// possibility for they may have different potential on generating state
353+
// with better performance
354+
State tmp_s = sketches[(rand_gens[index])() % sketches.size()];
355+
// Derivation rule based enumeration
356+
bool valid = true;
357+
for (const auto& rule : init_rules) {
358+
if (rule->Apply(this, &tmp_s, &rand_gens[index]) ==
359+
PopulationGenerationRule::ResultKind::kInvalid) {
360+
valid = false;
361+
break;
362+
}
363+
}
364+
if (valid) {
365+
temp_states[index] = std::move(tmp_s);
366+
}
367+
});
368+
369+
for (int i = 0; i < out_size; i++) {
370+
if (temp_states[i].defined()) {
371+
out_states.push_back(std::move(temp_states[i]));
372+
} else {
373+
fail_ct++;
351374
}
352375
}
353-
354-
if (valid) {
355-
out_states.push_back(std::move(tmp_s));
356-
} else {
357-
fail_ct++;
358-
}
359376
}
360377

361378
double duration = std::chrono::duration_cast<std::chrono::duration<double>>(
@@ -461,7 +478,7 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const Array<State>& init_popul
461478

462479
if (dis(rand_gen) < mutation_prob) {
463480
const auto& rule = mutation_rules[RandomChoose(rule_selection_probs, &rand_gen)];
464-
if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kValid) {
481+
if (rule->Apply(this, &tmp_s, &rand_gen) == PopulationGenerationRule::ResultKind::kValid) {
465482
pnext->push_back(std::move(tmp_s));
466483
mutation_success_ct++;
467484
} else {

0 commit comments

Comments
 (0)