diff --git a/pyproject.toml b/pyproject.toml index 3a759be..4175f5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,8 +47,8 @@ local_scheme = "no-local-version" [project.optional-dependencies] jax = [ - "jax>=0.4.35,<0.10.0; python_version >= '3.11'", - "jaxlib>=0.4.35,<0.10.0; python_version >= '3.11'", + "jax>=0.7.0,<0.11.0; python_version >= '3.11'", + "jaxlib>=0.7.0,<0.11.0; python_version >= '3.11'", "jaxnnls==1.0.1; python_version >= '3.11'" ] optional = [