diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 820bc6e58e4d..4638ee547706 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -31,6 +31,7 @@ #include +#include "../../support/scalars.h" #include "pattern_utils.h" namespace tvm { @@ -110,6 +111,39 @@ class MixedPrecisionPass : public MixedModeMutator { std::vector original_dtype_; bool keep_orig_output_dtype_; + /*! \brief If some of the constant attributes are out of mixed_precision_type_ bounds, then + * computation cannot be performed in mixed precision. */ + bool IsMixedPrecisionApplicableToAttrs(const Attrs& attrs) const { + if (attrs.get() != nullptr) { + double min_bound; + double max_bound; + if (mixed_precision_type_.is_float16()) { + min_bound = -support::kMaxFloat16; + max_bound = support::kMaxFloat16; + } else if (mixed_precision_type_.is_bfloat16()) { + min_bound = -support::kMaxBFloat16; + max_bound = support::kMaxBFloat16; + } else if (mixed_precision_type_.is_float8()) { + double bound = (mixed_precision_type_.code() == DataType::kE4M3Float) ? support::kMaxE4M3 + : support::kMaxE5M2; + min_bound = -bound; + max_bound = bound; + } else if (mixed_precision_type_.is_float()) { + min_bound = std::numeric_limits::lowest(); + max_bound = std::numeric_limits::max(); + } else { + return true; + } + + if (auto cur_attrs = attrs.as()) { + if (cur_attrs->a_min < min_bound || cur_attrs->a_max > max_bound) { + return false; + } + } + } + return true; + } + Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const { /* If the accumulation dtype is in the attributes make a copy and mutate the field. */ Attrs cur_attrs = call->attrs; @@ -382,9 +416,12 @@ class MixedPrecisionPass : public MixedModeMutator { all_args_mixed_type_compatible ? MIXED_PRECISION_ALWAYS : MIXED_PRECISION_NEVER; } + bool is_mixed_precision_applicable = + static_cast(final_category == MIXED_PRECISION_ALWAYS && + IsMixedPrecisionApplicableToAttrs(pre_call_node->attrs)); // Create the new arguments to the call. DataType wanted_arg_dtypes = - final_category == MIXED_PRECISION_ALWAYS ? mixed_precision_type_ : DataType::Float(32); + is_mixed_precision_applicable ? mixed_precision_type_ : DataType::Float(32); auto call_args_and_types = CastAllArgs(post_call_node->args, cur_arg_types, wanted_arg_dtypes); Array new_args = call_args_and_types.first; Array new_arg_types; @@ -397,7 +434,7 @@ class MixedPrecisionPass : public MixedModeMutator { } // Finally create the new attributes. - if (final_category == MIXED_PRECISION_ALWAYS) { + if (is_mixed_precision_applicable) { Attrs new_attrs = GetNewAttrs(pre_call_node, accumulation_dtype); Expr output = Call(cur_op, new_args, new_attrs, new_arg_types, pre_call_node->span); if (accumulation_dtype != output_dtype) { diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index 771d366df079..a802eee6d644 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -537,5 +537,54 @@ def test_convert_follow_node_with_integer_arguments(target_precision): assert tvm.ir.structural_equal(expected_mod, output_mod) +def test_clip(target_precision): + data = relay.var("data", shape=[1, 10], dtype="float32") + res = relay.clip(data, a_min=-128000, a_max=128000) + + mod = tvm.IRModule.from_expr(res) + + mod_params = { + "data": np.random.uniform(-1, 1, size=[1, 10]).astype("float32"), + } + output_mod = verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype=target_precision, atol=0.01, rtol=0.01 + ) + + # Create expected module + if target_precision == "bfloat16": + data = relay.cast(relay.var("data", shape=[1, 10]), target_precision) + res = relay.clip(data, a_min=-128000, a_max=128000) + expected_mod = tvm.IRModule.from_expr(res) + expected_mod = InferType()(expected_mod) + assert tvm.ir.structural_equal(expected_mod, output_mod) + + +def test_clip_with_pre_op(target_precision): + data = relay.var("data", shape=[1, 10], dtype="float32") + const = relay.const(5, "float32") + res = relay.divide(data, const) + res = relay.clip(res, a_min=-128000, a_max=128000) + + mod = tvm.IRModule.from_expr(res) + + mod_params = { + "data": np.random.uniform(-1, 1, size=[1, 10]).astype("float32"), + } + output_mod = verify_mixed_precision_output_close( + mod, mod_params, mixed_precision_dtype=target_precision, atol=0.01, rtol=0.01 + ) + + # Create expected module + data = relay.cast(relay.var("data", shape=[1, 10]), target_precision) + const = relay.cast(relay.const(5, "float32"), target_precision) + res = relay.divide(data, const) + if target_precision == "float16": + res = relay.cast(res, "float32") + res = relay.clip(res, a_min=-128000, a_max=128000) + expected_mod = tvm.IRModule.from_expr(res) + expected_mod = InferType()(expected_mod) + assert tvm.ir.structural_equal(expected_mod, output_mod) + + if __name__ == "__main__": tvm.testing.main()