Skip to content

Commit 2b2e45b

Browse files
committed
Use TritonCompilerParams
1 parent 730abba commit 2b2e45b

File tree

4 files changed

+22
-18
lines changed

4 files changed

+22
-18
lines changed

folx/experimental/pallas/attention/custom_gradients.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import jax
66
import jax.numpy as jnp
77
from jax.experimental import pallas as pl
8+
from jax.experimental.pallas import gpu as plgpu
89

910
from .mhsa import mhsa_kernel, reference_mhsa_kernel
1011
from .mhsea import mhsea_kernel, reference_mhsea_kernel
@@ -53,8 +54,8 @@ def mhsa_forward(
5354
out_shape=jax.ShapeDtypeStruct(
5455
shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype
5556
),
56-
compiler_params=dict(
57-
triton=dict(num_warps=num_warps, num_stages=num_stages)
57+
compiler_params=plgpu.TritonCompilerParams(
58+
num_warps=num_warps, num_stages=num_stages
5859
),
5960
debug=False,
6061
interpret=interpret,
@@ -113,8 +114,8 @@ def mhsa_backward(
113114
shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype
114115
),
115116
],
116-
compiler_params=dict(
117-
triton=dict(num_warps=num_warps, num_stages=num_stages)
117+
compiler_params=plgpu.TritonCompilerParams(
118+
num_warps=num_warps, num_stages=num_stages
118119
),
119120
debug=False,
120121
interpret=interpret,
@@ -268,8 +269,8 @@ def mhsea_forward(
268269
shape=(batch_len, seq_len, num_heads), dtype=v.dtype
269270
),
270271
],
271-
compiler_params=dict(
272-
triton=dict(num_warps=num_warps, num_stages=num_stages)
272+
compiler_params=plgpu.TritonCompilerParams(
273+
num_warps=num_warps, num_stages=num_stages
273274
),
274275
debug=False,
275276
interpret=interpret,
@@ -372,8 +373,8 @@ def mhsea_backward(
372373
shape=(batch_len, seq_len, num_heads, seq_len), dtype=e.dtype
373374
),
374375
],
375-
compiler_params=dict(
376-
triton=dict(num_warps=num_warps, num_stages=num_stages)
376+
compiler_params=plgpu.TritonCompilerParams(
377+
num_warps=num_warps, num_stages=num_stages
377378
),
378379
debug=False,
379380
interpret=interpret,
@@ -433,8 +434,8 @@ def mhsea_backward(
433434
shape=(batch_len, seq_len, num_heads, head_len), dtype=v.dtype
434435
),
435436
],
436-
compiler_params=dict(
437-
triton=dict(num_warps=num_warps, num_stages=num_stages)
437+
compiler_params=plgpu.TritonCompilerParams(
438+
num_warps=num_warps, num_stages=num_stages
438439
),
439440
debug=False,
440441
interpret=interpret,

folx/experimental/pallas/attention/forward_laplacian.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import jax
66
import jax.numpy as jnp
77
from jax.experimental import pallas as pl
8+
from jax.experimental.pallas import gpu as plgpu
89

910
from folx import forward_laplacian
1011
from folx.api import FwdJacobian, FwdLaplArray
@@ -153,8 +154,8 @@ def mhsa_forward_laplacian(
153154
dtype=q.dtype, # o.laplacian
154155
),
155156
],
156-
compiler_params=dict(
157-
triton=dict(num_warps=num_warps, num_stages=num_stages)
157+
compiler_params=plgpu.TritonCompilerParams(
158+
num_warps=num_warps, num_stages=num_stages
158159
),
159160
debug=False,
160161
interpret=interpret,
@@ -588,8 +589,8 @@ def mhsea_forward_laplacian(
588589
dtype=v.dtype, # o.laplacian
589590
),
590591
],
591-
compiler_params=dict(
592-
triton=dict(num_warps=num_warps, num_stages=num_stages)
592+
compiler_params=plgpu.TritonCompilerParams(
593+
num_warps=num_warps, num_stages=num_stages
593594
),
594595
debug=False,
595596
interpret=interpret,

folx/experimental/pallas/attention/mhsa.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import jax
66
import jax.numpy as jnp
77
from jax.experimental import pallas as pl
8+
from jax.experimental.pallas import gpu as plgpu
89

910
from .utils import (
1011
big_number,
@@ -58,8 +59,8 @@ def mhsa(
5859
out_shape=jax.ShapeDtypeStruct(
5960
shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype
6061
),
61-
compiler_params=dict(
62-
triton=dict(num_warps=num_warps, num_stages=num_stages)
62+
compiler_params=plgpu.TritonCompilerParams(
63+
num_warps=num_warps, num_stages=num_stages
6364
),
6465
debug=False,
6566
interpret=interpret,

folx/experimental/pallas/attention/mhsea.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import jax
66
import jax.numpy as jnp
77
from jax.experimental import pallas as pl
8+
from jax.experimental.pallas import gpu as plgpu
89

910
from .utils import (
1011
big_number,
@@ -58,8 +59,8 @@ def mhsea(
5859
shape=(batch_len, seq_len, num_heads), dtype=q.dtype
5960
), # lse
6061
],
61-
compiler_params=dict(
62-
triton=dict(num_warps=num_warps, num_stages=num_stages)
62+
compiler_params=plgpu.TritonCompilerParams(
63+
num_warps=num_warps, num_stages=num_stages
6364
),
6465
debug=False,
6566
interpret=interpret,

0 commit comments

Comments
 (0)