Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions csrc/fp_quantizer/fp_quantize_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -221,20 +221,20 @@ __global__ void apply_quantization(T* val,
}

template <typename T,
int q_mantisa_bits,
int mantisa_bits,
int total_q_bits = 16,
int _mantisa_bits = 3,
int _exponent_bits = 4>
int q_mantisa_bits = 3,
int q_exponent_bits = 4>
__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int total_num_elements)
{
constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
int tidx = (blockIdx.x * blockDim.x + threadIdx.x) * vector_size;

constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1;
constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1;
constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits;
constexpr uint16_t _sign_mask = 1U << (_mantisa_bits + _exponent_bits);
constexpr int quantized_bits = q_mantisa_bits + q_exponent_bits + 1;
constexpr int q_exponent_bits = total_q_bits - mantisa_bits - 1;
constexpr uint16_t _mantisa_mask = (1 << q_mantisa_bits) - 1;
constexpr uint16_t _exponent_mask = ((1 << q_exponent_bits) - 1) << q_mantisa_bits;
constexpr uint16_t _sign_mask = 1U << (q_mantisa_bits + q_exponent_bits);
const uint32_t g_index = (tidx / group_size);
const uint32_t group_size_bytes = (group_size * quantized_bits / 8);
const uint8_t* load_base_ptr =
Expand Down Expand Up @@ -298,17 +298,17 @@ __global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int
new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8));
}

uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits);
uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits;
uint16_t sign = (new_data & _sign_mask) >> (q_mantisa_bits + q_exponent_bits);
uint16_t dst_exponent = (new_data & _exponent_mask) >> q_mantisa_bits;
uint16_t dst_mantisa = (new_data & _mantisa_mask);

if (dst_exponent != (1 << q_exponent_bits) - 1)
dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
dst_exponent = (dst_exponent - ((1 << (q_exponent_bits - 1)) - 1)) +
(1 << (q_exponent_bits - 1)) - 1;

q_buf[j] =
((sign << (q_exponent_bits + q_mantisa_bits)) | (dst_exponent << q_mantisa_bits) |
(dst_mantisa << (q_mantisa_bits - _mantisa_bits)));
((sign << (q_exponent_bits + mantisa_bits)) | (dst_exponent << mantisa_bits) |
(dst_mantisa << (mantisa_bits - q_mantisa_bits)));
float up_cast = conversion::to<float>(store_buf[j]);
store_buf[j] = conversion::to<T>(up_cast * scale);
}
Expand Down Expand Up @@ -395,10 +395,10 @@ INSTANTIATE_LAUNCH_DEQUANTIZATION(__nv_bfloat16, 7);
INSTANTIATE_LAUNCH_DEQUANTIZATION(__half, 10);

template <typename T,
int q_mantisa_bits,
int mantisa_bits,
int total_q_bits = 16,
int _mantisa_bits = 3,
int _exponent_bits = 4>
int q_mantisa_bits = 3,
int q_exponent_bits = 4>
__global__ void apply_selective_dequantization(uint8_t* val,
T* q_val,
int32_t* indexes,
Expand All @@ -409,11 +409,11 @@ __global__ void apply_selective_dequantization(uint8_t* val,
constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
int tidx = (blockIdx.y * blockDim.x + threadIdx.x) * vector_size;
int input_index = index * total_num_elements + tidx;
constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1;
constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1;
constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits;
constexpr uint16_t _sign_mask = 1U << (_mantisa_bits + _exponent_bits);
constexpr int quantized_bits = q_mantisa_bits + q_exponent_bits + 1;
constexpr int q_exponent_bits = total_q_bits - mantisa_bits - 1;
constexpr uint16_t _mantisa_mask = (1 << q_mantisa_bits) - 1;
constexpr uint16_t _exponent_mask = ((1 << q_exponent_bits) - 1) << q_mantisa_bits;
constexpr uint16_t _sign_mask = 1U << (q_mantisa_bits + q_exponent_bits);
const uint32_t g_index = (input_index / group_size);
const uint32_t group_size_bytes = (group_size * quantized_bits / 8);
const uint8_t* load_base_ptr =
Expand Down Expand Up @@ -476,17 +476,17 @@ __global__ void apply_selective_dequantization(uint8_t* val,
new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8));
}

uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits);
uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits;
uint16_t sign = (new_data & _sign_mask) >> (q_mantisa_bits + q_exponent_bits);
uint16_t dst_exponent = (new_data & _exponent_mask) >> q_mantisa_bits;
uint16_t dst_mantisa = (new_data & _mantisa_mask);

if (dst_exponent != (1 << q_exponent_bits) - 1)
dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
dst_exponent = (dst_exponent - ((1 << (q_exponent_bits - 1)) - 1)) +
(1 << (q_exponent_bits - 1)) - 1;

q_buf[j] =
((sign << (q_exponent_bits + q_mantisa_bits)) | (dst_exponent << q_mantisa_bits) |
(dst_mantisa << (q_mantisa_bits - _mantisa_bits)));
((sign << (q_exponent_bits + mantisa_bits)) | (dst_exponent << mantisa_bits) |
(dst_mantisa << (mantisa_bits - q_mantisa_bits)));
float up_cast = conversion::to<float>(store_buf[j]);
store_buf[j] = conversion::to<T>(up_cast * scale);
}
Expand Down
Loading