Skip to content

Commit a4a1401

Browse files
committed
Import statement has to differ across versions
1 parent 2b2e45b commit a4a1401

File tree

4 files changed

+20
-4
lines changed

4 files changed

+20
-4
lines changed

folx/experimental/pallas/attention/custom_gradients.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
import jax
66
import jax.numpy as jnp
77
from jax.experimental import pallas as pl
8-
from jax.experimental.pallas import gpu as plgpu
8+
9+
try:
10+
from jax.experimental.pallas import triton as plgpu
11+
except ImportError:
12+
from jax.experimental.pallas import gpu as plgpu
913

1014
from .mhsa import mhsa_kernel, reference_mhsa_kernel
1115
from .mhsea import mhsea_kernel, reference_mhsea_kernel

folx/experimental/pallas/attention/forward_laplacian.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
import jax
66
import jax.numpy as jnp
77
from jax.experimental import pallas as pl
8-
from jax.experimental.pallas import gpu as plgpu
8+
9+
try:
10+
from jax.experimental.pallas import triton as plgpu
11+
except ImportError:
12+
from jax.experimental.pallas import gpu as plgpu
913

1014
from folx import forward_laplacian
1115
from folx.api import FwdJacobian, FwdLaplArray

folx/experimental/pallas/attention/mhsa.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
import jax
66
import jax.numpy as jnp
77
from jax.experimental import pallas as pl
8-
from jax.experimental.pallas import gpu as plgpu
8+
9+
try:
10+
from jax.experimental.pallas import triton as plgpu
11+
except ImportError:
12+
from jax.experimental.pallas import gpu as plgpu
913

1014
from .utils import (
1115
big_number,

folx/experimental/pallas/attention/mhsea.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
import jax
66
import jax.numpy as jnp
77
from jax.experimental import pallas as pl
8-
from jax.experimental.pallas import gpu as plgpu
8+
9+
try:
10+
from jax.experimental.pallas import triton as plgpu
11+
except ImportError:
12+
from jax.experimental.pallas import gpu as plgpu
913

1014
from .utils import (
1115
big_number,

0 commit comments

Comments
 (0)