-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy path06-fused-attention.py
More file actions
378 lines (327 loc) · 12.9 KB
/
06-fused-attention.py
File metadata and controls
378 lines (327 loc) · 12.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
# copy from https://github.com/NVIDIA/TileGym/blob/main/src/tilegym/ops/cutile/attention.py
import torch
import cuda.tile as ct
from cuda.tile import RoundingMode as RMd
import math
INV_LOG_2 = 1.0 / math.log(2)
# Define type aliases for Constant integers and booleans
ConstInt = ct.Constant[int]
ConstBool = ct.Constant[bool]
# --- FMHA Kernel Implementation ---
@ct.kernel()
def fmha_kernel(Q, K, V, Out,
qk_scale: float,
input_pos: int,
TILE_D: ConstInt, # TILE_D = hidden_size
H: ConstInt,
TILE_M: ConstInt,
TILE_N: ConstInt,
QUERY_GROUP_SIZE: ConstInt,
CAUSAL: ConstBool,
EVEN_K: ConstBool):
"""
cuTile kernel for Fused Multi-Head Attention (FMHA).
Computes attention output for a specific batch item and head, using tiling and online softmax.
"""
# Map block IDs to batch and head indices
bid_x = ct.bid(0)
bid_y = ct.bid(1)
batch_idx = bid_y // H
head_idx = bid_y % H
off_kv_h = head_idx // QUERY_GROUP_SIZE
# Adjust qk_scale for exp2
qk_scale = qk_scale * INV_LOG_2
# Initialize offsets for current query tile (M-dimension)
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32) # [TILE_M]
offs_m += input_pos
offs_m = offs_m[:, None] # [TILE_M, 1]
# Initialize local offsets for key/value tile (N-dimension)
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32) # [TILE_N]
offs_n_tile = offs_n_tile[None, :] # [1, TILE_N]
# Initialize online softmax accumulators in float32 for stability
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
# Load query tile for this batch, head, and M-chunk
q = ct.load(
Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D)
).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D]
# Loop over k, v and update accumulator
m_end = input_pos + (bid_x + 1) * TILE_M
k_seqlen = K.shape[2]
if CAUSAL:
# When kv pos could exceed q pos
mask_start = (input_pos + bid_x * TILE_M) // TILE_N
# When kv pos could exceed k_seqlen
mask_start = min(mask_start, k_seqlen // TILE_N)
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
# Loop over K, V blocks (N-dimension chunks)
for j in range(0, Tc):
# --- Compute QK product ---
k = ct.load(
K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2,
)
k = k.reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N]
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k, qk) # [TILE_M, TILE_N]
# --- Apply Causal Masking ---
if (CAUSAL or not EVEN_K) and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = ct.full((TILE_M, TILE_N), True, dtype=ct.bool_)
# Out of bound mask
if not EVEN_K:
mask = mask & (offs_n < k_seqlen)
# Causal mask
if CAUSAL:
mask = mask & (offs_m >= offs_n) # [TILE_M, TILE_N]
mask = ct.where(mask, 0.0, -math.inf) # [TILE_M, TILE_N]
qk += mask
# --- Online Softmax Update ---
# Moving qk_scale multiplication after reduce_max is to improve performance.
m_ij = max(m_i, ct.max(qk, axis=-1, keepdims=True) * qk_scale)
qk = qk * qk_scale - m_ij # [TILE_M, TILE_N]
# Attention weights
p = ct.exp2(qk, flush_to_zero=True) # [TILE_M, TILE_N]
l_ij = ct.sum(p, axis=-1, keepdims=True) # [TILE_M, 1]
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True) # [TILE_M, 1]
# Update m_i and l_i
l_i = l_i * alpha + l_ij # [TILE_M, 1]
# Scale acc
acc = acc * alpha # [TILE_M, TILE_N]
# --- Compute PV product ---
v = ct.load(
V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, TILE_D),
latency=4,
).reshape((TILE_N, TILE_D)) # [TILE_N, TILE_D]
p = p.astype(Q.dtype)
acc = ct.mma(p, v, acc) # [TILE_M, TILE_N]
m_i = m_ij # [TILE_M, 1]
# --- Final Normalization and Store ---
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
import sys
import os
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
from autotuner import Autotuner, Config, autotune
def _fmha_autotune_configs():
"""
Get autotune configurations for FMHA kernel.
"""
gpu_capability = torch.cuda.get_device_capability()
if gpu_capability in [(12, 0), (12, 1)]:
# sm120, sm121
configs = [
Config(TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2),
]
else:
# sm100 (Blackwell)
configs = [
Config(TILE_M=256, TILE_N=128, num_ctas=1, occupancy=1),
Config(TILE_M=128, TILE_N=128, num_ctas=1, occupancy=2),
]
return configs
@autotune(search_space=_fmha_autotune_configs())
def cutile_autotune_fmha(
q,
k,
v,
o,
sm_scale,
input_pos,
hidden_size,
num_heads,
query_group_size,
is_causal,
EVEN_K,
autotuner: Autotuner | None = None,
):
batch_size, _, q_len, _ = q.shape
tuned_result = autotuner(
torch.cuda.current_stream(),
grid_fn=lambda named_args, cfg: (
math.ceil(q_len / cfg.TILE_M),
batch_size * num_heads,
1,
),
kernel=fmha_kernel,
args_fn=lambda cfg: (
q,
k,
v,
o,
sm_scale,
input_pos,
hidden_size,
num_heads,
cfg.TILE_M,
cfg.TILE_N,
query_group_size,
is_causal,
EVEN_K,
),
)
return o
def tile_prefill_fmha(q, k, v, sm_scale, is_causal=True, kernel_configs=None):
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(q.size(-1))
batch_size, num_heads, q_len, hidden_size = q.shape
_, num_head_kv, k_len, _ = k.shape
assert num_heads % num_head_kv == 0
query_group_size = num_heads // num_head_kv
q = q.contiguous() if not q.is_contiguous() else q
k = k.contiguous() if not k.is_contiguous() else k
v = v.contiguous() if not v.is_contiguous() else v
o = torch.empty_like(q)
input_pos = 0 # prefill, causal
max_tile_n = max(cfg.kwargs['TILE_N'] for cfg in _fmha_autotune_configs())
EVEN_K = (k_len % max_tile_n) == 0
return cutile_autotune_fmha(
q, k, v, o, sm_scale, input_pos, hidden_size, num_heads, query_group_size, is_causal, EVEN_K
)
def cutile_fmha(
q,
k,
v,
scaling=None,
is_causal=True,
**kwargs,
):
if scaling is None:
scaling = 1.0 / math.sqrt(q.size(-1))
kernel_configs = kwargs.get('kernel_configs', None)
o = tile_prefill_fmha(
q, k, v, scaling, is_causal,
kernel_configs
)
return o
def reference_fmha(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scaling: float = None,
is_causal: bool = True,
):
"""Reference implementation using PyTorch"""
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=is_causal, scale=scaling
)
DEVICE = torch.cuda.current_device()
BATCH, H, N_CTX, HEAD_DIM = 4, 48, 1024, 128
dtype = torch.float16
q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE)
k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE)
v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE)
cutile_output = cutile_fmha(q, k, v)
torch_output = reference_fmha(q, k, v)
if torch.allclose(cutile_output, torch_output, atol=1e-2, rtol=0):
print("✅ cuTile and Torch match")
else:
print("❌ cuTile and Torch differ")
import triton
try:
from flash_attn.flash_attn_interface import \
flash_attn_qkvpacked_func as flash_attn_func
HAS_FLASH = True
except BaseException:
HAS_FLASH = False
from triton_kernels import is_blackwell, is_rtx_blackwell, is_hopper
TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2')
if is_rtx_blackwell():
# RTX 5090 only has 32G Memory
BATCH, N_HEADS = 1, 4
else:
BATCH, N_HEADS = 4, 32
# vary seq length for fixed head and batch=4
configs = []
for HEAD_DIM in [64, 128]:
# for mode in ["fwd", "bwd"]:
for mode in ["fwd"]:
for causal in [True, False]:
# Enable warpspec for causal fwd on Hopper
enable_ws = mode == "fwd" and (is_blackwell() or (is_hopper() and not causal))
for warp_specialize in [False, True] if enable_ws else [False]:
configs.append(
triton.testing.Benchmark(
x_names=["N_CTX"],
x_vals=[2**i for i in range(10, 15)],
line_arg="provider",
line_vals=["cutile-fp16"] + ["triton-fp16"] +
(["cutile-fp8"] if TORCH_HAS_FP8 else []) +
(["triton-fp8"] if TORCH_HAS_FP8 else []) +
(["flash"] if HAS_FLASH else []),
line_names=["cuTile [FP16]"] + ["Triton [fp16]"] +
(["cuTile [FP8]"] if TORCH_HAS_FP8 else []) +
(["Triton [fp8]"] if TORCH_HAS_FP8 else []) +
(["Flash-2"] if HAS_FLASH else []),
styles=[("red", "-"), ("blue", "-"), ("green", "-"), ("orange", "-"), ("pink", "-")],
ylabel="TFLOPS",
plot_name=
f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}-warp_specialize={warp_specialize}",
args={
"H": N_HEADS,
"BATCH": BATCH,
"HEAD_DIM": HEAD_DIM,
"mode": mode,
"causal": causal,
"warp_specialize": warp_specialize,
},
))
from triton_kernels import attention as triton_attention
@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, warp_specialize, mode, provider, device=DEVICE):
assert mode in ["fwd", "bwd"]
dtype = torch.float16
if "triton" in provider:
q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
if mode == "fwd" and "fp8" in provider:
q = q.to(torch.float8_e5m2)
k = k.to(torch.float8_e5m2)
v = v.permute(0, 1, 3, 2).contiguous()
v = v.permute(0, 1, 3, 2)
v = v.to(torch.float8_e5m2)
sm_scale = 1.3
fn = lambda: triton_attention(q, k, v, causal, sm_scale, warp_specialize)
if mode == "bwd":
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn)
if "cutile" in provider:
q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
if mode == "fwd" and "fp8" in provider:
q = q.to(torch.float8_e5m2)
k = k.to(torch.float8_e5m2)
v = v.permute(0, 1, 3, 2).contiguous()
v = v.permute(0, 1, 3, 2)
v = v.to(torch.float8_e5m2)
sm_scale = 1.3
fn = lambda: cutile_fmha(q, k, v, scaling=sm_scale, is_causal=causal)
ms = triton.testing.do_bench(fn)
if provider == "flash":
qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, causal=causal)
if mode == "bwd":
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn)
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM
total_flops = 2 * flops_per_matmul
if causal:
total_flops *= 0.5
if mode == "bwd":
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
return total_flops * 1e-12 / (ms * 1e-3)
if __name__ == "__main__":
# only works on Blackwl GPUs right now
bench_flash_attention.run(show_plots=True, print_data=True)