Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 30 additions & 24 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,17 +161,15 @@ class SimplifyConsecutiveCast : public DFPatternRewrite {
};

bool CheckDataTypeMaxMinValue(DataType dtype, double min_value, double max_value) {
double lbound{}, ubound{};
if (dtype.is_int() || dtype.is_uint()) {
double ubound = static_cast<double>(Downcast<IntImm>(tvm::max_value(dtype))->value);
double lbound = static_cast<double>(Downcast<IntImm>(tvm::min_value(dtype))->value);
return ubound == max_value && lbound == min_value;
} else if (dtype.is_float()) {
double ubound = Downcast<FloatImm>(tvm::max_value(dtype))->value;
double lbound = Downcast<FloatImm>(tvm::min_value(dtype))->value;
return ubound == max_value && lbound == min_value;
ubound = static_cast<double>(Downcast<IntImm>(tvm::max_value(dtype))->value);
lbound = static_cast<double>(Downcast<IntImm>(tvm::min_value(dtype))->value);
} else if (dtype.is_float() || dtype.is_bfloat16()) {
ubound = Downcast<FloatImm>(tvm::max_value(dtype))->value;
lbound = Downcast<FloatImm>(tvm::min_value(dtype))->value;
}

return false;
return max_value >= ubound && min_value <= lbound;
}

