Skip to content

Commit 30b053a

Browse files
committed
more robust indexing handling
1 parent 34e41fa commit 30b053a

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

folx/jvp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,10 @@ def sparse_index_jvp(
275275
# https://github.com/google/jax/pull/3370
276276
with jax.ensure_compile_time_eval():
277277
extra_filled = jtu.tree_map(
278-
lambda x: jnp.full(x.shape, -1, dtype=jnp.int32), extra_args
278+
lambda x: jnp.full(
279+
x.shape if isinstance(x, jax.Array) else (), -1, dtype=jnp.int32
280+
),
281+
extra_args,
279282
)
280283

281284
def _merged_fwd(*args):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "folx"
3-
version = "0.2.18"
3+
version = "0.2.19"
44
description = "Forward Laplacian for JAX"
55
authors = [
66
{ name = "Nicholas Gao", email = "n.gao@tum.de" },

0 commit comments

Comments
 (0)