From 9a8d81297056621aeada891e580fd3c82d9d74b3 Mon Sep 17 00:00:00 2001 From: Johnson Zhang Date: Wed, 5 Jul 2023 13:09:49 +0800 Subject: [PATCH] [QNN] Support Dequantize to "float16" and Quantize to "uint16" --- include/tvm/relay/qnn/attrs.h | 2 ++ python/tvm/relay/qnn/op/qnn.py | 21 +++++++----- src/relay/qnn/op/dequantize.cc | 28 ++++++++++++---- src/relay/qnn/op/quantize.cc | 5 +-- src/relay/qnn/utils.h | 3 +- tests/python/relay/test_op_qnn_dequantize.py | 35 ++++++++++++++++++-- tests/python/relay/test_op_qnn_quantize.py | 23 +++++++++++++ 7 files changed, 95 insertions(+), 22 deletions(-) diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index 64b2dc20981d..85e008528625 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -95,9 +95,11 @@ struct SimulatedQuantizeAttrs : public tvm::AttrsNode { /*! \brief Attribute for dequantize operator */ struct DequantizeAttrs : public tvm::AttrsNode { + DataType out_dtype; int axis; TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") { + TVM_ATTR_FIELD(out_dtype).describe("Output data type, can be one of [float16, float32]."); TVM_ATTR_FIELD(axis) .describe( "The channel axis for channel wise dequantization. Default value is -1," diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index eb64b56e829d..968935f062b2 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -186,8 +186,8 @@ def requantize( def quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"): r"""Quantize op - This operator takes float32 as input and produces quantized int8 or unit8 as output. - The input tensor can be of any shape. The output shape is the same as input shape. + This operator takes float32 input and produces quantized output. The input + tensor can be of any shape. The output shape is the same as input shape. Q_output = clamp((round(input_tensor/output_scale) + output_zero_point), out_dtype::min, @@ -206,8 +206,9 @@ def quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"): axis : int The channel axis for quantization. Default value is -1 which corresponds to the last axis. + out_dtype : str, optional - The data type of the input tensor. Can be [int8, uint8, int32] + The data type of the output tensor. Can be [int8, unit8, int16, uint16, int32]. Returns ------- @@ -256,16 +257,15 @@ def simulated_quantize(data, output_scale, output_zero_point, axis=-1, out_dtype return _make.simulated_quantize(data, out_dtype, output_scale, output_zero_point, axis) -def dequantize(data, input_scale, input_zero_point, axis=-1): +def dequantize(data, input_scale, input_zero_point, axis=-1, out_dtype="float32"): r"""Dequantize op - This operator takes quantized int8 and unit8 as input and produces - dequantized float32 as output. The output shape is the same as input shape. The input - tensor can be of any shape. + This operator takes quantized input and produces dequantized float output. + The output shape is the same as input shape. The input tensor can be of any shape. Parameters ---------- data : tvm.relay.Expr - The input tensor to be dequantized. Can be of type [int8, uint8, int32]. + The input tensor to be dequantized. Can be of type [int8, unit8, int16, uint16, int32]. input_scale : tvm.relay.Expr The input scale. @@ -276,13 +276,16 @@ def dequantize(data, input_scale, input_zero_point, axis=-1): axis : int The channel axis for quantization. Default value is -1 which corresponds to the last axis. + out_dtype : str, optional + The data type of the output tensor. Can be [float16, float32]. + Returns ------- result : tvm.relay.Expr The computed result. """ - return _make.dequantize(data, input_scale, input_zero_point, axis) + return _make.dequantize(data, input_scale, input_zero_point, axis, out_dtype) def simulated_dequantize(data, input_scale, input_zero_point, axis=-1, in_dtype="int8"): diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 1ddcde81234d..5e2ef39edacb 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -47,9 +47,10 @@ bool DequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const auto input_dtype = data->dtype; ICHECK(input_dtype == DataType::Int(8) || input_dtype == DataType::UInt(8) || - input_dtype == DataType::Int(16) || input_dtype == DataType::Int(32)) - << "Input type should be one of the quantized types [unit8, int8, int16, int32] but was " - << input_dtype; + input_dtype == DataType::Int(16) || input_dtype == DataType::UInt(16) || + input_dtype == DataType::Int(32)) + << "Input type should be one of the quantized types [int8, unit8, int16, uint16, int32] but " + << "was " << input_dtype; const auto* dequantize_attrs = attrs.as(); int axis = dequantize_attrs->axis; @@ -77,18 +78,24 @@ bool DequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, // Check and assign types for scale and zero points. AssignType(types[1], DataType::Float(32), axis_shape, reporter); // scale AssignType(types[2], DataType::Int(32), axis_shape, reporter); // zero point + const Array oshape = data->shape; - // assign output type, output will always be float 32. - reporter->Assign(types[3], TensorType(oshape, DataType::Float(32))); + const DataType out_dtype = dequantize_attrs->out_dtype; + ICHECK(out_dtype == DataType::Float(16) || out_dtype == DataType::Float(32)) + << "Output type should be one of [float16, float32] but was " << out_dtype; + // assign output type. + reporter->Assign(types[3], TensorType(oshape, out_dtype)); return true; } -Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis) { +Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis, + DataType out_dtype) { // real_value = scale * (quantized_value - zero_point) // A more detailed explanation can be found here - // https://github.com/google/gemmlowp/blob/master/doc/quantization.md auto attrs = make_object(); attrs->axis = axis; + attrs->out_dtype = out_dtype; static const Op& op = Op::Get("qnn.dequantize"); return Call(op, {data, input_scale, input_zero_point}, Attrs(attrs), {}); } @@ -125,7 +132,14 @@ Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale, auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), expanded_input_zero_point); auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), expanded_input_scale); - return scaled_output; + + const DataType out_dtype = attrs->out_dtype; + if (out_dtype.is_float() && out_dtype.bits() == 32) return scaled_output; + + double min_val = tvm::min_value(out_dtype).as()->value; + double max_val = tvm::max_value(out_dtype).as()->value; + auto clamped_output = Clip(scaled_output, min_val, max_val); + return Cast(clamped_output, out_dtype); } Expr DequantizeQnnCanonicalize(const Attrs& attrs, const Array& new_args, diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 1a16705932d0..8ed1f9ef4c4f 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -91,8 +91,9 @@ bool QuantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const Array oshape = data->shape; const DataType out_dtype = quantize_attrs->out_dtype; ICHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) || - out_dtype == DataType::Int(16) || out_dtype == DataType::Int(32)) - << "Output type should be one of [int8, unit8, int16, int32] but was " << out_dtype; + out_dtype == DataType::Int(16) || out_dtype == DataType::UInt(16) || + out_dtype == DataType::Int(32)) + << "Output type should be one of [int8, unit8, int16, uint16, int32] but was " << out_dtype; // assign output type reporter->Assign(types[3], TensorType(oshape, out_dtype)); return true; diff --git a/src/relay/qnn/utils.h b/src/relay/qnn/utils.h index 5005d6068524..4102fb29a6fe 100644 --- a/src/relay/qnn/utils.h +++ b/src/relay/qnn/utils.h @@ -135,7 +135,8 @@ static inline Expr Dequantize(const Expr& data, const Expr& input_scale, return DequantizeLower(data, input_scale, input_zero_point, types, attrs.operator->()); } -Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis); +Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis, + DataType out_dtype = DataType::Float(32)); Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale, const Expr& output_zero_point, const Array& types, diff --git a/tests/python/relay/test_op_qnn_dequantize.py b/tests/python/relay/test_op_qnn_dequantize.py index b332bd94f31e..3b2ae97eb63f 100644 --- a/tests/python/relay/test_op_qnn_dequantize.py +++ b/tests/python/relay/test_op_qnn_dequantize.py @@ -23,13 +23,19 @@ from tvm.relay.testing import run_infer_type -def dequantize_test_driver(in_dtype, quant_args, in_data, verify_output_data, axis): +def dequantize_test_driver( + in_dtype, quant_args, in_data, verify_output_data, axis, out_dtype="float32" +): shape = in_data.shape input_data = relay.var("input_data", shape=shape, dtype=in_dtype) input_zero_point = relay.const(quant_args["in_zero_point"], "int32") input_scale = relay.const(quant_args["in_scale"], "float32") quantized_output = relay.qnn.op.dequantize( - input_data, input_scale=input_scale, input_zero_point=input_zero_point, axis=axis + input_data, + input_scale=input_scale, + input_zero_point=input_zero_point, + axis=axis, + out_dtype=out_dtype, ) mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) mod = tvm.IRModule.from_expr(mod) @@ -41,7 +47,7 @@ def dequantize_test_driver(in_dtype, quant_args, in_data, verify_output_data, ax rt_mod.run() res = rt_mod.get_output(0).numpy() np.testing.assert_equal(res, verify_output_data) - assert res.dtype == np.float32 + assert res.dtype == out_dtype def test_uint8_to_float32(): @@ -74,6 +80,28 @@ def test_int8_to_float32(): ) +def test_int8_to_float16(): + data = ( + np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) + .astype("int8") + .reshape((2, 5)) + ) + output = ( + np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) + .astype("float16") + .reshape((2, 5)) + ) + quant_args = {"in_zero_point": -1, "in_scale": 0.5} + dequantize_test_driver( + in_dtype="int8", + quant_args=quant_args, + in_data=data, + verify_output_data=output, + axis=-1, + out_dtype="float16", + ) + + def test_scalar_int8_to_float32(): data = np.array(-128).astype("int8") output = np.array(-63.5).astype("float32") @@ -171,6 +199,7 @@ def test_dynamic_dequantize(): if __name__ == "__main__": test_uint8_to_float32() test_int8_to_float32() + test_int8_to_float16() test_scalar_int8_to_float32() test_int32_to_float32() test_channelwise_axis_1() diff --git a/tests/python/relay/test_op_qnn_quantize.py b/tests/python/relay/test_op_qnn_quantize.py index 322382ca002c..3a3521b11e90 100644 --- a/tests/python/relay/test_op_qnn_quantize.py +++ b/tests/python/relay/test_op_qnn_quantize.py @@ -88,6 +88,28 @@ def test_float32_to_int8(): ) +def test_float32_to_uint16(): + data = ( + np.array([-6553, -6552.8, -6552.6, -6552.4, -6552.2, 6553.2, 6553.4, 6553.6, 6553.8, 6554]) + .astype("float32") + .reshape((2, 5)) + ) + output = ( + np.array([0, 1, 2, 3, 4, 65531, 65532, 65533, 65534, 65535]) + .astype("uint16") + .reshape((2, 5)) + ) + quant_args = {"out_zero_point": np.int32(32765), "out_scale": np.float32(0.2)} + quantize_test_driver( + in_dtype="float32", + quant_args=quant_args, + axis=-1, + out_dtype="uint16", + in_data=data, + verify_output_data=output, + ) + + def test_scalar_float32_to_int8(): data = np.array(-63.5).astype("float32") output = np.array(-128).astype("int8") @@ -177,6 +199,7 @@ def test_dynamic_quantize(): if __name__ == "__main__": test_float32_to_uint8() test_float32_to_int8() + test_float32_to_uint16() test_scalar_float32_to_int8() test_channelwise_axis_0() test_channelwise_axis_1()