Skip to content

Commit 0c048b7

Browse files
committed
support newer jax versions
1 parent c4be0d4 commit 0c048b7

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

folx/hessian.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import jax.flatten_util as jfu
77
import jax.numpy as jnp
88
import jax.tree_util as jtu
9-
import jaxlib.xla_extension
109
import numpy as np
1110

1211
from .ad import hessian, jacrev
@@ -230,7 +229,7 @@ def idx_fn(x):
230229
# potentially fails if the arrays are too large.
231230
# +1 because we need to accomodate the -1.
232231
arrs = np.asarray(idx_fn(inp), dtype=int)
233-
except jaxlib.xla_extension.XlaRuntimeError:
232+
except RuntimeError:
234233
logging.info(
235234
'Failed to find unique elements on GPU, falling back to CPU. This will be slow.'
236235
)

0 commit comments

Comments
 (0)