Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Disable fq2i for nn.softmax by default
Co-authored-by: Toshiki Maekawa <toshiki.maekawa-aoha-renesas@aoha.co.jp>
  • Loading branch information
maekawatoshiki and Toshiki Maekawa committed May 14, 2023
commit 0f502449e585b163b0bfbe5f231e7f08ff1bed49
1 change: 1 addition & 0 deletions python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
debug,
register_external_compiler,
register_fake_quantization_to_integer,
register_optional_fake_quantization_to_integer,
register_mixed_precision_conversion,
)
from . import strategy
Expand Down
21 changes: 21 additions & 0 deletions python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,27 @@ def register_fake_quantization_to_integer(op_name, func=None, level=10):
return tvm.ir.register_op_attr(op_name, "FTVMFakeQuantizationToInteger", func, level)


def register_optional_fake_quantization_to_integer(op_name, func=None, level=10):
"""Register optional quantize function for an op

Given an op and Affine Types on it's inputs, this function should return the op
in affine space/integer operators and the new type of the output, where affine
denotes the transformation x_real = (x_affine - zero_point) * scale

Parameters
----------
op_name : str
The name of the operator

func: function (expr: Expr, map: Map<Expr, AffineType>) -> new_expr: Expr
The function for translating the op into affine space and integer operators

level : int
The priority level
"""
return tvm.ir.register_op_attr(op_name, "FTVMOptionalFakeQuantizationToInteger", func, level)


def register_mixed_precision_conversion(op_name, func=None, level=10):
"""Register mixed precision conversion function for an op

Expand Down
7 changes: 5 additions & 2 deletions python/tvm/relay/transform/fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from tvm.relay.qnn.op import canonicalizations
from tvm.tir import bijective_layout

from ..op import register_fake_quantization_to_integer
from ..op import (
register_fake_quantization_to_integer,
register_optional_fake_quantization_to_integer,
)


def fold_constant(expr):
Expand Down Expand Up @@ -635,7 +638,7 @@ def take(expr, type_map):
return [out, t]


@register_fake_quantization_to_integer("nn.softmax")
@register_optional_fake_quantization_to_integer("nn.softmax")
def softmax(expr, type_map):
"""Rewrite a softmax op"""
arg = expr.args[0]
Expand Down
9 changes: 7 additions & 2 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,7 +1251,7 @@ def AnnotateSpans():
return _ffi_api.AnnotateSpans()


def FakeQuantizationToInteger(hard_fail=False, use_qat=False):
def FakeQuantizationToInteger(hard_fail=False, use_qat=False, optional_qnn_ops=[]):
# pylint: disable=anomalous-backslash-in-string
"""
Find regions of the graph of the form
Expand Down Expand Up @@ -1298,12 +1298,17 @@ def FakeQuantizationToInteger(hard_fail=False, use_qat=False):
|
q

optional_qnn_ops : List[str]
Specify a list of operator names to explicitly enable conversion for
specific ops disabled by default.
Example: ['nn.softmax']

