forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcuda_half_t.h
More file actions
621 lines (569 loc) · 23.1 KB
/
cuda_half_t.h
File metadata and controls
621 lines (569 loc) · 23.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file cuda_half_t.h
* \brief half_t (fp16) definition for cuda codegen.
*/
#ifndef TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
#define TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
#include <string>
static constexpr const char* _cuda_half_t_def = R"(
typedef unsigned short uint16_t;
typedef unsigned char uint8_t;
typedef signed char int8_t;
typedef int int32_t;
typedef unsigned long long uint64_t;
typedef unsigned int uint32_t;
#define TVM_FORCE_INLINE inline __attribute__((always_inline))
#define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__
#define TVM_ALIGNED(x) __attribute__ ((aligned(x)))
#define TVM_HALF_OPERATOR(RTYPE, OP) \
TVM_XINLINE RTYPE operator OP (half a, half b) { \
return RTYPE(float(a) OP float(b)); \
} \
template<typename T> \
TVM_XINLINE RTYPE operator OP (half a, T b) { \
return RTYPE(float(a) OP float(b)); \
} \
template<typename T> \
TVM_XINLINE RTYPE operator OP (T a, half b) { \
return RTYPE(float(a) OP float(b)); \
}
#define TVM_HALF_ASSIGNOP(AOP, OP) \
template<typename T> \
TVM_XINLINE half operator AOP (const T& a) { \
return *this = half(float(*this) OP float(a)); \
} \
template<typename T> \
TVM_XINLINE half operator AOP (const volatile T& a) volatile { \
return *this = half(float(*this) OP float(a)); \
}
class TVM_ALIGNED(2) half {
public:
uint16_t half_;
static TVM_XINLINE half Binary(uint16_t value) {
half res;
res.half_ = value;
return res;
}
TVM_XINLINE half() {}
TVM_XINLINE half(const float& value) { constructor(value); }
TVM_XINLINE explicit half(const double& value) { constructor(value); }
TVM_XINLINE explicit half(const int8_t& value) { constructor(value); }
TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }
TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }
TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }
TVM_XINLINE explicit half(const long long& value) { constructor(value); }
TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }
TVM_XINLINE operator float() const { \
return float(half2float(half_)); \
} \
TVM_XINLINE operator float() const volatile { \
return float(half2float(half_)); \
}
TVM_HALF_ASSIGNOP(+=, +)
TVM_HALF_ASSIGNOP(-=, -)
TVM_HALF_ASSIGNOP(*=, *)
TVM_HALF_ASSIGNOP(/=, /)
TVM_XINLINE half operator+() {
return *this;
}
TVM_XINLINE half operator-() {
return half(-float(*this));
}
TVM_XINLINE half operator=(const half& a) {
half_ = a.half_;
return a;
}
template<typename T>
TVM_XINLINE half operator=(const T& a) {
return *this = half(a);
}
TVM_XINLINE half operator=(const half& a) volatile {
half_ = a.half_;
return a;
}
template<typename T>
TVM_XINLINE half operator=(const T& a) volatile {
return *this = half(a);
}
private:
union Bits {
float f;
int32_t si;
uint32_t ui;
};
static int const fp16FractionBits = 10;
static int const fp32FractionBits = 23;
static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff
static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000
static int const shift = fp32FractionBits - fp16FractionBits; // == 13
static int const shiftSign = 16;
static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)
static int32_t const infN = 0x7F800000; // flt32 infinity
static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift
static int32_t const minN = 0x38800000; // min flt16 normal as a flt32
static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16
static int32_t const signN = 0x80000000; // flt32 sign bit
static int32_t const infC = infN >> shift;
static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32
static int32_t const maxC = maxN >> shift;
static int32_t const minC = minN >> shift;
static int32_t const signC = signN >> shiftSign; // flt16 sign bit
static int32_t const mulN = 0x52000000; // (1 << 23) / minN
static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift))
static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted
static int32_t const norC = 0x00400; // min flt32 normal down shifted
static int32_t const maxD = infC - maxC - 1;
static int32_t const minD = minC - subC - 1;
TVM_XINLINE uint16_t float2half(const float& value) const {
Bits v;
v.f = value;
uint32_t sign = v.si & signN; // grab sign bit
v.si ^= sign; // clear sign bit from v
sign >>= shiftSign; // logical shift sign to fp16 position
if (v.si <= maxZ) {
// Handle eventual zeros here to ensure
// vshift will not exceed 32 below.
v.ui = 0;
} else if (v.si < minN) {
// Handle denorms
uint32_t exp32 = v.ui >> fp32FractionBits;
int32_t exp16 = exp32 - expAdjust;
// If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
// Smaller (so negative) exp16 values should result in greater right shifts.
uint32_t vshift = 1 - exp16;
uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
v.ui = significand >> vshift;
v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
} else if (v.si <= maxN) {
// Handle norms
v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
v.ui -= expAdjust << fp32FractionBits;
} else if (v.si <= infN) {
v.si = infN;
} else if (v.si < nanN) {
v.si = nanN;
}
v.ui >>= shift;
return sign | (v.ui & 0x7fff);
}
// Same as above routine, except for addition of volatile keyword
TVM_XINLINE uint16_t float2half(
const volatile float& value) const volatile {
Bits v;
v.f = value;
uint32_t sign = v.si & signN; // grab sign bit
v.si ^= sign; // clear sign bit from v
sign >>= shiftSign; // logical shift sign to fp16 position
if (v.si <= maxZ) {
// Handle eventual zeros here to ensure
// vshift will not exceed 32 below.
v.ui = 0;
} else if (v.si < minN) {
// Handle denorms
uint32_t exp32 = v.ui >> fp32FractionBits;
int32_t exp16 = exp32 - expAdjust;
// If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
// Smaller (so negative) exp16 values should result in greater right shifts.
uint32_t vshift = 1 - exp16;
uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
v.ui = significand >> vshift;
v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
} else if (v.si <= maxN) {
// Handle norms
v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
v.ui -= expAdjust << fp32FractionBits;
} else if (v.si <= infN) {
v.si = infN;
} else if (v.si < nanN) {
v.si = nanN;
}
v.ui >>= shift;
return sign | (v.ui & 0x7fff);
}
TVM_XINLINE float half2float(const uint16_t& value) const {
Bits v;
v.ui = value;
int32_t sign = v.si & signC;
v.si ^= sign;
sign <<= shiftSign;
v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
Bits s;
s.si = mulC;
s.f *= v.si;
int32_t mask = -(norC > v.si);
v.si <<= shift;
v.si ^= (s.si ^ v.si) & mask;
v.si |= sign;
return v.f;
}
TVM_XINLINE float half2float(
const volatile uint16_t& value) const volatile {
Bits v;
v.ui = value;
int32_t sign = v.si & signC;
v.si ^= sign;
sign <<= shiftSign;
v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
Bits s;
s.si = mulC;
s.f *= v.si;
int32_t mask = -(norC > v.si);
v.si <<= shift;
v.si ^= (s.si ^ v.si) & mask;
v.si |= sign;
return v.f;
}
template<typename T>
TVM_XINLINE void constructor(const T& value) {
half_ = float2half(float(value));
}
};
TVM_HALF_OPERATOR(half, +)
TVM_HALF_OPERATOR(half, -)
TVM_HALF_OPERATOR(half, *)
TVM_HALF_OPERATOR(half, /)
TVM_HALF_OPERATOR(bool, >)
TVM_HALF_OPERATOR(bool, <)
TVM_HALF_OPERATOR(bool, >=)
TVM_HALF_OPERATOR(bool, <=)
TVM_XINLINE half __float2half_rn(const float a) {
return half(a);
}
)";
static constexpr const char* _cuda_half_util = R"(
// Pack two half values.
static inline __device__ __host__ unsigned
__pack_half2(const half x, const half y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
#define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) \
static inline __device__ __host__ half HALF_MATH_NAME(half x, half y) { \
float tmp_x = __half2float(x); \
float tmp_y = __half2float(y); \
float result = FP32_MATH_NAME(tmp_x, tmp_y); \
return __float2half(result); \
}
#define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) \
static inline __device__ __host__ half HALF_MATH_NAME(half x) { \
float tmp_x = __half2float(x); \
float result = FP32_MATH_NAME(tmp_x); \
return __float2half(result); \
}
// Some fp16 math functions are not supported in cuda_fp16.h,
// so we define them here to make sure the generated CUDA code
// is valid.
#if defined(__CUDA_ARCH__)
#if (__CUDA_ARCH__ >= 530)
CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf)
#if ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 8)))
CUDA_UNSUPPORTED_HALF_MATH_UNARY(hsinh, sinhf)
CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf)
#endif
CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf)
CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf)
CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf)
#else
CUDA_UNSUPPORTED_HALF_MATH_UNARY(hexp, exp)
#endif
#endif
#undef CUDA_UNSUPPORTED_HALF_MATH_BINARY
#undef CUDA_UNSUPPORTED_HALF_MATH_UNARY
)";
static constexpr const char* _cuda_bfloat16_util = R"(
// Pack two bfloat16 values.
static inline __device__ __host__ unsigned
__pack_nv_bfloat162(const nv_bfloat16 x, const nv_bfloat16 y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
// Some bfp16 math functions are not supported in cuda_bfp16.h,
// so we define them here to make sure the generated CUDA code
// is valid.
#define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) \
static inline __device__ __host__ nv_bfloat16 HALF_MATH_NAME(nv_bfloat16 x, nv_bfloat16 y) { \
float tmp_x = __bfloat162float(x); \
float tmp_y = __bfloat162float(y); \
float result = FP32_MATH_NAME(tmp_x, tmp_y); \
return __float2bfloat16(result); \
}
#define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) \
static inline __device__ __host__ nv_bfloat16 HALF_MATH_NAME(nv_bfloat16 x) { \
float tmp_x = __bfloat162float(x); \
float result = FP32_MATH_NAME(tmp_x); \
return __float2bfloat16(result); \
}
CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf)
#if ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 8)))
CUDA_UNSUPPORTED_HALF_MATH_UNARY(hsinh, sinhf)
CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf)
#endif
CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf)
CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf)
CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf)
#undef CUDA_UNSUPPORTED_HALF_MATH_BINARY
#undef CUDA_UNSUPPORTED_HALF_MATH_UNARY
)";
static constexpr const char* _cuda_warp_intrinsic_util = R"(
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)
#define __shfl_sync(mask, var, lane, width) \
__shfl((var), (lane), (width))
#define __shfl_down_sync(mask, var, offset, width) \
__shfl_down((var), (offset), (width))
#define __shfl_up_sync(mask, var, offset, width) \
__shfl_up((var), (offset), (width))
#endif
)";
void declare_vector_type_extensions(std::ostringstream& stream, bool enable_fp16, bool enable_bf16,
bool enable_fp8, bool enable_fp4) {
if (enable_fp16 || enable_bf16) {
stream << R"(
#include <type_traits>
template <typename T, typename TVec2>
struct __align__(8) half4_bfloat164 {
T x, y, z, w;
__host__ __device__ half4_bfloat164() : x(T(0)), y(T(0)), z(T(0)), w(T(0)) {}
__host__ __device__ half4_bfloat164(T x, T y, T z, T w) : x(x), y(y), z(z), w(w) {}
)";
if (enable_fp8) {
stream << R"(
__host__ __device__ explicit half4_bfloat164(const __nv_fp8x4_e4m3& fp8x4) {
if constexpr (std::is_same_v<T, __half>) {
__nv_fp8x2_e4m3 lo_part, hi_part;
lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF);
TVec2 lo_half2 = static_cast<TVec2>(lo_part);
TVec2 hi_half2 = static_cast<TVec2>(hi_part);
x = reinterpret_cast<T*>(&lo_half2)[0];
y = reinterpret_cast<T*>(&lo_half2)[1];
z = reinterpret_cast<T*>(&hi_half2)[0];
w = reinterpret_cast<T*>(&hi_half2)[1];
} else {
__nv_fp8_storage_t elem0_raw = static_cast<__nv_fp8_storage_t>(fp8x4.__x & 0xFF);
__nv_fp8_storage_t elem1_raw = static_cast<__nv_fp8_storage_t>((fp8x4.__x >> 8) & 0xFF);
__nv_fp8_storage_t elem2_raw = static_cast<__nv_fp8_storage_t>((fp8x4.__x >> 16) & 0xFF);
__nv_fp8_storage_t elem3_raw = static_cast<__nv_fp8_storage_t>((fp8x4.__x >> 24) & 0xFF);
__nv_fp8_e4m3 elem0, elem1, elem2, elem3;
elem0.__x = elem0_raw;
elem1.__x = elem1_raw;
elem2.__x = elem2_raw;
elem3.__x = elem3_raw;
x = T(elem0);
y = T(elem1);
z = T(elem2);
w = T(elem3);
}
}
__host__ __device__ explicit operator __nv_fp8x4_e4m3() const {
__nv_fp8x4_e4m3 result;
TVec2 lo_half2 = *reinterpret_cast<const TVec2*>(&x);
TVec2 hi_half2 = *reinterpret_cast<const TVec2*>(&z);
__nv_fp8x2_e4m3 lo_part(lo_half2), hi_part(hi_half2);
result.__x =
(static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16));
return result;
}
__host__ __device__ explicit half4_bfloat164(const __nv_fp8x4_e5m2& fp8x4) {
__nv_fp8x2_e5m2 lo_part, hi_part;
lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF);
TVec2 lo_half2 = static_cast<TVec2>(lo_part);
TVec2 hi_half2 = static_cast<TVec2>(hi_part);
x = reinterpret_cast<T*>(&lo_half2)[0];
y = reinterpret_cast<T*>(&lo_half2)[1];
z = reinterpret_cast<T*>(&hi_half2)[0];
w = reinterpret_cast<T*>(&hi_half2)[1];
}
__host__ __device__ explicit operator __nv_fp8x4_e5m2() const {
__nv_fp8x4_e5m2 result;
TVec2 lo_half2 = *reinterpret_cast<const TVec2*>(&x);
TVec2 hi_half2 = *reinterpret_cast<const TVec2*>(&z);
__nv_fp8x2_e5m2 lo_part(lo_half2), hi_part(hi_half2);
result.__x =
(static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16));
return result;
}
__host__ __device__ explicit half4_bfloat164(const __nv_fp8x4_e8m0& fp8x4) {
__nv_fp8x2_e8m0 lo_part, hi_part;
lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF);
TVec2 lo_half2 = static_cast<TVec2>(lo_part);
TVec2 hi_half2 = static_cast<TVec2>(hi_part);
x = reinterpret_cast<T*>(&lo_half2)[0];
y = reinterpret_cast<T*>(&lo_half2)[1];
z = reinterpret_cast<T*>(&hi_half2)[0];
w = reinterpret_cast<T*>(&hi_half2)[1];
}
__host__ __device__ explicit operator __nv_fp8x4_e8m0() const {
__nv_fp8x4_e8m0 result;
TVec2 lo_half2 = *reinterpret_cast<const TVec2*>(&x);
TVec2 hi_half2 = *reinterpret_cast<const TVec2*>(&z);
__nv_fp8x2_e8m0 lo_part(lo_half2), hi_part(hi_half2);
result.__x =
(static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16));
return result;
}
)";
}
if (enable_fp4) {
stream << R"(
__host__ __device__ explicit half4_bfloat164(const __nv_fp4x4_e2m1& fp4x4) {
if constexpr (std::is_same_v<T, __half>) {
__nv_fp4x2_storage_t lo_part = static_cast<__nv_fp4x2_storage_t>(fp4x4.__x & 0xFF);
__nv_fp4x2_storage_t hi_part = static_cast<__nv_fp4x2_storage_t>((fp4x4.__x >> 8) & 0xFF);
TVec2 lo_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(lo_part, __NV_E2M1));
TVec2 hi_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(hi_part, __NV_E2M1));
x = reinterpret_cast<T*>(&lo_half2)[0];
y = reinterpret_cast<T*>(&lo_half2)[1];
z = reinterpret_cast<T*>(&hi_half2)[0];
w = reinterpret_cast<T*>(&hi_half2)[1];
} else {
__nv_fp4_e2m1 elem0, elem1, elem2, elem3;
elem0.__x = static_cast<__nv_fp4_storage_t>(fp4x4.__x & 0xF);
elem1.__x = static_cast<__nv_fp4_storage_t>((fp4x4.__x >> 4) & 0xF);
elem2.__x = static_cast<__nv_fp4_storage_t>((fp4x4.__x >> 8) & 0xF);
elem3.__x = static_cast<__nv_fp4_storage_t>((fp4x4.__x >> 12) & 0xF);
x = T(elem0);
y = T(elem1);
z = T(elem2);
w = T(elem3);
}
}
__host__ __device__ explicit operator __nv_fp4x4_e2m1() const {
TVec2 lo_half2 = *reinterpret_cast<const TVec2*>(&x);
TVec2 hi_half2 = *reinterpret_cast<const TVec2*>(&z);
return __nv_fp4x4_e2m1(lo_half2, hi_half2);
}
)";
}
stream << R"(
};
)";
}
if (enable_fp16) {
stream << R"(
using half4 = half4_bfloat164<__half, __half2>;
__host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w) {
return half4(x, y, z, w);
}
)";
}
if (enable_bf16) {
stream << R"(
using nv_bfloat164 = half4_bfloat164<nv_bfloat16, nv_bfloat162>;
__host__ __device__ nv_bfloat164 make_nv_bfloat164(nv_bfloat16 x, nv_bfloat16 y, nv_bfloat16 z, nv_bfloat16 w) {
return nv_bfloat164(x, y, z, w);
}
__host__ __device__ nv_bfloat162 make_nv_bfloat162(nv_bfloat16 x, nv_bfloat16 y) {
return nv_bfloat162(x, y);
}
)";
if (enable_fp8) {
stream << R"(
__host__ __device__ nv_bfloat162 cast_to_nv_bfloat162(const __nv_fp8x2_e4m3& fp8x2) {
__nv_fp8_e4m3 elem0, elem1;
elem0.__x = static_cast<__nv_fp8_storage_t>(fp8x2.__x & 0xFF);
elem1.__x = static_cast<__nv_fp8_storage_t>((fp8x2.__x >> 8) & 0xFF);
nv_bfloat16 x = nv_bfloat16(elem0);
nv_bfloat16 y = nv_bfloat16(elem1);
return nv_bfloat162(x, y);
}
__host__ __device__ nv_bfloat162 cast_to_nv_bfloat162(const __nv_fp8x2_e5m2& fp8x2) {
__nv_fp8_e5m2 elem0, elem1;
elem0.__x = static_cast<__nv_fp8_storage_t>(fp8x2.__x & 0xFF);
elem1.__x = static_cast<__nv_fp8_storage_t>((fp8x2.__x >> 8) & 0xFF);
nv_bfloat16 x = nv_bfloat16(elem0);
nv_bfloat16 y = nv_bfloat16(elem1);
return nv_bfloat162(x, y);
}
__host__ __device__ nv_bfloat162 cast_to_nv_bfloat162(const __nv_fp8x2_e8m0& fp8x2) {
__nv_fp8_e8m0 elem0, elem1;
elem0.__x = static_cast<__nv_fp8_storage_t>(fp8x2.__x & 0xFF);
elem1.__x = static_cast<__nv_fp8_storage_t>((fp8x2.__x >> 8) & 0xFF);
nv_bfloat16 x = nv_bfloat16(elem0);
nv_bfloat16 y = nv_bfloat16(elem1);
return nv_bfloat162(x, y);
}
)";
}
}
if (enable_fp8) {
stream << R"(
__device__ __nv_fp8x2_e5m2 make___nv_fp8x2_e5m2(__nv_fp8_e5m2 x, __nv_fp8_e5m2 y) {
__nv_fp8x2_e5m2 result;
result.__x = (x.__x) | (y.__x << 8);
return result;
}
__device__ __nv_fp8x4_e5m2 make___nv_fp8x4_e5m2(__nv_fp8_e5m2 a, __nv_fp8_e5m2 b, __nv_fp8_e5m2 c, __nv_fp8_e5m2 d) {
__nv_fp8x4_e5m2 result;
result.__x = (a.__x) | (b.__x << 8) | (c.__x << 16) | (d.__x << 24);
return result;
}
__device__ __nv_fp8x2_e4m3 make___nv_fp8x2_e4m3(__nv_fp8_e4m3 x, __nv_fp8_e4m3 y) {
__nv_fp8x2_e4m3 result;
result.__x = (x.__x) | (y.__x << 8);
return result;
}
__device__ __nv_fp8x4_e4m3 make___nv_fp8x4_e4m3(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b, __nv_fp8_e4m3 c, __nv_fp8_e4m3 d) {
__nv_fp8x4_e4m3 result;
result.__x = (a.__x) | (b.__x << 8) | (c.__x << 16) | (d.__x << 24);
return result;
}
__device__ __nv_fp8x2_e8m0 make___nv_fp8x2_e8m0(__nv_fp8_e8m0 x, __nv_fp8_e8m0 y) {
__nv_fp8x2_e8m0 result;
result.__x = (x.__x) | (y.__x << 8);
return result;
}
__device__ __nv_fp8x4_e8m0 make___nv_fp8x4_e8m0(__nv_fp8_e8m0 a, __nv_fp8_e8m0 b, __nv_fp8_e8m0 c, __nv_fp8_e8m0 d) {
__nv_fp8x4_e8m0 result;
result.__x = (a.__x) | (b.__x << 8) | (c.__x << 16) | (d.__x << 24);
return result;
}
)";
}
if (enable_fp4) {
stream << R"(
__device__ __nv_fp4x2_e2m1 make___nv_fp4x2_e2m1(__nv_fp4_e2m1 x, __nv_fp4_e2m1 y) {
__nv_fp4x2_e2m1 result;
result.__x = (x.__x) | (y.__x << 4);
return result;
}
__device__ __nv_fp4x4_e2m1 make___nv_fp4x4_e2m1(__nv_fp4_e2m1 a, __nv_fp4_e2m1 b, __nv_fp4_e2m1 c, __nv_fp4_e2m1 d) {
__nv_fp4x4_e2m1 result;
result.__x = (static_cast<__nv_fp4x4_storage_t>(a.__x)) | (static_cast<__nv_fp4x4_storage_t>(b.__x) << 4) | (static_cast<__nv_fp4x4_storage_t>(c.__x) << 8) | (static_cast<__nv_fp4x4_storage_t>(d.__x) << 12);
return result;
}
__host__ __device__ nv_bfloat162 cast_to_nv_bfloat162(const __nv_fp4x2_e2m1& fp4x2) {
__nv_fp4_e2m1 elem0, elem1;
elem0.__x = static_cast<__nv_fp4_storage_t>(fp4x2.__x & 0xF);
elem1.__x = static_cast<__nv_fp4_storage_t>((fp4x2.__x >> 4) & 0xF);
nv_bfloat16 x = nv_bfloat16(elem0);
nv_bfloat16 y = nv_bfloat16(elem1);
return nv_bfloat162(x, y);
}
)";
}
}
#endif // TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_