Skip to content

Commit c1fee31

Browse files
tqchentmoreau89
authored andcommitted
[VTA][SIM] Allow debug mode in simulator to skip execution (apache#6)
1 parent f3aeda7 commit c1fee31

4 files changed

Lines changed: 54 additions & 88 deletions

File tree

vta/python/vta/testing/simulator.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,19 @@ def clear_stats():
2727
if f:
2828
f()
2929

30+
# debug flag to skip execution.
31+
DEBUG_SKIP_EXEC = 1
32+
33+
def debug_mode(flag):
34+
"""Set debug mode
35+
36+
Paramaters
37+
----------
38+
flag : int
39+
The debug flag, 0 means clear all flags.
40+
"""
41+
tvm.get_global_func("vta.simulator.profiler_debug_mode")(flag)
42+
3043

3144
def stats():
3245
"""Clear profiler statistics

vta/python/vta/top/arm_conv2d.py

Lines changed: 0 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -5,87 +5,6 @@
55
from topi.nn import conv2d, conv2d_alter_layout
66
from topi import generic
77

8-
_WORKLOADS = [
9-
# resnet 18
10-
Workload('float32', 'float32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
11-
Workload('int8', 'int32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
12-
Workload('int8', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
13-
Workload('int8', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
14-
Workload('int8', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
15-
Workload('int8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
16-
Workload('int8', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
17-
Workload('int8', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
18-
Workload('int8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
19-
Workload('int8', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
20-
Workload('int8', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
21-
Workload('int8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
22-
Workload('int8', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
23-
24-
# mobilenet float32
25-
Workload('float32', 'float32', 224, 224, 3, 32, 3, 3, 1, 1, 2, 2),
26-
Workload('float32', 'float32', 112, 112, 32, 64, 1, 1, 0, 0, 1, 1),
27-
Workload('float32', 'float32', 56, 56, 64, 128, 1, 1, 0, 0, 1, 1),
28-
Workload('float32', 'float32', 56, 56, 128, 128, 1, 1, 0, 0, 1, 1),
29-
Workload('float32', 'float32', 28, 28, 128, 256, 1, 1, 0, 0, 1, 1),
30-
Workload('float32', 'float32', 28, 28, 256, 256, 1, 1, 0, 0, 1, 1),
31-
Workload('float32', 'float32', 14, 14, 256, 512, 1, 1, 0, 0, 1, 1),
32-
Workload('float32', 'float32', 14, 14, 512, 512, 1, 1, 0, 0, 1, 1),
33-
Workload('float32', 'float32', 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1),
34-
Workload('float32', 'float32', 7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1),
35-
36-
# mobilenet int8
37-
Workload('float32', 'float32', 224, 224, 3, 32, 3, 3, 1, 1, 2, 2),
38-
Workload('int8', 'int32', 112, 112, 32, 64, 1, 1, 0, 0, 1, 1),
39-
Workload('int8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 1, 1),
40-
Workload('int8', 'int32', 56, 56, 128, 128, 1, 1, 0, 0, 1, 1),
41-
Workload('int8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 1, 1),
42-
Workload('int8', 'int32', 28, 28, 256, 256, 1, 1, 0, 0, 1, 1),
43-
Workload('int8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 1, 1),
44-
Workload('int8', 'int32', 14, 14, 512, 512, 1, 1, 0, 0, 1, 1),
45-
Workload('int8', 'int32', 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1),
46-
Workload('int8', 'int32', 7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1),
47-
]
48-
49-
_SCHEDULES = [
50-
# float32 imagenet
51-
SpatialPack(1, 8, 4, 1, 4, True),
52-
SpatialPack(1, 8, 4, 1, 4, True),
53-
SpatialPack(1, 7, 4, 2, 4, True),
54-
SpatialPack(1, 4, 8, 4, 1, True),
55-
SpatialPack(1, 4, 4, 1, 16, False),
56-
SpatialPack(1, 4, 8, 4, 8, False),
57-
SpatialPack(1, 7, 4, 3, 8, True),
58-
SpatialPack(1, 2, 8, 1, 8, True),
59-
SpatialPack(2, 1, 16, 1, 4, True),
60-
SpatialPack(1, 7, 4, 1, 1, True),
61-
Im2ColPack(7, 4, 1, 16, True),
62-
Im2ColPack(7, 4, 1, 8, False),
63-
Im2ColPack(7, 4, 1, 16, False),
64-
65-
# float32 mobilenet
66-
SpatialPack(2, 2, 4, 28, 1, True),
67-
SpatialPack(1, 4, 8, 14, 1, False),
68-
SpatialPack(1, 2, 16, 8, 1, True),
69-
SpatialPack(1, 4, 8, 8, 8, True),
70-
SpatialPack(2, 2, 8, 1, 1, False),
71-
SpatialPack(1, 4, 8, 4, 8, False),
72-
SpatialPack(2, 2, 8, 1, 4, False),
73-
SpatialPack(2, 2, 8, 1, 8, False),
74-
Im2ColPack(7, 4, 1, 16, False),
75-
Im2ColPack(7, 4, 1, 4, True),
76-
77-
# int8 mobilenet
78-
SpatialPack(2, 2, 4, 28, 1, True),
79-
SpatialPack(1, 4, 8, 14, 1, False),
80-
SpatialPack(1, 2, 16, 8, 1, True),
81-
SpatialPack(1, 4, 8, 8, 8, True),
82-
SpatialPack(2, 2, 8, 1, 1, False),
83-
SpatialPack(1, 4, 8, 4, 8, False),
84-
SpatialPack(2, 2, 8, 1, 4, False),
85-
SpatialPack(2, 2, 8, 1, 8, False),
86-
Im2ColPack(7, 4, 1, 16, False),
87-
Im2ColPack(7, 4, 1, 4, True),
88-
]
898

909
@conv2d.register(["vtacpu", "vta"])
9110
def compute(*args, **kwargs):

vta/src/sim/sim_driver.cc

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
namespace vta {
1717
namespace sim {
1818

19+
/*! \brief debug flag for skipping computation */
20+
enum DebugFlagMask {
21+
kSkipExec = 1
22+
};
23+
1924
/*!
2025
* \brief Helper class to pack and unpack bits
2126
* Applies truncation when pack to low level bits.
@@ -234,8 +239,12 @@ class SRAM {
234239
return &(data_[index]);
235240
}
236241
// Execute the load instruction on this SRAM
237-
void Load(const VTAMemInsn* op, DRAM* dram, uint64_t* load_counter) {
242+
void Load(const VTAMemInsn* op,
243+
DRAM* dram,
244+
uint64_t* load_counter,
245+
bool skip_exec) {
238246
load_counter[0] += (op->x_size * op->y_size) * kElemBytes;
247+
if (skip_exec) return;
239248
DType* sram_ptr = data_ + op->sram_base;
240249
uint8_t* dram_ptr = static_cast<uint8_t*>(dram->GetAddr(
241250
op->dram_base * kElemBytes));
@@ -306,6 +315,8 @@ class Profiler {
306315
uint64_t gemm_counter{0};
307316
/*! \brief instr counter for ALU ops */
308317
uint64_t alu_counter{0};
318+
/*! \brief set debug mode */
319+
int64_t debug_flag{0};
309320
/*! \brief clear the profiler */
310321
void Clear() {
311322
inp_load_nbytes = 0;
@@ -316,6 +327,10 @@ class Profiler {
316327
gemm_counter = 0;
317328
alu_counter = 0;
318329
}
330+
/*! \return Whether we should skip execution. */
331+
bool SkipExec() const {
332+
return (debug_flag & DebugFlagMask::kSkipExec) != 0;
333+
}
319334

320335
std::string AsJSON() {
321336
std::ostringstream os;
@@ -379,13 +394,15 @@ class Device {
379394
void RunLoad(const VTAMemInsn* op) {
380395
if (op->x_size == 0) return;
381396
if (op->memory_type == VTA_MEM_ID_INP) {
382-
inp_.Load(op, dram_, &(prof_->inp_load_nbytes));
397+
inp_.Load(op, dram_, &(prof_->inp_load_nbytes), prof_->SkipExec());
383398
} else if (op->memory_type == VTA_MEM_ID_WGT) {
384-
wgt_.Load(op, dram_, &(prof_->wgt_load_nbytes));
399+
wgt_.Load(op, dram_, &(prof_->wgt_load_nbytes), prof_->SkipExec());
385400
} else if (op->memory_type == VTA_MEM_ID_ACC) {
386-
acc_.Load(op, dram_, &(prof_->acc_load_nbytes));
401+
acc_.Load(op, dram_, &(prof_->acc_load_nbytes), prof_->SkipExec());
387402
} else if (op->memory_type == VTA_MEM_ID_UOP) {
388-
uop_.Load(op, dram_, &(prof_->uop_load_nbytes));
403+
// always load in uop, since uop is stateful
404+
// subsequent non-debug mode exec can depend on it.
405+
uop_.Load(op, dram_, &(prof_->uop_load_nbytes), false);
389406
} else {
390407
LOG(FATAL) << "Unknown memory_type=" << op->memory_type;
391408
}
@@ -397,7 +414,9 @@ class Device {
397414
op->memory_type == VTA_MEM_ID_UOP) {
398415
prof_->out_store_nbytes += (
399416
op->x_size * op->y_size * VTA_BATCH * VTA_BLOCK_OUT * VTA_OUT_WIDTH / 8);
400-
acc_.TruncStore<VTA_OUT_WIDTH>(op, dram_);
417+
if (!prof_->SkipExec()) {
418+
acc_.TruncStore<VTA_OUT_WIDTH>(op, dram_);
419+
}
401420
} else {
402421
LOG(FATAL) << "Store do not support memory_type="
403422
<< op->memory_type;
@@ -407,6 +426,7 @@ class Device {
407426
void RunGEMM(const VTAGemInsn* op) {
408427
if (!op->reset_reg) {
409428
prof_->gemm_counter += op->iter_out * op->iter_in * (op->uop_end - op->uop_bgn);
429+
if (prof_->SkipExec()) return;
410430
for (uint32_t y = 0; y < op->iter_out; ++y) {
411431
for (uint32_t x = 0; x < op->iter_in; ++x) {
412432
for (uint32_t uindex = op->uop_bgn; uindex < op->uop_end; ++uindex) {
@@ -440,6 +460,7 @@ class Device {
440460
}
441461
}
442462
} else {
463+
if (prof_->SkipExec()) return;
443464
// reset
444465
for (uint32_t y = 0; y < op->iter_out; ++y) {
445466
for (uint32_t x = 0; x < op->iter_in; ++x) {
@@ -506,6 +527,7 @@ class Device {
506527
template<bool use_imm, typename F>
507528
void RunALULoop(const VTAAluInsn* op, F func) {
508529
prof_->alu_counter += op->iter_out * op->iter_in * op->uop_end - op->uop_bgn;
530+
if (prof_->SkipExec()) return;
509531
for (int y = 0; y < op->iter_out; ++y) {
510532
for (int x = 0; x < op->iter_in; ++x) {
511533
for (int k = op->uop_bgn; k < op->uop_end; ++k) {
@@ -548,6 +570,10 @@ TVM_REGISTER_GLOBAL("vta.simulator.profiler_clear")
548570
.set_body([](TVMArgs args, TVMRetValue* rv) {
549571
Profiler::ThreadLocal()->Clear();
550572
});
573+
TVM_REGISTER_GLOBAL("vta.simulator.profiler_debug_mode")
574+
.set_body([](TVMArgs args, TVMRetValue* rv) {
575+
Profiler::ThreadLocal()->debug_flag = args[0];
576+
});
551577
TVM_REGISTER_GLOBAL("vta.simulator.profiler_status")
552578
.set_body([](TVMArgs args, TVMRetValue* rv) {
553579
*rv = Profiler::ThreadLocal()->AsJSON();

vta/tests/python/unittest/test_vta_insn.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,16 @@ def verify(s):
183183

184184
if env.TARGET == "sim":
185185
simulator.clear_stats()
186+
simulator.debug_mode(simulator.DEBUG_SKIP_EXEC)
186187
f(x_nd, w_nd, y_nd)
187-
print(simulator.stats())
188+
stat1 = simulator.stats()
189+
simulator.clear_stats()
190+
simulator.debug_mode(0)
191+
f(x_nd, w_nd, y_nd)
192+
stat2 = simulator.stats()
193+
for k, v in stat1.items():
194+
if k != "uop_load_nbytes":
195+
assert stat1[k] == stat2[k]
188196
else:
189197
f(x_nd, w_nd, y_nd)
190198

0 commit comments

Comments
 (0)