Skip to content

Commit f9b237a

Browse files
authored
Merge pull request #34 from microsoft/ae-foster/fix-jax-version-pallas
Use `TritonCompilerParams` in `pallas_call`
2 parents 730abba + 2634e24 commit f9b237a

File tree

5 files changed

+28
-27
lines changed

5 files changed

+28
-27
lines changed

folx/experimental/pallas/attention/custom_gradients.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .mhsea import mhsea_kernel, reference_mhsea_kernel
1111
from .utils import (
1212
big_number,
13+
compiler_params,
1314
compute_q_and_kv_block_len,
1415
create_grid,
1516
get_lse_block_spec,
@@ -53,9 +54,7 @@ def mhsa_forward(
5354
out_shape=jax.ShapeDtypeStruct(
5455
shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype
5556
),
56-
compiler_params=dict(
57-
triton=dict(num_warps=num_warps, num_stages=num_stages)
58-
),
57+
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
5958
debug=False,
6059
interpret=interpret,
6160
name='mhsa_forward',
@@ -113,9 +112,7 @@ def mhsa_backward(
113112
shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype
114113
),
115114
],
116-
compiler_params=dict(
117-
triton=dict(num_warps=num_warps, num_stages=num_stages)
118-
),
115+
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
119116
debug=False,
120117
interpret=interpret,
121118
name='mhsa_backward',
@@ -268,9 +265,7 @@ def mhsea_forward(
268265
shape=(batch_len, seq_len, num_heads), dtype=v.dtype
269266
),
270267
],
271-
compiler_params=dict(
272-
triton=dict(num_warps=num_warps, num_stages=num_stages)
273-
),
268+
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
274269
debug=False,
275270
interpret=interpret,
276271
name='mhea_forward',
@@ -372,9 +367,7 @@ def mhsea_backward(
372367
shape=(batch_len, seq_len, num_heads, seq_len), dtype=e.dtype
373368
),
374369
],
375-
compiler_params=dict(
376-
triton=dict(num_warps=num_warps, num_stages=num_stages)
377-
),
370+
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
378371
debug=False,
379372
interpret=interpret,
380373
name='mhsea_backward_q_vjp',
@@ -433,9 +426,7 @@ def mhsea_backward(
433426
shape=(batch_len, seq_len, num_heads, head_len), dtype=v.dtype
434427
),
435428
],
436-
compiler_params=dict(
437-
triton=dict(num_warps=num_warps, num_stages=num_stages)
438-
),
429+
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
439430
debug=False,
440431
interpret=interpret,
441432
name='mhsea_backward_kv_vjp',

folx/experimental/pallas/attention/forward_laplacian.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .mhsea import reference_mhsea_kernel
1414
from .utils import (
1515
big_number,
16+
compiler_params,
1617
compute_q_and_kv_block_len,
1718
create_grid,
1819
get_input_mask_block_spec,
@@ -153,9 +154,7 @@ def mhsa_forward_laplacian(
153154
dtype=q.dtype, # o.laplacian
154155
),
155156
],
156-
compiler_params=dict(
157-
triton=dict(num_warps=num_warps, num_stages=num_stages)
158-
),
157+
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
159158
debug=False,
160159
interpret=interpret,
161160
name='mhsa_forward_laplacian',
@@ -588,9 +587,7 @@ def mhsea_forward_laplacian(
588587
dtype=v.dtype, # o.laplacian
589588
),
590589
],
591-
compiler_params=dict(
592-
triton=dict(num_warps=num_warps, num_stages=num_stages)
593-
),
590+
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
594591
debug=False,
595592
interpret=interpret,
596593
name='mhsea_forward_laplacian',

folx/experimental/pallas/attention/mhsa.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from .utils import (
1010
big_number,
11+
compiler_params,
1112
compute_q_and_kv_block_len,
1213
create_grid,
1314
get_mask_block_spec,
@@ -58,9 +59,7 @@ def mhsa(
5859
out_shape=jax.ShapeDtypeStruct(
5960
shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype
6061
),
61-
compiler_params=dict(
62-
triton=dict(num_warps=num_warps, num_stages=num_stages)
63-
),
62+
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
6463
debug=False,
6564
interpret=interpret,
6665
name='mhsa',

folx/experimental/pallas/attention/mhsea.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from .utils import (
1010
big_number,
11+
compiler_params,
1112
compute_q_and_kv_block_len,
1213
create_grid,
1314
get_lse_block_spec,
@@ -58,9 +59,7 @@ def mhsea(
5859
shape=(batch_len, seq_len, num_heads), dtype=q.dtype
5960
), # lse
6061
],
61-
compiler_params=dict(
62-
triton=dict(num_warps=num_warps, num_stages=num_stages)
63-
),
62+
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
6463
debug=False,
6564
interpret=interpret,
6665
name='mhsea',

folx/experimental/pallas/attention/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44
import jax.numpy as jnp
55
from jax.experimental import pallas as pl
66

7+
try:
8+
from jax.experimental.pallas import triton as plgpu
9+
except ImportError:
10+
from jax.experimental.pallas import gpu as plgpu
11+
12+
13+
from packaging.version import Version
14+
715

816
def sum_columns(x: jax.Array) -> jax.Array:
917
return x.sum(axis=1, keepdims=True)
@@ -210,3 +218,10 @@ def big_number(dtype) -> float:
210218
return 1e40
211219
else:
212220
raise ValueError(f'Unexpected dtype {dtype}')
221+
222+
223+
def compiler_params(num_warps, num_stages):
224+
if Version(jax.__version__) >= Version('0.4.34'):
225+
return plgpu.TritonCompilerParams(num_warps=num_warps, num_stages=num_stages)
226+
else:
227+
return dict(triton=dict(num_warps=num_warps, num_stages=num_stages))

0 commit comments

Comments
 (0)