|
10 | 10 | from .mhsea import mhsea_kernel, reference_mhsea_kernel |
11 | 11 | from .utils import ( |
12 | 12 | big_number, |
| 13 | + compiler_params, |
13 | 14 | compute_q_and_kv_block_len, |
14 | 15 | create_grid, |
15 | 16 | get_lse_block_spec, |
@@ -53,9 +54,7 @@ def mhsa_forward( |
53 | 54 | out_shape=jax.ShapeDtypeStruct( |
54 | 55 | shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype |
55 | 56 | ), |
56 | | - compiler_params=dict( |
57 | | - triton=dict(num_warps=num_warps, num_stages=num_stages) |
58 | | - ), |
| 57 | + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), |
59 | 58 | debug=False, |
60 | 59 | interpret=interpret, |
61 | 60 | name='mhsa_forward', |
@@ -113,9 +112,7 @@ def mhsa_backward( |
113 | 112 | shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype |
114 | 113 | ), |
115 | 114 | ], |
116 | | - compiler_params=dict( |
117 | | - triton=dict(num_warps=num_warps, num_stages=num_stages) |
118 | | - ), |
| 115 | + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), |
119 | 116 | debug=False, |
120 | 117 | interpret=interpret, |
121 | 118 | name='mhsa_backward', |
@@ -268,9 +265,7 @@ def mhsea_forward( |
268 | 265 | shape=(batch_len, seq_len, num_heads), dtype=v.dtype |
269 | 266 | ), |
270 | 267 | ], |
271 | | - compiler_params=dict( |
272 | | - triton=dict(num_warps=num_warps, num_stages=num_stages) |
273 | | - ), |
| 268 | + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), |
274 | 269 | debug=False, |
275 | 270 | interpret=interpret, |
276 | 271 | name='mhea_forward', |
@@ -372,9 +367,7 @@ def mhsea_backward( |
372 | 367 | shape=(batch_len, seq_len, num_heads, seq_len), dtype=e.dtype |
373 | 368 | ), |
374 | 369 | ], |
375 | | - compiler_params=dict( |
376 | | - triton=dict(num_warps=num_warps, num_stages=num_stages) |
377 | | - ), |
| 370 | + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), |
378 | 371 | debug=False, |
379 | 372 | interpret=interpret, |
380 | 373 | name='mhsea_backward_q_vjp', |
@@ -433,9 +426,7 @@ def mhsea_backward( |
433 | 426 | shape=(batch_len, seq_len, num_heads, head_len), dtype=v.dtype |
434 | 427 | ), |
435 | 428 | ], |
436 | | - compiler_params=dict( |
437 | | - triton=dict(num_warps=num_warps, num_stages=num_stages) |
438 | | - ), |
| 429 | + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), |
439 | 430 | debug=False, |
440 | 431 | interpret=interpret, |
441 | 432 | name='mhsea_backward_kv_vjp', |
|
0 commit comments