Skip to content

Commit 34e41fa

Browse files
committed
bump version; support newer JAX versions
1 parent f9b237a commit 34e41fa

File tree

3 files changed

+977
-595
lines changed

3 files changed

+977
-595
lines changed

folx/interpreter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def eval_laplacian(eqn: JaxprEqn, invals):
214214
outvals = eqn.primitive.bind(*subfuns, *invals, **bind_params)
215215
elif eqn.primitive.name == 'scan':
216216
outvals = eval_scan(eqn, invals)
217-
elif eqn.primitive.name == 'pjit':
217+
elif eqn.primitive.name in ('jit', 'pjit'):
218218
outvals = eval_pjit(eqn, invals)
219219
elif eqn.primitive.name == 'custom_jvp_call':
220220
outvals = eval_custom_jvp(eqn, invals)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "folx"
3-
version = "0.2.17"
3+
version = "0.2.18"
44
description = "Forward Laplacian for JAX"
55
authors = [
66
{ name = "Nicholas Gao", email = "n.gao@tum.de" },

0 commit comments

Comments
 (0)