Returns
-------
ret : tvm.transform.Pass
The registered FakeQuantizationToInteger pass.
"""
return _ffi_api.FakeQuantizationToInteger(hard_fail, use_qat)
return _ffi_api.FakeQuantizationToInteger(hard_fail, use_qat, optional_qnn_ops)


def FlattenAtrousConv():
Expand Down
73 changes: 53 additions & 20 deletions src/relay/transforms/fake_quantization_to_integer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,12 @@ void SubgraphExtractor::VisitExpr_(const CallNode* call_node) {

class SubgraphMutator : public ExprMutator {
public:
SubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types, bool hard_fail)
: subgraph_(subgraph), affine_types_(affine_types), hard_fail_(hard_fail) {}
SubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types, bool hard_fail,
const std::unordered_set<String>& optional_qnn_ops)
: subgraph_(subgraph),
affine_types_(affine_types),
hard_fail_(hard_fail),
optional_qnn_ops_(optional_qnn_ops) {}

Expr MutateSubgraph(const Expr& expr) {
if (subgraph_.size() == 0) {
Expand All @@ -176,9 +180,14 @@ class SubgraphMutator : public ExprMutator {
out_type_ = affine_types_[expr];
static auto fqfq =
Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger");
static auto opt_fqfq =
Op::HasAttrMap("FTVMOptionalFakeQuantizationToInteger")
? Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMOptionalFakeQuantizationToInteger")
: fqfq;
for (auto node : subgraph_) {
const Op op = Downcast<Op>(node.as<CallNode>()->op);
if (!fqfq.count(Downcast<Op>(op))) {
if (!fqfq.count(Downcast<Op>(op)) &&
!(optional_qnn_ops_.count(op->name) && opt_fqfq.count(Downcast<Op>(op)))) {
// Only modify the subgraph if we have translation
// rules for every op
if (hard_fail_) {
Expand Down Expand Up @@ -207,8 +216,12 @@ class SubgraphMutator : public ExprMutator {

static auto fqfq =
Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger");
static auto opt_fqfq =
Op::HasAttrMap("FTVMOptionalFakeQuantizationToInteger")
? Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMOptionalFakeQuantizationToInteger")
: fqfq;
Op op = Downcast<Op>(call_node->op);
if (fqfq.count(op)) {
if (fqfq.count(op) || (optional_qnn_ops_.count(op->name) && opt_fqfq.count(op))) {
Expr expr;
if (op == dequantize_op_) {
expr = GetRef<Expr>(call_node);
Expand All @@ -219,7 +232,7 @@ class SubgraphMutator : public ExprMutator {
affine_types_.Set(expr, out_type_);
}
// Call the rewrite
Array<ObjectRef> vals = fqfq[op](expr, affine_types_);
Array<ObjectRef> vals = (fqfq.count(op) ? fqfq : opt_fqfq)[op](expr, affine_types_);
// Save the outputs of the rewrite
ICHECK(vals.size() == 2)
<< "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for "
Expand Down Expand Up @@ -256,13 +269,16 @@ class SubgraphMutator : public ExprMutator {
AffineTypeMap affine_types_;
AffineType out_type_;
const bool hard_fail_;
const std::unordered_set<String>& optional_qnn_ops_;
const Op quantize_op_ = Op::Get("qnn.quantize");
const Op dequantize_op_ = Op::Get("qnn.dequantize");
};

class FakeQuantizationRewriter : public MixedModeMutator {
public:
explicit FakeQuantizationRewriter(bool hard_fail) : hard_fail_(hard_fail) {}
explicit FakeQuantizationRewriter(bool hard_fail,
const std::unordered_set<String>& optional_qnn_ops)
: hard_fail_(hard_fail), optional_qnn_ops_(optional_qnn_ops) {}

protected:
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
Expand All @@ -286,15 +302,16 @@ class FakeQuantizationRewriter : public MixedModeMutator {
for (auto expr : subgraph) {
post_subgraph.insert(memo_[expr]);
}
Expr out =
SubgraphMutator(post_subgraph, post_affine_types, hard_fail_).MutateSubgraph(post);
Expr out = SubgraphMutator(post_subgraph, post_affine_types, hard_fail_, optional_qnn_ops_)
.MutateSubgraph(post);
return out;
}
}
return post;
}
const Op quantize_op_ = Op::Get("qnn.quantize");
const bool hard_fail_;
const std::unordered_set<String>& optional_qnn_ops_;
};

/* Checks if the operation to convert QAT pass is enabled.
Expand Down Expand Up @@ -404,8 +421,12 @@ class QATSubgraphExtractor : public ExprVisitor {

class QATSubgraphMutator : public ExprMutator {
public:
QATSubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types, bool hard_fail)
: subgraph_(subgraph), affine_types_(affine_types), hard_fail_(hard_fail) {}
QATSubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types, bool hard_fail,
const std::unordered_set<String>& optional_qnn_ops)
: subgraph_(subgraph),
affine_types_(affine_types),
hard_fail_(hard_fail),
optional_qnn_ops_(optional_qnn_ops) {}

Expr MutateSubgraph(const Expr& expr) {
if (subgraph_.size() == 0) {
Expand Down Expand Up @@ -447,17 +468,21 @@ class QATSubgraphMutator : public ExprMutator {
Expr out;
static auto fqfq =
Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger");
static auto opt_fqfq =
Op::HasAttrMap("FTVMOptionalFakeQuantizationToInteger")
? Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMOptionalFakeQuantizationToInteger")
: fqfq;

Op op = Downcast<Op>(call_node->op);
if (fqfq.count(op)) {
if (fqfq.count(op) || (optional_qnn_ops_.count(op->name) && opt_fqfq.count(op))) {
Expr expr;
if (op == dequantize_op_) {
expr = GetRef<Expr>(call_node);
} else {
expr = ExprMutator::VisitExpr_(call_node);
}
// Call the rewrite
Array<ObjectRef> vals = fqfq[op](expr, affine_types_);
Array<ObjectRef> vals = (fqfq.count(op) ? fqfq : opt_fqfq)[op](expr, affine_types_);
// Save the outputs of the rewrite
ICHECK(vals.size() == 2)
<< "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for "
Expand Down Expand Up @@ -500,13 +525,15 @@ class QATSubgraphMutator : public ExprMutator {
ExprSet subgraph_;
AffineTypeMap affine_types_;
const bool hard_fail_;
const std::unordered_set<String>& optional_qnn_ops_;
const Op dequantize_op_ = Op::Get("qnn.dequantize");
const CallNode* quantize_node_ = nullptr;
};

class QATRewriter : public MixedModeMutator {
public:
explicit QATRewriter(bool hard_fail) : hard_fail_(hard_fail) {}
explicit QATRewriter(bool hard_fail, const std::unordered_set<String>& optional_qnn_ops)
: hard_fail_(hard_fail), optional_qnn_ops_(optional_qnn_ops) {}

protected:
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
Expand All @@ -516,31 +543,37 @@ class QATRewriter : public MixedModeMutator {
QATSubgraphExtractor extractor;
ExprSet subgraph = extractor.GetSubgraph(post);
AffineTypeMap affine_types = extractor.GetAffineTypes();
Expr out = QATSubgraphMutator(subgraph, affine_types, hard_fail_).MutateSubgraph(post);
Expr out = QATSubgraphMutator(subgraph, affine_types, hard_fail_, optional_qnn_ops_)
.MutateSubgraph(post);
return out;
}
}
return post;
}
const bool hard_fail_;
const std::unordered_set<String>& optional_qnn_ops_;
};

Expr FakeQuantizationToInteger(const Expr& expr, const IRModule& mod, bool hard_fail,
bool use_qat) {
auto fq_expr = FakeQuantizationRewriter(hard_fail).Mutate(expr);
Expr FakeQuantizationToInteger(const Expr& expr, const IRModule& mod, bool hard_fail, bool use_qat,
const Array<String>& optional_qnn_ops) {
const std::unordered_set<String> optional_qnn_ops_(optional_qnn_ops.begin(),
optional_qnn_ops.end());
auto fq_expr = FakeQuantizationRewriter(hard_fail, optional_qnn_ops_).Mutate(expr);
if (use_qat) {
fq_expr = tvm::relay::InferType(fq_expr);
fq_expr = QATRewriter(hard_fail).Mutate(fq_expr);
fq_expr = QATRewriter(hard_fail, optional_qnn_ops_).Mutate(fq_expr);
}
return fq_expr;
}

namespace transform {

Pass FakeQuantizationToInteger(bool hard_fail, bool use_qat) {
Pass FakeQuantizationToInteger(bool hard_fail, bool use_qat,
const Array<String>& optional_qnn_ops) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(FakeQuantizationToInteger(f, m, hard_fail, use_qat));
return Downcast<Function>(
FakeQuantizationToInteger(f, m, hard_fail, use_qat, optional_qnn_ops));
};
return CreateFunctionPass(pass_func, 0, "FakeQuantizationToInteger", {"InferType", "DivToMul"});
}
Expand Down
4 changes: 3 additions & 1 deletion tests/python/relay/test_pass_fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,7 +1128,9 @@ def test_fake_quantize_softmax():

mod = tvm.IRModule.from_expr(op)
mod = tvm.relay.transform.InferType()(mod)
mod_int = tvm.relay.transform.FakeQuantizationToInteger(hard_fail=True)(mod)
mod_int = tvm.relay.transform.FakeQuantizationToInteger(
hard_fail=True, optional_qnn_ops=["nn.softmax"]
)(mod)
assert not tvm.ir.structural_equal(mod, mod_int)

result = (
Expand Down