Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 57 additions & 29 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ class SimplifyConsecutiveCast : public DFPatternRewrite {
DFPattern cast1_;
};

bool CheckDataTypeMaxMinValue(DataType dtype, double min_value, double max_value) {
/*! If mode == 0, return true if the interval [min_value, max_value] contains the range of dtype,
* and return false otherwise. If mode == 1, return true if the interval [min_value, max_value] is
* contained by the range of dtype, and return false otherwise.*/
bool CheckDataTypeMaxMinValue(DataType dtype, double min_value, double max_value, int mode = 0) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please comment the meaning of different mode

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

double lbound{}, ubound{};
if (dtype.is_int() || dtype.is_uint()) {
ubound = static_cast<double>(Downcast<IntImm>(tvm::max_value(dtype))->value);
Expand All @@ -169,56 +172,81 @@ bool CheckDataTypeMaxMinValue(DataType dtype, double min_value, double max_value
ubound = Downcast<FloatImm>(tvm::max_value(dtype))->value;
lbound = Downcast<FloatImm>(tvm::min_value(dtype))->value;
}
return max_value >= ubound && min_value <= lbound;
if (mode == 0) {
return max_value >= ubound && min_value <= lbound;
} else if (mode == 1) {
return max_value <= ubound && min_value >= lbound;
} else {
LOG(FATAL) << "invalid mode " << mode << " in CheckDataTypeMaxMinValue";
return false;
}
}

/*!
* \brief SimplifyClipAndConsecutiveCast matches the pattern clip->cast->cast and remove redundant
* casts.
* Analysis of "redundancy" is done based on clip min/max values and min/max values of casted data
* type.
* \brief SimplifyClipAndConsecutiveCast matches the pattern clip->cast->...->cast and remove
* redundant casts. Analysis of "redundancy" is done based on clip min/max values and min/max values
* of casted data type.
*
* Example:
* %0 == [type=int32]
* %1 = clip(%0, a_min=0f, a_max=255f) [type=int32]
* %2 = cast(%1, dtype="uint8") [type=uint8]
* %3 = cast(%2, dtype="int32") [type=int32]
*
* Optimized to (both casts can be removed):
* %1 = clip(%0, a_min=0f, a_max=255f) [type=int32]
*/
class SimplifyClipAndConsecutiveCast : public DFPatternRewrite {
public:
SimplifyClipAndConsecutiveCast() {
clip_ = IsOp("clip")({IsWildcard()});
cast1_ = IsOp("cast")({clip_});
pattern_ = IsOp("cast")({cast1_});
ObjectPtr<CallPatternNode> pattern_ptr = make_object<CallPatternNode>();
pattern_ptr->op = IsOp("cast");
pattern_ptr->args.clear();
pattern_ = CallPattern(pattern_ptr);
AltPattern or_pattern{pattern_, clip_};
pattern_ptr->args.push_back(or_pattern);
}

Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
auto clip = Downcast<Call>(node_map[clip_][0]);
const CallNode* clip_node = clip.as<CallNode>();
const ClipAttrs* clip_attrs = clip_node->attrs.as<ClipAttrs>();
DataType clip_dtype = Downcast<TensorType>(clip->checked_type())->dtype;

auto cast1 = Downcast<Call>(node_map[cast1_][0]);
DataType cast1_dtype = Downcast<TensorType>(cast1->checked_type())->dtype;
std::vector<Expr> remaining_casts{};
Expr cast_expr{post};
while (cast_expr != clip) {
DataType cast_dtype = Downcast<TensorType>(cast_expr->checked_type())->dtype;
if (!CheckDataTypeMaxMinValue(cast_dtype, clip_attrs->a_min, clip_attrs->a_max, 1)) {
remaining_casts.push_back(cast_expr);
}
cast_expr = cast_expr.as<CallNode>()->args[0];
}

auto cast2 = Downcast<Call>(post);
DataType cast2_dtype = Downcast<TensorType>(cast2->checked_type())->dtype;
Expr last_op = (remaining_casts.size() == 0) ? clip : remaining_casts[0];
DataType last_op_dtype = Downcast<TensorType>(last_op->checked_type())->dtype;
bool need_additional_cast{false};
if (last_op_dtype != Downcast<TensorType>(post->checked_type())->dtype) {
need_additional_cast = true;
}

if (clip_dtype == cast2_dtype &&
CheckDataTypeMaxMinValue(cast1_dtype, clip_attrs->a_min, clip_attrs->a_max)) {
// Case 1:
// Data type of Clip == target data type of second Cast and min/max value of Clip == min/max
// value of first Clip target data type. In this case both Clip ops can be removed.
// Example:
// %0 == [type=int32]
// %1 = clip(%0, a_min=0f, a_max=255f) [type=int32]
// %2 = cast(%1, dtype="uint8") [type=uint8]
// %3 = cast(%2, dtype="int32") [type=int32]
//
// Optimized to (both casts can be removed):
// %1 = clip(%0, a_min=0f, a_max=255f) [type=int32]
return node_map[clip_][0];
Expr res{clip};
for (size_t i = remaining_casts.size(); i > 0; --i) {
auto attrs = make_object<CastAttrs>();
attrs->dtype = remaining_casts[i - 1].as<CallNode>()->attrs.as<CastAttrs>()->dtype;
res = Call(Op::Get("cast"), {res}, Attrs(attrs), {});
}
return post;
if (need_additional_cast) {
auto attrs = make_object<CastAttrs>();
attrs->dtype = Downcast<TensorType>(post->checked_type())->dtype;
res = Call(Op::Get("cast"), {res}, Attrs(attrs), {});
}
return res;
}

protected:
DFPattern clip_, cast1_;
DFPattern clip_;
};

/*!
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/aot/test_crt_aot_usmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def test_byoc_microtvm(merge_compiler_regions):
"model_url, usmp_algo, workspace_size, constant_size",
[
(MOBILENET_V1_URL, "greedy_by_size", 4845696, 8468008),
(MOBILENET_V1_URL, "greedy_by_conflicts", 4845696, 8468008),
(MOBILENET_V1_URL, "greedy_by_conflicts", 4444288, 8468008),
(MOBILENET_V1_URL, "hill_climb", 3240064, 8468008),
],
)
Expand Down
45 changes: 45 additions & 0 deletions tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,9 +746,54 @@ def expected2():
clip = relay.clip(x, a_min=0.0, a_max=255.0)
return relay.Function([x], clip)

def before3():
x = relay.var("x", shape=(4, 8), dtype="int32")
clip = relay.clip(x, a_min=0.0, a_max=255.0)
cast = relay.cast(clip, "uint8")
cast = relay.cast(cast, "int16")
cast = relay.cast(cast, "int32")
return relay.Function([x], cast)

def expected3():
x = relay.var("x", shape=(4, 8), dtype="int32")
clip = relay.clip(x, a_min=0.0, a_max=255.0)
return relay.Function([x], clip)

def before4():
x = relay.var("x", shape=(4, 8), dtype="float32")
clip = relay.clip(x, a_min=0.0, a_max=255.0)
cast = relay.cast(clip, "uint8")
cast = relay.cast(cast, "int16")
cast = relay.cast(cast, "int32")
return relay.Function([x], cast)

def expected4():
x = relay.var("x", shape=(4, 8), dtype="float32")
clip = relay.clip(x, a_min=0.0, a_max=255.0)
cast = relay.cast(clip, "int32")
return relay.Function([x], cast)

def before5():
x = relay.var("x", shape=(4, 8), dtype="float32")
clip = relay.clip(x, a_min=0.0, a_max=255.0)
cast = relay.cast(clip, "int8")
cast = relay.cast(cast, "int16")
cast = relay.cast(cast, "int32")
return relay.Function([x], cast)

def expected5():
x = relay.var("x", shape=(4, 8), dtype="float32")
clip = relay.clip(x, a_min=0.0, a_max=255.0)
cast = relay.cast(clip, "int8")
cast = relay.cast(cast, "int32")
return relay.Function([x], cast)

for before, expected in [
[before1(), expected1()],
[before2(), expected2()],
[before3(), expected3()],
[before4(), expected4()],
[before5(), expected5()],
]:
after = run_opt_pass(before, transform.SimplifyExpr())
expected = run_opt_pass(expected, transform.InferType())
Expand Down