We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 30b053a commit ea364bcCopy full SHA for ea364bc
folx/experimental/pallas/attention/utils.py
@@ -222,6 +222,11 @@ def big_number(dtype) -> float:
222
223
def compiler_params(num_warps, num_stages):
224
if Version(jax.__version__) >= Version('0.4.34'):
225
- return plgpu.TritonCompilerParams(num_warps=num_warps, num_stages=num_stages)
+ if hasattr(plgpu, 'CompilerParams'):
226
+ return plgpu.CompilerParams(num_warps=num_warps, num_stages=num_stages)
227
+ else:
228
+ return plgpu.TritonCompilerParams(
229
+ num_warps=num_warps, num_stages=num_stages
230
+ )
231
else:
232
return dict(triton=dict(num_warps=num_warps, num_stages=num_stages))
0 commit comments