From dcbf83e36ac59fb0bd045c604a9657036ff4b9a4 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Wed, 2 Aug 2023 13:42:38 +0300 Subject: [PATCH 1/2] [Relay] Stop ToMixedPrecision when constant is out of dtype range In some layers, e.g. Clip, we might have a compilation error in the case when operation takes on the input a constant which is out of target data type range. To prevent such situation, a new method was introduced. It compares values of constant attributes with the range of the target data type. In case if the value is out of range then float32 will be used. --- src/relay/transforms/to_mixed_precision.cc | 41 +++++++++++++++- tests/python/relay/test_to_mixed_precision.py | 49 +++++++++++++++++++ 2 files changed, 88 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 820bc6e58e4d..c54e4704b3a5 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -31,6 +31,7 @@ #include +#include "../../support/scalars.h" #include "pattern_utils.h" namespace tvm { @@ -110,6 +111,39 @@ class MixedPrecisionPass : public MixedModeMutator { std::vector original_dtype_; bool keep_orig_output_dtype_; + /*! \brief If some of the constant attributes are out of mixed_precision_type_ bounds, then + * computation cannot be performed in mixed precision. */ + bool IsMixedPrecisionApplicableToAttrs(const Attrs& attrs) const { + if (attrs.get() != nullptr) { + double min_bound; + double max_bound; + if (mixed_precision_type_.is_float16()) { + min_bound = -support::kMaxFloat16; + max_bound = support::kMaxFloat16; + } else if (mixed_precision_type_.is_bfloat16()) { + min_bound = -support::kMaxBFloat16; + max_bound = support::kMaxBFloat16; + } else if (mixed_precision_type_.is_float8()) { + double bound = (mixed_precision_type_.code() == DataType::kE4M3Float) ? support::kMaxE4M3 + : support::kMaxE5M2; + min_bound = -bound; + max_bound = bound; + } else if (mixed_precision_type_.is_float()) { + min_bound = std::numeric_limits::lowest(); + max_bound = std::numeric_limits::max(); + } else { + return true; + } + + if (auto cur_attrs = attrs.as()) { + if (cur_attrs->a_min < min_bound || cur_attrs->a_max > max_bound) { + return false; + } + } + } + return true; + } + Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const { /* If the accumulation dtype is in the attributes make a copy and mutate the field. */ Attrs cur_attrs = call->attrs; @@ -382,9 +416,12 @@ class MixedPrecisionPass : public MixedModeMutator { all_args_mixed_type_compatible ? MIXED_PRECISION_ALWAYS : MIXED_PRECISION_NEVER; } + bool is_mixed_precision_applicable = + (bool)(final_category == MIXED_PRECISION_ALWAYS && + IsMixedPrecisionApplicableToAttrs(pre_call_node->attrs)); // Create the new arguments to the call. DataType wanted_arg_dtypes = - final_category == MIXED_PRECISION_ALWAYS ? mixed_precision_type_ : DataType::Float(32); + is_mixed_precision_applicable ? mixed_precision_type_ : DataType::Float(32); auto call_args_and_types = CastAllArgs(post_call_node->args, cur_arg_types, wanted_arg_dtypes); Array new_args = call_args_and_types.first; Array new_arg_types; @@ -397,7 +434,7 @@ class MixedPrecisionPass : public MixedModeMutator { } // Finally create the new attributes. - if (final_category == MIXED_PRECISION_ALWAYS) { + if (is_mixed_precision_applicable) { Attrs new_attrs = GetNewAttrs(pre_call_node, accumulation_dtype); Expr output = Call(cur_op, new_args, new_attrs, new_arg_types, pre_call_node->span); if (accumulation_dtype != output_dtype) { diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index 771d366df079..a802eee6d644 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -537,5 +537,54 @@ def test_convert_follow_node_with_integer_arguments(target_precision): assert tvm.ir.structural_equal(expected_mod, output_mod) +def test_clip(target_precision): + data = relay.var("data", shape=[1, 10], dtype="float32") + res = relay.clip(data, a_min=-128000, a_max=128000) + + mod = tvm.IRModule.from_expr(res) + + mod_params = { + "data": np.random.uniform(-1, 1, size=[1, 10]).astype("float32"), + } + output_mod = verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype=target_precision, atol=0.01, rtol=0.01 + ) + + # Create expected module + if target_precision == "bfloat16": + data = relay.cast(relay.var("data", shape=[1, 10]), target_precision) + res = relay.clip(data, a_min=-128000, a_max=128000) + expected_mod = tvm.IRModule.from_expr(res) + expected_mod = InferType()(expected_mod) + assert tvm.ir.structural_equal(expected_mod, output_mod) + + +def test_clip_with_pre_op(target_precision): + data = relay.var("data", shape=[1, 10], dtype="float32") + const = relay.const(5, "float32") + res = relay.divide(data, const) + res = relay.clip(res, a_min=-128000, a_max=128000) + + mod = tvm.IRModule.from_expr(res) + + mod_params = { + "data": np.random.uniform(-1, 1, size=[1, 10]).astype("float32"), + } + output_mod = verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype=target_precision, atol=0.01, rtol=0.01 + ) + + # Create expected module + data = relay.cast(relay.var("data", shape=[1, 10]), target_precision) + const = relay.cast(relay.const(5, "float32"), target_precision) + res = relay.divide(data, const) + if target_precision == "float16": + res = relay.cast(res, "float32") + res = relay.clip(res, a_min=-128000, a_max=128000) + expected_mod = tvm.IRModule.from_expr(res) + expected_mod = InferType()(expected_mod) + assert tvm.ir.structural_equal(expected_mod, output_mod) + + if __name__ == "__main__": tvm.testing.main() From bfa63d60c2a470db35c0a6e759b50ea2658c629d Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Wed, 2 Aug 2023 21:21:50 +0300 Subject: [PATCH 2/2] Fix lint --- src/relay/transforms/to_mixed_precision.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index c54e4704b3a5..4638ee547706 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -417,8 +417,8 @@ class MixedPrecisionPass : public MixedModeMutator { } bool is_mixed_precision_applicable = - (bool)(final_category == MIXED_PRECISION_ALWAYS && - IsMixedPrecisionApplicableToAttrs(pre_call_node->attrs)); + static_cast(final_category == MIXED_PRECISION_ALWAYS && + IsMixedPrecisionApplicableToAttrs(pre_call_node->attrs)); // Create the new arguments to the call. DataType wanted_arg_dtypes = is_mixed_precision_applicable ? mixed_precision_type_ : DataType::Float(32);