In some recent jax version, the pallas API changed from
pl.pallas_call(
...,
compiler_params=dict(triton=dict(
num_warps=num_warps,
num_stages=num_stages,
)),
...
)
to
pl.pallas_call(
...,
compiler_params=plgpu.TritonCompilerParams(
num_warps=num_warps_, num_stages=num_stages),
...,
)
We could either drop support for old jax, or try and make a backwards compatible solution