From 52115348c26442da2c7dd46b3303f469e4c44729 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 20 May 2026 20:21:50 +0100 Subject: [PATCH] Bump JAX floor to 0.7.0 to match nufftax 0.4.0 runtime needs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Raise the [jax] optional extra to `jax>=0.7.0,<0.11.0` (was `jax>=0.4.35,<0.10.0`). nufftax 0.4.0 calls `jax.experimental.pallas.triton.CompilerParams`, which was renamed from `TritonCompilerParams` in JAX 0.7.0. The previous floor let pip resolve to JAX 0.4.x, producing installs that crashed at runtime on `subplot_interferometer_dirty_images`. JAX 0.7.0+ also requires Python 3.11+, already enforced by the existing python_version marker. The ceiling moves to <0.11.0 so we still permit the just-released 0.10.x line — a code audit confirmed PyAuto uses neither `jax.pmap` nor `PartitionSpec` tuple equality, the two breaking changes introduced in JAX 0.10.0. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.7 (1M context) --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3a759be..4175f5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,8 +47,8 @@ local_scheme = "no-local-version" [project.optional-dependencies] jax = [ - "jax>=0.4.35,<0.10.0; python_version >= '3.11'", - "jaxlib>=0.4.35,<0.10.0; python_version >= '3.11'", + "jax>=0.7.0,<0.11.0; python_version >= '3.11'", + "jaxlib>=0.7.0,<0.11.0; python_version >= '3.11'", "jaxnnls==1.0.1; python_version >= '3.11'" ] optional = [