Skip to content

Commit 2634e24

Browse files
committed
Fix for lower versions
1 parent a4a1401 commit 2634e24

File tree

5 files changed

+28
-47
lines changed

5 files changed

+28
-47
lines changed

folx/experimental/pallas/attention/custom_gradients.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,11 @@
66
import jax.numpy as jnp
77
from jax.experimental import pallas as pl
88

9-
try:
10-
from jax.experimental.pallas import triton as plgpu
11-
except ImportError:
12-
from jax.experimental.pallas import gpu as plgpu
13-
149
from .mhsa import mhsa_kernel, reference_mhsa_kernel
1510
from .mhsea import mhsea_kernel, reference_mhsea_kernel
1611
from .utils import (
1712
big_number,
13+
compiler_params,
1814
compute_q_and_kv_block_len,
1915
create_grid,
2016
get_lse_block_spec,
@@ -58,9 +54,7 @@ def mhsa_forward(
5854
out_shape=jax.ShapeDtypeStruct(
5955
shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype
6056
),
61-
compiler_params=plgpu.TritonCompilerParams(
62-
num_warps=num_warps, num_stages=num_stages
63-
),
57+
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
6458
debug=False,
6559
interpret=interpret,
6660
name='mhsa_forward',
@@ -118,9 +112,7 @@ def mhsa_backward(
118112
shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype
119113
),
120114
],
121-
compiler_params=plgpu.TritonCompilerParams(
122-
num_warps=num_warps, num_stages=num_stages
123-
),
115+
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
124116
debug=False,
125117
interpret=interpret,
126118
name='mhsa_backward',
@@ -273,9 +265,7 @@ def mhsea_forward(
273265
shape=(batch_len, seq_len, num_heads), dtype=v.dtype
274266
),
275267
],
276-
compiler_params=plgpu.TritonCompilerParams(
277-
num_warps=num_warps, num_stages=num_stages
278-
),
268+
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
279269
debug=False,
280270
interpret=interpret,
281271
name='mhea_forward',
@@ -377,9 +367,7 @@ def mhsea_backward(
377367
shape=(batch_len, seq_len, num_heads, seq_len), dtype=e.dtype
378368
),
379369
],
380-
compiler_params=plgpu.TritonCompilerParams(
381-
num_warps=num_warps, num_stages=num_stages
382-
),
370+
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
383371
debug=False,
384372
interpret=interpret,
385373
name='mhsea_backward_q_vjp',
@@ -438,9 +426,7 @@ def mhsea_backward(
438426
shape=(batch_len, seq_len, num_heads, head_len), dtype=v.dtype
439427
),
440428
],
441-
compiler_params=plgpu.TritonCompilerParams(
442-
num_warps=num_warps, num_stages=num_stages
443-
),
429+
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
444430
debug=False,
445431
interpret=interpret,
446432
name='mhsea_backward_kv_vjp',

folx/experimental/pallas/attention/forward_laplacian.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,14 @@
66
import jax.numpy as jnp
77
from jax.experimental import pallas as pl
88

9-
try:
10-
from jax.experimental.pallas import triton as plgpu
11-
except ImportError:
12-
from jax.experimental.pallas import gpu as plgpu
13-
149
from folx import forward_laplacian
1510
from folx.api import FwdJacobian, FwdLaplArray
1611

1712
from .mhsa import reference_mhsa_kernel
1813
from .mhsea import reference_mhsea_kernel
1914
from .utils import (
2015
big_number,
16+
compiler_params,
2117
compute_q_and_kv_block_len,
2218
create_grid,
2319
get_input_mask_block_spec,
@@ -158,9 +154,7 @@ def mhsa_forward_laplacian(
158154
dtype=q.dtype, # o.laplacian
159155
),
160156
],
161-
compiler_params=plgpu.TritonCompilerParams(
162-
num_warps=num_warps, num_stages=num_stages
163-
),
157+
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
164158
debug=False,
165159
interpret=interpret,
166160
name='mhsa_forward_laplacian',
@@ -593,9 +587,7 @@ def mhsea_forward_laplacian(
593587
dtype=v.dtype, # o.laplacian
594588
),
595589
],
596-
compiler_params=plgpu.TritonCompilerParams(
597-
num_warps=num_warps, num_stages=num_stages
598-
),
590+
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
599591
debug=False,
600592
interpret=interpret,
601593
name='mhsea_forward_laplacian',

folx/experimental/pallas/attention/mhsa.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,9 @@
66
import jax.numpy as jnp
77
from jax.experimental import pallas as pl
88

9-
try:
10-
from jax.experimental.pallas import triton as plgpu
11-
except ImportError:
12-
from jax.experimental.pallas import gpu as plgpu
13-
149
from .utils import (
1510
big_number,
11+
compiler_params,
1612
compute_q_and_kv_block_len,
1713
create_grid,
1814
get_mask_block_spec,
@@ -63,9 +59,7 @@ def mhsa(
6359
out_shape=jax.ShapeDtypeStruct(
6460
shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype
6561
),
66-
compiler_params=plgpu.TritonCompilerParams(
67-
num_warps=num_warps, num_stages=num_stages
68-
),
62+
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
6963
debug=False,
7064
interpret=interpret,
7165
name='mhsa',

folx/experimental/pallas/attention/mhsea.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,9 @@
66
import jax.numpy as jnp
77
from jax.experimental import pallas as pl
88

9-
try:
10-
from jax.experimental.pallas import triton as plgpu
11-
except ImportError:
12-
from jax.experimental.pallas import gpu as plgpu
13-
149
from .utils import (
1510
big_number,
11+
compiler_params,
1612
compute_q_and_kv_block_len,
1713
create_grid,
1814
get_lse_block_spec,
@@ -63,9 +59,7 @@ def mhsea(
6359
shape=(batch_len, seq_len, num_heads), dtype=q.dtype
6460
), # lse
6561
],
66-
compiler_params=plgpu.TritonCompilerParams(
67-
num_warps=num_warps, num_stages=num_stages
68-
),
62+
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
6963
debug=False,
7064
interpret=interpret,
7165
name='mhsea',

folx/experimental/pallas/attention/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44
import jax.numpy as jnp
55
from jax.experimental import pallas as pl
66

7+
try:
8+
from jax.experimental.pallas import triton as plgpu
9+
except ImportError:
10+
from jax.experimental.pallas import gpu as plgpu
11+
12+
13+
from packaging.version import Version
14+
715

816
def sum_columns(x: jax.Array) -> jax.Array:
917
return x.sum(axis=1, keepdims=True)
@@ -210,3 +218,10 @@ def big_number(dtype) -> float:
210218
return 1e40
211219
else:
212220
raise ValueError(f'Unexpected dtype {dtype}')
221+
222+
223+
def compiler_params(num_warps, num_stages):
224+
if Version(jax.__version__) >= Version('0.4.34'):
225+
return plgpu.TritonCompilerParams(num_warps=num_warps, num_stages=num_stages)
226+
else:
227+
return dict(triton=dict(num_warps=num_warps, num_stages=num_stages))

0 commit comments

Comments
 (0)