|
6 | 6 | import jax.numpy as jnp |
7 | 7 | from jax.experimental import pallas as pl |
8 | 8 |
|
9 | | -try: |
10 | | - from jax.experimental.pallas import triton as plgpu |
11 | | -except ImportError: |
12 | | - from jax.experimental.pallas import gpu as plgpu |
13 | | - |
14 | 9 | from .mhsa import mhsa_kernel, reference_mhsa_kernel |
15 | 10 | from .mhsea import mhsea_kernel, reference_mhsea_kernel |
16 | 11 | from .utils import ( |
17 | 12 | big_number, |
| 13 | + compiler_params, |
18 | 14 | compute_q_and_kv_block_len, |
19 | 15 | create_grid, |
20 | 16 | get_lse_block_spec, |
@@ -58,9 +54,7 @@ def mhsa_forward( |
58 | 54 | out_shape=jax.ShapeDtypeStruct( |
59 | 55 | shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype |
60 | 56 | ), |
61 | | - compiler_params=plgpu.TritonCompilerParams( |
62 | | - num_warps=num_warps, num_stages=num_stages |
63 | | - ), |
| 57 | + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), |
64 | 58 | debug=False, |
65 | 59 | interpret=interpret, |
66 | 60 | name='mhsa_forward', |
@@ -118,9 +112,7 @@ def mhsa_backward( |
118 | 112 | shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype |
119 | 113 | ), |
120 | 114 | ], |
121 | | - compiler_params=plgpu.TritonCompilerParams( |
122 | | - num_warps=num_warps, num_stages=num_stages |
123 | | - ), |
| 115 | + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), |
124 | 116 | debug=False, |
125 | 117 | interpret=interpret, |
126 | 118 | name='mhsa_backward', |
@@ -273,9 +265,7 @@ def mhsea_forward( |
273 | 265 | shape=(batch_len, seq_len, num_heads), dtype=v.dtype |
274 | 266 | ), |
275 | 267 | ], |
276 | | - compiler_params=plgpu.TritonCompilerParams( |
277 | | - num_warps=num_warps, num_stages=num_stages |
278 | | - ), |
| 268 | + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), |
279 | 269 | debug=False, |
280 | 270 | interpret=interpret, |
281 | 271 | name='mhea_forward', |
@@ -377,9 +367,7 @@ def mhsea_backward( |
377 | 367 | shape=(batch_len, seq_len, num_heads, seq_len), dtype=e.dtype |
378 | 368 | ), |
379 | 369 | ], |
380 | | - compiler_params=plgpu.TritonCompilerParams( |
381 | | - num_warps=num_warps, num_stages=num_stages |
382 | | - ), |
| 370 | + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), |
383 | 371 | debug=False, |
384 | 372 | interpret=interpret, |
385 | 373 | name='mhsea_backward_q_vjp', |
@@ -438,9 +426,7 @@ def mhsea_backward( |
438 | 426 | shape=(batch_len, seq_len, num_heads, head_len), dtype=v.dtype |
439 | 427 | ), |
440 | 428 | ], |
441 | | - compiler_params=plgpu.TritonCompilerParams( |
442 | | - num_warps=num_warps, num_stages=num_stages |
443 | | - ), |
| 429 | + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), |
444 | 430 | debug=False, |
445 | 431 | interpret=interpret, |
446 | 432 | name='mhsea_backward_kv_vjp', |
|
0 commit comments