|
8 | 8 | from .. import op as _op |
9 | 9 | from .. import build_module as _build |
10 | 10 | from ..base import register_relay_node |
11 | | -from ..._ffi.function import register_func |
| 11 | +from ..._ffi.function import register_func, get_global_func |
12 | 12 |
|
13 | 13 |
|
14 | 14 | class QFieldKind(object): |
@@ -72,14 +72,27 @@ def simulated_quantize(data, dom_scale, bit, clip_min, clip_max, |
72 | 72 | sign, rounding, kind) |
73 | 73 |
|
74 | 74 |
|
| 75 | + |
| 76 | +SQ_CACHE_MAP = {} |
| 77 | + |
| 78 | + |
75 | 79 | @register_func("relay.quantize.attach_simulated_quantize") |
76 | 80 | def attach_simulated_quantize(data, kind): |
| 81 | + global SQ_CACHE_MAP |
| 82 | + key = data |
| 83 | + if data in SQ_CACHE_MAP: |
| 84 | + return SQ_CACHE_MAP[data] |
| 85 | + if len(SQ_CACHE_MAP) == 0: |
| 86 | + f = get_global_func("relay._quantize.make_annotate_op") |
| 87 | + data = f(data, "quantize_start") |
77 | 88 | dom_scale = _expr.var("dom_scale") |
78 | 89 | bit = _expr.var("bit") |
79 | 90 | clip_min = _expr.var("clip_min") |
80 | 91 | clip_max = _expr.var("clip_max") |
81 | | - return simulated_quantize(data, dom_scale, bit, clip_min, clip_max, |
82 | | - True, "round", kind) |
| 92 | + ret = simulated_quantize(data, dom_scale, bit, clip_min, clip_max, |
| 93 | + True, "round", kind) |
| 94 | + SQ_CACHE_MAP[key] = ret |
| 95 | + return ret |
83 | 96 |
|
84 | 97 |
|
85 | 98 | def register_qfield_rewrite(op_name, frewrite=None, level=10): |
|
0 commit comments