/*!
Expand Down Expand Up @@ -224,8 +222,8 @@ class SimplifyClipAndConsecutiveCast : public DFPatternRewrite {
};

/*!
* \brief SimplifyCastClip matches the pattern cast->clip and remove redundant Cast based on Clip
* min/max values and min/max values of Cast target data type.
* \brief SimplifyClip removes redundant Clip based on its a_min/a_max values and the min/max values
* of the data type.
*
* Example:
* %1 = cast(%0, dtype="uint8") [type=uint8]
Expand All @@ -234,30 +232,38 @@ class SimplifyClipAndConsecutiveCast : public DFPatternRewrite {
* Optimized to (remove Clip):
* %1 = cast(%0, dtype="uint8") [type=uint8]
*/
class SimplifyCastClip : public DFPatternRewrite {
class SimplifyClip : public DFPatternRewrite {
public:
SimplifyCastClip() {
cast_ = IsOp("cast")({IsWildcard()});
pattern_ = IsOp("clip")({cast_});
SimplifyClip() {
x_ = IsWildcard();
pattern_ = IsOp("clip")({x_});
}

Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
auto cast = Downcast<Call>(node_map[cast_][0]);
DataType cast_dtype = Downcast<TensorType>(cast->checked_type())->dtype;
DataType cast_dtype = Downcast<TensorType>(pre->checked_type())->dtype;

auto clip = Downcast<Call>(post);
const CallNode* clip_node = clip.as<CallNode>();
const CallNode* clip_node = post.as<CallNode>();
const ClipAttrs* clip_attrs = clip_node->attrs.as<ClipAttrs>();

// TODO(kfeng123): For now, the arg of "clip" is forced to not be "qnn.requantize" and
// "qnn.add". This is to avoid destroying the structure required by LegalizeQnnOpForDnnl
auto child{post.as<CallNode>()->args[0].as<CallNode>()};
if (child && child->op.as<OpNode>()) {
String op_name{child->op.as<OpNode>()->name};
if (op_name == "qnn.requantize" || op_name == "qnn.add") {
return post;
}
}

if (CheckDataTypeMaxMinValue(cast_dtype, clip_attrs->a_min, clip_attrs->a_max)) {
return node_map[cast_][0];
return node_map[x_][0];
}
return post;
}

protected:
DFPattern clip_, cast_;
DFPattern x_;
};

/*!
Expand Down Expand Up @@ -992,7 +998,7 @@ class SimplifyBinomial : public DFPatternRewrite {
DFPattern y_;
};

/*! \brief Simplifying x/sqrt to x*sqrt */
/*! \brief Simplifying x/sqrt to x*rsqrt */
class SimplifyRSqrt : public DFPatternRewrite {
public:
SimplifyRSqrt() {
Expand Down Expand Up @@ -1085,7 +1091,7 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
composer.AddRewrite<SimplifyDQArgMin>();
composer.AddRewrite<SimplifyDQArgSort>();
composer.AddRewrite<SimplifyClipAndConsecutiveCast>();
composer.AddRewrite<SimplifyCastClip>();
composer.AddRewrite<SimplifyClip>();
composer.AddRewrite<SimplifyBinomial>();
return RewritePatterns(composer.MakeCallbacks(), expr, mod);
}
Expand All @@ -1099,7 +1105,7 @@ Expr SimplifyExprPostAlterOp(const Expr& expr, const IRModule& mod) {
composer.AddRewrite<SimplifySameCast>();
composer.AddRewrite<SimplifyConsecutiveCast>();
composer.AddRewrite<SimplifyClipAndConsecutiveCast>();
composer.AddRewrite<SimplifyCastClip>();
composer.AddRewrite<SimplifyClip>();
return RewritePatterns(composer.MakeCallbacks(), expr, mod);
}

Expand Down
82 changes: 66 additions & 16 deletions tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,34 +722,84 @@ def expected():


def test_simplify_clip_cast():
x = relay.var("x", shape=(4, 8), dtype="int32")
def before1():
x = relay.var("x", shape=(4, 8), dtype="int32")
clip = relay.clip(x, a_min=0.0, a_max=255.0)
cast = relay.cast(clip, "uint8")
cast = relay.cast(cast, "int32")
return relay.Function([x], cast)

def before():
def expected1():
x = relay.var("x", shape=(4, 8), dtype="int32")
clip = relay.clip(x, a_min=0.0, a_max=255.0)
return relay.Function([x], clip)

def before2():
x = relay.var("x", shape=(4, 8), dtype="int32")
clip = relay.clip(x, a_min=0.0, a_max=255.0)
cast = relay.cast(clip, "uint8")
return relay.cast(cast, "int32")
cast = relay.cast(cast, "int32")
return relay.Function([x], cast)

def expected():
return relay.clip(x, a_min=0.0, a_max=255.0)
def expected2():
x = relay.var("x", shape=(4, 8), dtype="int32")
clip = relay.clip(x, a_min=0.0, a_max=255.0)
return relay.Function([x], clip)

opt = run_opt_pass(before(), transform.SimplifyExpr())
ref = run_infer_type(expected())
assert tvm.ir.structural_equal(opt, ref)
for before, expected in [
[before1(), expected1()],
[before2(), expected2()],
]:
after = run_opt_pass(before, transform.SimplifyExpr())
expected = run_opt_pass(expected, transform.InferType())
assert tvm.ir.structural_equal(after, expected), "\nafter: {} \nexpected: {}".format(
after, expected
)


def test_simplify_cast_clip():
x = relay.var("x", shape=(4, 8), dtype="int32")
def before1():
x = relay.var("x", shape=(4, 8), dtype="int32")
cast = relay.cast(x, "uint8")
clip = relay.clip(cast, a_min=0.0, a_max=255.0)
return relay.Function([x], clip)

def before():
def expected1():
x = relay.var("x", shape=(4, 8), dtype="int32")
cast = relay.cast(x, "uint8")
return relay.clip(cast, a_min=0.0, a_max=255.0)
return relay.Function([x], cast)

def expected():
return relay.cast(x, "uint8")
def before2():
x = relay.var("x", shape=(4, 8), dtype="uint8")
clip = relay.clip(x, a_min=0.0, a_max=255.0)
return relay.Function([x], clip)

opt = run_opt_pass(before(), transform.SimplifyExpr())
ref = run_infer_type(expected())
assert tvm.ir.structural_equal(opt, ref)
def expected2():
x = relay.var("x", shape=(4, 8), dtype="uint8")
return relay.Function([x], x)

def before3():
x = relay.var("x", shape=(4, 8), dtype="float32")
cast = relay.cast(x, "bfloat16")
clip = relay.clip(cast, a_min=-0.2, a_max=0.4)
return relay.Function([x], clip)

def expected3():
x = relay.var("x", shape=(4, 8), dtype="float32")
cast = relay.cast(x, "bfloat16")
clip = relay.clip(cast, a_min=-0.2, a_max=0.4)
return relay.Function([x], clip)

for before, expected in [
[before1(), expected1()],
[before2(), expected2()],
[before3(), expected3()],
]:
after = run_opt_pass(before, transform.SimplifyExpr())
expected = run_opt_pass(expected, transform.InferType())
assert tvm.ir.structural_equal(after, expected), "\nafter: {} \nexpected: {}".format(
after, expected
)


def test_simplify_add():
Expand Down