|
| 1 | +// Copyright (c) Microsoft Corporation. |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +// DeepSpeed Team |
| 5 | + |
| 6 | +// NOTE: |
| 7 | +// This shared-memory implementation targets AArch64 CPUs. |
| 8 | +// Minimum supported architecture is ARMv8-A with NEON (Advanced SIMD) support. |
| 9 | +// Systems without NEON are not supported. |
| 10 | + |
| 11 | +#include <arm_neon.h> |
| 12 | +#include <stddef.h> |
| 13 | +#include <stdint.h> |
| 14 | +#include <cmath> |
| 15 | + |
| 16 | +// 128 bits = 16 bytes -> fits 8 fp16/bf16 or 4 fp32 elements. |
| 17 | +static int vector_length_in_bytes = 16; |
| 18 | +// When widening fp16/bf16 -> fp32, 4 elements fit in one 128-bit register. |
| 19 | +// Using 8 would require two 128-bit registers, so limit to 4. |
| 20 | +static constexpr int full_precision_elements_in_fixed_vector = 4; |
| 21 | + |
| 22 | +static inline float32x4_t cvt_bf16_to_fp32(const uint16x4_t input) |
| 23 | +{ |
| 24 | + // Zero-extend 16-bit to 32-bit and shift left by 16 bits |
| 25 | + // BF16 has the same exponent/sign bits as FP32, just missing lower mantissa bits |
| 26 | + uint32x4_t result_32 = vshll_n_u16(input, 16); |
| 27 | + return vreinterpretq_f32_u32(result_32); |
| 28 | +} |
| 29 | + |
| 30 | +static inline float32x4_t cvt_fp16_to_fp32(float16x4_t input) |
| 31 | +{ |
| 32 | + // Converts 4 FP16 values to 4 FP32 values |
| 33 | + return vcvt_f32_f16(input); |
| 34 | +} |
| 35 | + |
| 36 | +// While converting fp32 to fp16, before truncating lsb, it should be rounded to nearest even and |
| 37 | +// Converts 4 float32 -> 4 bfloat16 with round-to-nearest-even (RNE) and NaN handling |
| 38 | +static inline uint16x4_t cvt_fp32_to_bf16(float32x4_t src) |
| 39 | +{ |
| 40 | + // Reinterpret float32 bits as uint32 |
| 41 | + uint32x4_t u32 = vreinterpretq_u32_f32(src); |
| 42 | + |
| 43 | + const uint32x4_t ones = vdupq_n_u32(0x1); |
| 44 | + const uint32x4_t vec_bias = |
| 45 | + vdupq_n_u32(0x7FFF); // one less than half of the dropped bits range |
| 46 | + const uint16x4_t nan_bf16 = vdup_n_u16(0xFFFF); |
| 47 | + |
| 48 | + // RNE: lsb = (input >> 16) & 1 |
| 49 | + uint32x4_t lsb = vandq_u32(vshrq_n_u32(u32, 16), ones); |
| 50 | + |
| 51 | + // rounding_bias = 0x7FFF + lsb, lsb can be 0 or 1. |
| 52 | + uint32x4_t bias = vaddq_u32(vec_bias, lsb); |
| 53 | + |
| 54 | + // input += rounding_bias |
| 55 | + u32 = vaddq_u32(u32, bias); |
| 56 | + |
| 57 | + // >> 16 to get bfloat16 |
| 58 | + // vshrq_n_u32 - keeps 32 bit width after shift |
| 59 | + // vshrn_n_u32 - keeps 16 bits width after shift |
| 60 | + uint16x4_t bf16 = vshrn_n_u32(u32, 16); |
| 61 | + |
| 62 | + // vmvnq_u32 is bitwise NOT |
| 63 | + // NaN mask: ~(src == src) -> 1 if NaN |
| 64 | + // for normal num, ~(src == src) -> 0 |
| 65 | + uint32x4_t isnan = vmvnq_u32(vceqq_f32(src, src)); |
| 66 | + |
| 67 | + // Select nan_bf16 if isnan (use 16-bit mask) |
| 68 | + uint16x4_t mask = vreinterpret_u16_u32(vget_low_u32(isnan)); |
| 69 | + return vbsl_u16(mask, nan_bf16, bf16); |
| 70 | +} |
| 71 | + |
| 72 | +// fp32 and fp16 are IEEE formats. |
| 73 | +// converting fp32 to fp16 is handled by vcvt_f16_f32 internally without arbitrarily truncating the |
| 74 | +// lsb but rounds to nearest. |
| 75 | +static inline float16x4_t cvt_fp32_to_fp16(float32x4_t input) |
| 76 | +{ |
| 77 | + // Converts 4 FP32 values to 4 FP16 values with rounding |
| 78 | + return vcvt_f16_f32(input); |
| 79 | +} |
| 80 | + |
| 81 | +// Reduce functions down below use vectorized algorithm, the number of bytes processed each |
| 82 | +// iteration depends on vector length. 128bit vector ==> 16 bytes. sticking to NEON 128 bit |
| 83 | + |
| 84 | +void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers); |
| 85 | +void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers); |
| 86 | +void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers); |
| 87 | + |
| 88 | +void parallel_memcpy(void* to, void* from, size_t n_bytes); |
| 89 | + |
| 90 | +#define VLOAD_U8(X) vld1q_u8((uint8_t*)(X)) |
| 91 | +#define VLOAD_U16(X) vld1_u16((uint16_t*)(X)) |
| 92 | +#define VLOAD_F16(X) vld1_f16((float16_t*)(X)) |
| 93 | +#define VLOAD_F32(X) vld1q_f32((float32_t*)(X)) |
| 94 | + |
| 95 | +#define VSTORE_U8(A, B) vst1q_u8((uint8_t*)(A), B) |
| 96 | +#define VSTORE_U16(A, B) vst1_u16((uint16_t*)(A), B) |
| 97 | +#define VSTORE_F16(A, B) vst1_f16((float16_t*)(A), B) // fp16 supported from armv8.2-a+fp16 |
| 98 | +#define VSTORE_F32(A, B) vst1q_f32((float32_t*)(A), B) |
| 99 | + |
| 100 | +#define VADD_F32(A, B) vaddq_f32(A, B) |
| 101 | +#define VADD_F32_2VL(A, B) vaddq_f32(A, B) |
| 102 | + |
| 103 | +#define CVT_BF16_TO_FP32(X) cvt_bf16_to_fp32(X) |
| 104 | +#define CVT_FP16_TO_FP32(X) cvt_fp16_to_fp32(X) |
| 105 | +#define CVT_FP32_TO_BF16(X) cvt_fp32_to_bf16(X) |
| 106 | +#define CVT_FP32_TO_FP16(X) cvt_fp32_to_fp16(X) |
0 commit comments