diff --git a/csrc/fp_quantizer/fp_quantize_impl.cu b/csrc/fp_quantizer/fp_quantize_impl.cu index a71fe49c5eed..8b1913e1588f 100644 --- a/csrc/fp_quantizer/fp_quantize_impl.cu +++ b/csrc/fp_quantizer/fp_quantize_impl.cu @@ -221,20 +221,20 @@ __global__ void apply_quantization(T* val, } template + int q_mantisa_bits = 3, + int q_exponent_bits = 4> __global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int total_num_elements) { constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); int tidx = (blockIdx.x * blockDim.x + threadIdx.x) * vector_size; - constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1; - constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1; - constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1; - constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits; - constexpr uint16_t _sign_mask = 1U << (_mantisa_bits + _exponent_bits); + constexpr int quantized_bits = q_mantisa_bits + q_exponent_bits + 1; + constexpr int q_exponent_bits = total_q_bits - mantisa_bits - 1; + constexpr uint16_t _mantisa_mask = (1 << q_mantisa_bits) - 1; + constexpr uint16_t _exponent_mask = ((1 << q_exponent_bits) - 1) << q_mantisa_bits; + constexpr uint16_t _sign_mask = 1U << (q_mantisa_bits + q_exponent_bits); const uint32_t g_index = (tidx / group_size); const uint32_t group_size_bytes = (group_size * quantized_bits / 8); const uint8_t* load_base_ptr = @@ -298,17 +298,17 @@ __global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8)); } - uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits); - uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits; + uint16_t sign = (new_data & _sign_mask) >> (q_mantisa_bits + q_exponent_bits); + uint16_t dst_exponent = (new_data & _exponent_mask) >> q_mantisa_bits; uint16_t dst_mantisa = (new_data & _mantisa_mask); if (dst_exponent != (1 << q_exponent_bits) - 1) - dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) + + dst_exponent = (dst_exponent - ((1 << (q_exponent_bits - 1)) - 1)) + (1 << (q_exponent_bits - 1)) - 1; q_buf[j] = - ((sign << (q_exponent_bits + q_mantisa_bits)) | (dst_exponent << q_mantisa_bits) | - (dst_mantisa << (q_mantisa_bits - _mantisa_bits))); + ((sign << (q_exponent_bits + mantisa_bits)) | (dst_exponent << mantisa_bits) | + (dst_mantisa << (mantisa_bits - q_mantisa_bits))); float up_cast = conversion::to(store_buf[j]); store_buf[j] = conversion::to(up_cast * scale); } @@ -395,10 +395,10 @@ INSTANTIATE_LAUNCH_DEQUANTIZATION(__nv_bfloat16, 7); INSTANTIATE_LAUNCH_DEQUANTIZATION(__half, 10); template + int q_mantisa_bits = 3, + int q_exponent_bits = 4> __global__ void apply_selective_dequantization(uint8_t* val, T* q_val, int32_t* indexes, @@ -409,11 +409,11 @@ __global__ void apply_selective_dequantization(uint8_t* val, constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); int tidx = (blockIdx.y * blockDim.x + threadIdx.x) * vector_size; int input_index = index * total_num_elements + tidx; - constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1; - constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1; - constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1; - constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits; - constexpr uint16_t _sign_mask = 1U << (_mantisa_bits + _exponent_bits); + constexpr int quantized_bits = q_mantisa_bits + q_exponent_bits + 1; + constexpr int q_exponent_bits = total_q_bits - mantisa_bits - 1; + constexpr uint16_t _mantisa_mask = (1 << q_mantisa_bits) - 1; + constexpr uint16_t _exponent_mask = ((1 << q_exponent_bits) - 1) << q_mantisa_bits; + constexpr uint16_t _sign_mask = 1U << (q_mantisa_bits + q_exponent_bits); const uint32_t g_index = (input_index / group_size); const uint32_t group_size_bytes = (group_size * quantized_bits / 8); const uint8_t* load_base_ptr = @@ -476,17 +476,17 @@ __global__ void apply_selective_dequantization(uint8_t* val, new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8)); } - uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits); - uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits; + uint16_t sign = (new_data & _sign_mask) >> (q_mantisa_bits + q_exponent_bits); + uint16_t dst_exponent = (new_data & _exponent_mask) >> q_mantisa_bits; uint16_t dst_mantisa = (new_data & _mantisa_mask); if (dst_exponent != (1 << q_exponent_bits) - 1) - dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) + + dst_exponent = (dst_exponent - ((1 << (q_exponent_bits - 1)) - 1)) + (1 << (q_exponent_bits - 1)) - 1; q_buf[j] = - ((sign << (q_exponent_bits + q_mantisa_bits)) | (dst_exponent << q_mantisa_bits) | - (dst_mantisa << (q_mantisa_bits - _mantisa_bits))); + ((sign << (q_exponent_bits + mantisa_bits)) | (dst_exponent << mantisa_bits) | + (dst_mantisa << (mantisa_bits - q_mantisa_bits))); float up_cast = conversion::to(store_buf[j]); store_buf[j] = conversion::to(up_cast * scale); }