Skip to content

Commit ea364bc

Browse files
committed
fix compiler params for newer jax versions
1 parent 30b053a commit ea364bc

File tree

1 file changed

+6
-1
lines changed
  • folx/experimental/pallas/attention

1 file changed

+6
-1
lines changed

folx/experimental/pallas/attention/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,11 @@ def big_number(dtype) -> float:
222222

223223
def compiler_params(num_warps, num_stages):
224224
if Version(jax.__version__) >= Version('0.4.34'):
225-
return plgpu.TritonCompilerParams(num_warps=num_warps, num_stages=num_stages)
225+
if hasattr(plgpu, 'CompilerParams'):
226+
return plgpu.CompilerParams(num_warps=num_warps, num_stages=num_stages)
227+
else:
228+
return plgpu.TritonCompilerParams(
229+
num_warps=num_warps, num_stages=num_stages
230+
)
226231
else:
227232
return dict(triton=dict(num_warps=num_warps, num_stages=num_stages))

0 commit comments

Comments
 (0)