Skip to content

Commit 2b51e79

Browse files
ZihengJiangtmoreau89
authored andcommitted
Hack for inserting start/end point of quantizing. (apache#33)
* Hack for inserting start/end point of quantizing. * Fix.
1 parent f301b7f commit 2b51e79

5 files changed

Lines changed: 61 additions & 4 deletions

File tree

include/tvm/relay/attrs/transform.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,18 @@ struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> {
150150
}
151151
};
152152

153+
154+
/*! \brief Attributes for Annotate operator */
155+
struct AnnotateAttrs : public tvm::AttrsNode<AnnotateAttrs> {
156+
std::string info;
157+
158+
TVM_DECLARE_ATTRS(AnnotateAttrs, "relay.attrs.AnnotateAttrs") {
159+
TVM_ATTR_FIELD(info)
160+
.describe("The annotation info.");
161+
}
162+
};
163+
164+
153165
/*! \brief Attributes for Clip operator */
154166
struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
155167
double a_min;

python/tvm/relay/quantize/quantize.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .. import op as _op
99
from .. import build_module as _build
1010
from ..base import register_relay_node
11-
from ..._ffi.function import register_func
11+
from ..._ffi.function import register_func, get_global_func
1212

1313

1414
class QFieldKind(object):
@@ -72,14 +72,27 @@ def simulated_quantize(data, dom_scale, bit, clip_min, clip_max,
7272
sign, rounding, kind)
7373

7474

75+
76+
SQ_CACHE_MAP = {}
77+
78+
7579
@register_func("relay.quantize.attach_simulated_quantize")
7680
def attach_simulated_quantize(data, kind):
81+
global SQ_CACHE_MAP
82+
key = data
83+
if data in SQ_CACHE_MAP:
84+
return SQ_CACHE_MAP[data]
85+
if len(SQ_CACHE_MAP) == 0:
86+
f = get_global_func("relay._quantize.make_annotate_op")
87+
data = f(data, "quantize_start")
7788
dom_scale = _expr.var("dom_scale")
7889
bit = _expr.var("bit")
7990
clip_min = _expr.var("clip_min")
8091
clip_max = _expr.var("clip_max")
81-
return simulated_quantize(data, dom_scale, bit, clip_min, clip_max,
82-
True, "round", kind)
92+
ret = simulated_quantize(data, dom_scale, bit, clip_min, clip_max,
93+
True, "round", kind)
94+
SQ_CACHE_MAP[key] = ret
95+
return ret
8396

8497

8598
def register_qfield_rewrite(op_name, frewrite=None, level=10):

src/relay/op/tensor/unary.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,24 @@ RELAY_REGISTER_UNARY_OP("copy")
8181
.set_support_level(3)
8282
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity));
8383

84+
85+
// relay.annotate
86+
TVM_REGISTER_NODE_TYPE(AnnotateAttrs);
87+
88+
RELAY_REGISTER_OP("annotate")
89+
.describe(R"code(Copy a tensor with annotation information
90+
)code" TVM_ADD_FILELINE)
91+
.set_num_inputs(1)
92+
.add_argument("data", "Tensor", "The input tensor.")
93+
.add_type_rel("Identity", IdentityRel)
94+
.set_attr<TOpPattern>("TOpPattern", kElemWise)
95+
.set_attr<TOpIsStateful>("TOpIsStateful", false)
96+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
97+
.set_attrs_type_key("relay.attrs.AnnotateAttrs")
98+
.set_support_level(10)
99+
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity));
100+
101+
84102
// relay.clip
85103
TVM_REGISTER_NODE_TYPE(ClipAttrs);
86104

src/relay/pass/quantize.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,26 @@ TVM_REGISTER_API("relay._quantize.simulated_quantize")
6666
});
6767

6868

69+
Expr MakeAnnotateOp(Expr data, std::string info) {
70+
auto attrs = make_node<AnnotateAttrs>();
71+
attrs->info = info;
72+
static const Op& op = Op::Get("annotate");
73+
return CallNode::make(op, {data}, Attrs(attrs), {});
74+
}
75+
76+
TVM_REGISTER_API("relay._quantize.make_annotate_op")
77+
.set_body_typed<Expr(Expr, std::string)>(MakeAnnotateOp);
78+
79+
6980
// =============
7081
// annotate pass
7182

7283
Expr QFieldExprNode::Realize() const {
7384
// store low bit output back
7485
const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
75-
return (*f)(expr, static_cast<int>(kQInput));
86+
Expr ret = (*f)(expr, static_cast<int>(kQInput));
87+
ret = MakeAnnotateOp(ret, "quantize_end");
88+
return ret;
7689
}
7790

7891
QFieldExpr QFieldExprNode::make(Expr expr, QFieldKind kind) {

src/relay/pass/quantize.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
3636
}
3737
};
3838

39+
3940
class QFieldExpr;
4041

4142
class QFieldExprNode : public TempExprNode {

0 commit comments

Comments
 (0)