File tree Expand file tree Collapse file tree 4 files changed +20
-4
lines changed
folx/experimental/pallas/attention Expand file tree Collapse file tree 4 files changed +20
-4
lines changed Original file line number Diff line number Diff line change 55import jax
66import jax .numpy as jnp
77from jax .experimental import pallas as pl
8- from jax .experimental .pallas import gpu as plgpu
8+
9+ try :
10+ from jax .experimental .pallas import triton as plgpu
11+ except ImportError :
12+ from jax .experimental .pallas import gpu as plgpu
913
1014from .mhsa import mhsa_kernel , reference_mhsa_kernel
1115from .mhsea import mhsea_kernel , reference_mhsea_kernel
Original file line number Diff line number Diff line change 55import jax
66import jax .numpy as jnp
77from jax .experimental import pallas as pl
8- from jax .experimental .pallas import gpu as plgpu
8+
9+ try :
10+ from jax .experimental .pallas import triton as plgpu
11+ except ImportError :
12+ from jax .experimental .pallas import gpu as plgpu
913
1014from folx import forward_laplacian
1115from folx .api import FwdJacobian , FwdLaplArray
Original file line number Diff line number Diff line change 55import jax
66import jax .numpy as jnp
77from jax .experimental import pallas as pl
8- from jax .experimental .pallas import gpu as plgpu
8+
9+ try :
10+ from jax .experimental .pallas import triton as plgpu
11+ except ImportError :
12+ from jax .experimental .pallas import gpu as plgpu
913
1014from .utils import (
1115 big_number ,
Original file line number Diff line number Diff line change 55import jax
66import jax .numpy as jnp
77from jax .experimental import pallas as pl
8- from jax .experimental .pallas import gpu as plgpu
8+
9+ try :
10+ from jax .experimental .pallas import triton as plgpu
11+ except ImportError :
12+ from jax .experimental .pallas import gpu as plgpu
913
1014from .utils import (
1115 big_number ,
You can’t perform that action at this time.
0 commit comments