We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent c4be0d4 commit 0c048b7Copy full SHA for 0c048b7
folx/hessian.py
@@ -6,7 +6,6 @@
6
import jax.flatten_util as jfu
7
import jax.numpy as jnp
8
import jax.tree_util as jtu
9
-import jaxlib.xla_extension
10
import numpy as np
11
12
from .ad import hessian, jacrev
@@ -230,7 +229,7 @@ def idx_fn(x):
230
229
# potentially fails if the arrays are too large.
231
# +1 because we need to accomodate the -1.
232
arrs = np.asarray(idx_fn(inp), dtype=int)
233
- except jaxlib.xla_extension.XlaRuntimeError:
+ except RuntimeError:
234
logging.info(
235
'Failed to find unique elements on GPU, falling back to CPU. This will be slow.'
236
)
0 commit comments