diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index bf3e83928ed7..27d44d9f7f4a 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -410,7 +410,28 @@ struct __align__(8) half4 { result.__x = (static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16)); return result; - })"; + } + __host__ __device__ explicit half4(const __nv_fp8x4_e5m2& fp8x4) { + __nv_fp8x2_e5m2 lo_part, hi_part; + lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF); + hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF); + __half2 lo_half2 = static_cast<__half2>(lo_part); + __half2 hi_half2 = static_cast<__half2>(hi_part); + x = reinterpret_cast<__half*>(&lo_half2)[0]; + y = reinterpret_cast<__half*>(&lo_half2)[1]; + z = reinterpret_cast<__half*>(&hi_half2)[0]; + w = reinterpret_cast<__half*>(&hi_half2)[1]; + } + __host__ __device__ explicit operator __nv_fp8x4_e5m2() const { + __nv_fp8x4_e5m2 result; + __half2 lo_half2 = *reinterpret_cast(&x); + __half2 hi_half2 = *reinterpret_cast(&z); + __nv_fp8x2_e5m2 lo_part(lo_half2), hi_part(hi_half2); + result.__x = + (static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16)); + return result; + } + )"; } stream << R"( };