-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathreduction_all_max_naive_opt_flexible.ptx
More file actions
156 lines (134 loc) · 4.9 KB
/
reduction_all_max_naive_opt_flexible.ptx
File metadata and controls
156 lines (134 loc) · 4.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
.version 7.0
.target sm_50 // enough for my Titan X
.address_size 64
// Similar to reduction_all_max_naive_opt.ptx, except that
// we make use of multi-dimensional blocks where some warps
// in the block can be inactive. This allows us to experiment
// with different layouts for a fixed number of threads.
//
// In particular, block size should be (32, N, 1024/(32*N)),
// in which case N warps will be active and the rest will be
// stuck in a barrier to prevent other threads from occupying
// the SM.
.visible .entry reductionAllMaxNaiveOptFlexible (
.param .u64 ptrIn,
.param .u64 ptrOut,
.param .u64 numBlocks
) {
.reg .pred %p0;
// Arguments
.reg .u64 %ptrIn;
.reg .u64 %ptrOut;
.reg .u64 %numBlocks;
.reg .u64 %i;
.reg .u64 %tmp<2>;
.reg .u32 %stmp<2>;
.reg .u64 %blockSize;
.reg .f32 %curMax;
.reg .f32 %ftmp;
.reg .v4 .f32 %ftmpVec<2>;
.shared .align 4 .f32 results[32];
// Load arguments.
ld.param.u64 %ptrIn, [ptrIn];
ld.param.u64 %ptrOut, [ptrOut];
ld.param.u64 %numBlocks, [numBlocks];
// We might not do any work from certain threads of this block,
// for experimentation purposes.
// In particular, we do work from tid.z == 0.
mov.u32 %stmp0, %tid.z;
setp.eq.u32 %p0, %stmp0, 0;
@!%p0 bra end_of_block;
// blockSize = ntid.x * ntid.y (ignore ntid.z)
mov.u32 %stmp0, %ntid.x;
mov.u32 %stmp1, %ntid.y;
mul.wide.u32 %blockSize, %stmp0, %stmp1;
// Input is offset ctaid.x*4*blockSize*numBlocks, output offset by 4*ctaid.x
cvt.u64.u32 %tmp0, %ctaid.x;
shl.b64 %tmp0, %tmp0, 2;
add.u64 %ptrOut, %ptrOut, %tmp0;
mul.lo.u64 %tmp0, %tmp0, %blockSize;
mul.lo.u64 %tmp0, %tmp0, %numBlocks;
add.u64 %ptrIn, %ptrIn, %tmp0;
// Each rank is offset by 16 bytes.
cvt.u64.u32 %tmp0, %tid.x;
cvt.u64.u32 %tmp1, %tid.y;
shl.b64 %tmp1, %tmp1, 5;
add.u64 %tmp0, %tmp0, %tmp1;
shl.b64 %tmp0, %tmp0, 4;
add.u64 %ptrIn, %ptrIn, %tmp0;
// Base condition: use our output.
ld.global.f32 %curMax, [%ptrIn];
// Stride is blockSize*16 bytes.
shl.b64 %tmp0, %blockSize, 4;
mov.u64 %i, 0;
loop_start:
ld.global.v4.f32 %ftmpVec0, [%ptrIn];
add.u64 %ptrIn, %ptrIn, %tmp0;
ld.global.v4.f32 %ftmpVec1, [%ptrIn];
add.u64 %ptrIn, %ptrIn, %tmp0;
max.f32 %curMax, %curMax, %ftmpVec0.w;
max.f32 %curMax, %curMax, %ftmpVec0.x;
max.f32 %curMax, %curMax, %ftmpVec0.y;
max.f32 %curMax, %curMax, %ftmpVec0.z;
max.f32 %curMax, %curMax, %ftmpVec1.w;
max.f32 %curMax, %curMax, %ftmpVec1.x;
max.f32 %curMax, %curMax, %ftmpVec1.y;
max.f32 %curMax, %curMax, %ftmpVec1.z;
add.u64 %i, %i, 8;
setp.lt.u64 %p0, %i, %numBlocks;
@%p0 bra loop_start;
loop_end:
// Synchronize on warp using a hypercube.
// https://en.wikipedia.org/wiki/Hypercube_(communication_pattern)
shfl.sync.bfly.b32 %ftmp, %curMax, 1, 0x1f, 0xffffffff;
max.f32 %curMax, %curMax, %ftmp;
shfl.sync.bfly.b32 %ftmp, %curMax, 2, 0x1f, 0xffffffff;
max.f32 %curMax, %curMax, %ftmp;
shfl.sync.bfly.b32 %ftmp, %curMax, 4, 0x1f, 0xffffffff;
max.f32 %curMax, %curMax, %ftmp;
shfl.sync.bfly.b32 %ftmp, %curMax, 8, 0x1f, 0xffffffff;
max.f32 %curMax, %curMax, %ftmp;
shfl.sync.bfly.b32 %ftmp, %curMax, 16, 0x1f, 0xffffffff;
max.f32 %curMax, %curMax, %ftmp;
// Our warp writes to results[tid.y].
mov.u32 %stmp0, results;
mov.u32 %stmp1, %tid.y;
shl.b32 %stmp1, %stmp1, 2;
add.u32 %stmp0, %stmp0, %stmp1;
// Only write from rank 0 of warp.
mov.u32 %stmp1, %tid.x;
setp.eq.u32 %p0, %stmp1, 0;
@%p0 st.shared.f32 [%stmp0], %curMax;
// Wait for all threads to write to shmem
cvt.u32.u64 %stmp0, %blockSize;
bar.sync 0, %stmp0;
// Exit on all but first warp, where we do final reduction.
mov.u32 %stmp1, %tid.y;
setp.eq.u32 %p0, %stmp1, 0;
@!%p0 bra end_of_block;
// Reduce the shared memory from the first warp.
mov.u32 %stmp1, %tid.x;
mov.u32 %stmp0, %ntid.y;
setp.lt.u32 %p0, %stmp1, %stmp0; // only reduce when tid.x < ntid.y
shl.b32 %stmp1, %stmp1, 2;
mov.u32 %stmp0, results;
add.u32 %stmp0, %stmp0, %stmp1;
@%p0 ld.shared.f32 %curMax, [%stmp0];
shfl.sync.bfly.b32 %ftmp, %curMax, 1, 0x1f, 0xffffffff;
max.f32 %curMax, %curMax, %ftmp;
shfl.sync.bfly.b32 %ftmp, %curMax, 2, 0x1f, 0xffffffff;
max.f32 %curMax, %curMax, %ftmp;
shfl.sync.bfly.b32 %ftmp, %curMax, 4, 0x1f, 0xffffffff;
max.f32 %curMax, %curMax, %ftmp;
shfl.sync.bfly.b32 %ftmp, %curMax, 8, 0x1f, 0xffffffff;
max.f32 %curMax, %curMax, %ftmp;
shfl.sync.bfly.b32 %ftmp, %curMax, 16, 0x1f, 0xffffffff;
max.f32 %curMax, %curMax, %ftmp;
setp.eq.u32 %p0, %stmp1, 0;
@%p0 st.global.f32 [%ptrOut], %curMax;
end_of_block:
// Synchronize across all warps to make sure the block keeps
// the SM busy and unable to schedule anything other blocks.
bar.sync 1;
ret;
}