forked from InfiniTensor/go-llama-go
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkernels.py
More file actions
341 lines (278 loc) · 10.6 KB
/
kernels.py
File metadata and controls
341 lines (278 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
import torch
import triton
import triton.language as tl
@triton.jit
def _rms_norm_kernel(
X_ptr, W_ptr, Out_ptr,
stride_x_row, stride_y_row,
N_COLS, eps,
BLOCK_SIZE: tl.constexpr
):
# Current row index
row_idx = tl.program_id(0)
# Calculate start pointers for the current row
row_start_ptr = X_ptr + row_idx * stride_x_row
out_row_start_ptr = Out_ptr + row_idx * stride_y_row
# Generate column offsets
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < N_COLS
# Load data and weights, use float32 for precision
x_val = tl.load(row_start_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
w_val = tl.load(W_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
# RMSNorm computation logic
# 1. Calculate mean square
x_sq = x_val * x_val
mean_sq = tl.sum(x_sq, axis=0) / N_COLS
# 2. Calculate rstd (reciprocal standard deviation)
rstd = tl.rsqrt(mean_sq + eps)
# 3. Normalize and apply gamma weight
y_val = x_val * rstd * w_val
# Write back to memory
tl.store(out_row_start_ptr + offsets, y_val, mask=mask)
def rms_norm_triton(x, weight, eps):
# Flatten input to (Total_Rows, Hidden_Size)
# e.g., (Batch, Seq_Len, Hidden) -> (Batch * Seq_Len, Hidden)
M = x.numel() // x.shape[-1]
N = x.shape[-1]
# Ensure input is contiguous
if not x.is_contiguous():
x = x.contiguous()
# Flatten to get correct strides
x_flat = x.view(-1, N)
# Allocate output space
y = torch.empty_like(x)
# Calculate Block Size, round up to next power of 2
BLOCK_SIZE = triton.next_power_of_2(N)
# Grid dimensions: parallelize over rows
grid = (M,)
_rms_norm_kernel[grid](
x, weight, y,
x_flat.stride(0), # Input row stride
x_flat.stride(0), # Output row stride (usually same as input)
N, eps,
BLOCK_SIZE=BLOCK_SIZE
)
return y
@triton.jit
def _rotary_kernel(
X_ptr, Cos_ptr, Sin_ptr,
stride_x_batch, stride_x_seq, stride_x_head, stride_x_dim,
stride_c_seq, stride_c_dim,
HEAD_DIM: tl.constexpr,
BLOCK_SIZE: tl.constexpr
):
"""
Core kernel function for Rotary Positional Encoding (RoPE).
This function applies rotary positional encoding transformation to the input tensor.
By pairing adjacent dimensions and applying rotation, it captures relative position information in the sequence.
Args:
X_ptr: Pointer to input tensor, shape [batch, seq_len, num_heads, head_dim]
Cos_ptr: Pointer to cosine table, shape [seq_len, head_dim//2]
Sin_ptr: Pointer to sine table, shape [seq_len, head_dim//2]
stride_x_batch: Stride for batch dimension of input tensor
stride_x_seq: Stride for sequence dimension of input tensor
stride_x_head: Stride for head dimension of input tensor
stride_x_dim: Stride for dimension of input tensor
stride_c_seq: Stride for sequence dimension of Cos/Sin table
stride_c_dim: Stride for dimension of Cos/Sin table
HEAD_DIM: Size of the head dimension
BLOCK_SIZE: Block size for parallel computation optimization
Returns:
None, modified in-place
"""
# Grid structure: (Batch, Seq, Heads)
batch_id = tl.program_id(0)
seq_id = tl.program_id(1)
head_id = tl.program_id(2)
# Calculate offsets for the current head
x_offset = (
batch_id * stride_x_batch +
seq_id * stride_x_seq +
head_id * stride_x_head
)
# Cos/Sin depend only on sequence position (and dim)
c_offset = seq_id * stride_c_seq
# RoPE applies to pairs in the head dimension
HALF_DIM = HEAD_DIM // 2
# Process range [0, HALF_DIM) in parallel
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < HALF_DIM
# Pointers to the two halves of the head vector
# Usually last dim is contiguous (stride=1), but use stride for generality
# Indexing: x[..., i] and x[..., i + HALF_DIM]
x0_ptr = X_ptr + x_offset + offsets * stride_x_dim
x1_ptr = X_ptr + x_offset + (offsets + HALF_DIM) * stride_x_dim
c_ptr = Cos_ptr + c_offset + offsets * stride_c_dim
s_ptr = Sin_ptr + c_offset + offsets * stride_c_dim # Sin table has same layout as Cos
# Load data
x0 = tl.load(x0_ptr, mask=mask, other=0.0).to(tl.float32)
x1 = tl.load(x1_ptr, mask=mask, other=0.0).to(tl.float32)
c = tl.load(c_ptr, mask=mask, other=0.0).to(tl.float32)
s = tl.load(s_ptr, mask=mask, other=0.0).to(tl.float32)
# Apply rotation transformation
# x0_new = x0 * cos - x1 * sin
# x1_new = x0 * sin + x1 * cos
y0 = x0 * c - x1 * s
y1 = x0 * s + x1 * c
# Store back to original location (in-place)
tl.store(x0_ptr, y0, mask=mask)
tl.store(x1_ptr, y1, mask=mask)
def apply_rotary_pos_emb_triton(x, cos, sin):
# x: (Batch, Seq, Heads, Dim)
# cos, sin: (Seq, Dim/2)
# We perform the operation in-place on x
# Ensure input is contiguous
if not x.is_contiguous():
x = x.contiguous()
if not cos.is_contiguous():
cos = cos.contiguous()
if not sin.is_contiguous():
sin = sin.contiguous()
B, S, H, D = x.shape
# Block size >= HALF_DIM
HALF_DIM = D // 2
BLOCK_SIZE = triton.next_power_of_2(HALF_DIM)
# Grid: (Batch, Seq, Heads)
grid = (B, S, H)
_rotary_kernel[grid](
x, cos, sin,
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
cos.stride(0), cos.stride(1),
HEAD_DIM=D,
BLOCK_SIZE=BLOCK_SIZE
)
return x
@triton.jit
def _flash_attn_fwd_kernel(
Q, K, V, sm_scale,
L, # Sequence length
Out,
stride_q_batch, stride_q_head, stride_q_m, stride_q_k,
stride_k_batch, stride_k_head, stride_k_n, stride_k_k,
stride_v_batch, stride_v_head, stride_v_n, stride_v_k,
stride_o_batch, stride_o_head, stride_o_m, stride_o_n,
Z, H, N_CTX, # Batch, Heads, Context Length
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
):
# Grid: (Tr, B, H)
start_m = tl.program_id(0)
off_b = tl.program_id(1)
off_h = tl.program_id(2)
# Offsets for Q, K, V pointers
# shape: (B, H, S, D)
q_offset = off_b * stride_q_batch + off_h * stride_q_head
k_offset = off_b * stride_k_batch + off_h * stride_k_head
v_offset = off_b * stride_v_batch + off_h * stride_v_head
o_offset = off_b * stride_o_batch + off_h * stride_o_head
# Block pointers
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_q_m, stride_q_k),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(BLOCK_DMODEL, N_CTX), # Transposed for Q @ K.T
strides=(stride_k_k, stride_k_n),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_v_n, stride_v_k),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
# Initialize
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# Load Q
# We assume N_CTX is multiple of BLOCK_M/N or handle padding if needed.
# For this assignment, simple boundary check might be enough but block_ptr helps.
q = tl.load(Q_block_ptr, boundary_check=(0, 1))
# Loop over K, V blocks
# Causal Masking: we only attend to keys up to the current query position
# The end of the K loop is determined by (start_m + 1) * BLOCK_M
# We loop from 0 to (start_m + 1) * BLOCK_M
lo = 0
hi = (start_m + 1) * BLOCK_M
for start_n in range(lo, hi, BLOCK_N):
# Load K, V
k = tl.load(K_block_ptr, boundary_check=(0, 1))
v = tl.load(V_block_ptr, boundary_check=(0, 1))
# Compute Q @ K.T
qk = tl.dot(q, k)
# Apply scaling
qk *= sm_scale
# Apply Causal Mask
# If the block is on the diagonal (start_n == start_m * BLOCK_M), we mask
if start_n + BLOCK_N > start_m * BLOCK_M:
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = start_n + tl.arange(0, BLOCK_N)
mask = offs_m[:, None] >= offs_n[None, :]
qk = tl.where(mask, qk, float("-inf"))
# Online Softmax update
m_i_new = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_i_new)
p = tl.exp(qk - m_i_new[:, None])
alpha = tl.exp(m_i - m_i_new)
l_i = l_i * alpha + tl.sum(p, 1)
# acc update
# acc = acc * alpha + p @ v
acc = acc * alpha[:, None]
acc += tl.dot(p.to(v.dtype), v) # Precision handling
# Update statistics
m_i = m_i_new
# Advance pointers
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
# Final normalization
# Check for division by zero (e.g. fully masked blocks?)
# l_i should be > 0 if any attention was valid.
acc = acc / l_i[:, None]
# Store Output
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_o_m, stride_o_n),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
tl.store(O_block_ptr, acc.to(Out.dtype.element_ty), boundary_check=(0, 1))
def flash_attention_triton(q, k, v):
# q, k, v: (Batch, Heads, Seq_Len, Dim)
# Shape checks
B, H, S, D = q.shape
# Ensure contiguous
if not q.is_contiguous(): q = q.contiguous()
if not k.is_contiguous(): k = k.contiguous()
if not v.is_contiguous(): v = v.contiguous()
# block sizes
BLOCK_M = 128
BLOCK_N = 64
# Grid
# Parallelize over (Seq_Len // BLOCK_M, Batch, Heads)
grid = (triton.cdiv(S, BLOCK_M), B, H)
scale = 1.0 / (D ** 0.5)
# Initialize output
o = torch.empty_like(q)
_flash_attn_fwd_kernel[grid](
q, k, v, scale,
S,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
B, H, S,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D
)
return o