-
Notifications
You must be signed in to change notification settings - Fork 18
Description
Description
Hello, I encountered a TypeError when computing the gradient of a function inside jax.shard_map including jnp.linalg.slogdet.
The specific error is: File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/interpreters/ad.py", line 543, in backward_pass3 rule(cts_in, *map(read, eqn.invars), **eqn.params) ~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/pjit.py", line 2431, in _pjit_transpose_fancy trans_jaxpr, out_tree = _transpose_jaxpr_fancy(jaxpr, in_tree, (*in_avals,), specs) ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/pjit.py", line 2478, in _transpose_jaxpr_fancy trans_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( ~~~~~~~~~~~~~~~~~~~~~~~~~^ lu.wrap_init(transposed, debug_info=dbg), in_avals) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/profiler.py", line 359, in wrapper return func(*args, **kwargs) File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 2409, in trace_to_jaxpr_dynamic ans = fun.call_wrapped(*in_tracers) File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/linear_util.py", line 212, in call_wrapped return self.f_transformed(*args, **kwargs) ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/pjit.py", line 2473, in transposed ad.backward_pass3(jaxpr.jaxpr, False, jaxpr.consts, args, cts_in) ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/interpreters/ad.py", line 555, in backward_pass3 cts_out = rule(cts_in, *map(up, primals), **eqn.params) File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/interpreters/ad.py", line 1233, in linear_transpose2 return transpose_rule(cotangent, *args, **kwargs) File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 6959, in _split_transpose_rule _zeros(t.aval) if type(t) is ad_util.Zero else t ~~~~~~^^^^^^^^ File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 3655, in full_like val = core.pvary(val, tuple(core.typeof(x).vma)) ~~~~~~~~~~~^^^ File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/core.py", line 1809, in get_aval raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type") TypeError: Argument 'float32[16,1]{V:data}' of type '<class 'jax._src.core.ShapedArray'>' is not a valid JAX type.
This seems to happen when:
- Running inside
shard_map(distributed context). - Computing high-order derivatives (Grad of Laplacian but works fine when I replace folx calculation with jax.hessian).
- The computation involves
slogdet(which involves inverses/trace)
Here is a minimal script that reproduces the crash.
import jax
import jax.numpy as jnp
import folx
from flax import linen as nn
from jax import shard_map
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import os
LOCAL_DEVICES = [0, 1]
def reproduce_fix():
jax.distributed.initialize(local_device_ids=LOCAL_DEVICES)
devices = jax.devices()
mesh = Mesh(devices, axis_names=('data',))
N_dim = 16
Input_Size = N_dim * N_dim
GlobalBatch = 16 * len(devices)
x_host = jax.random.normal(jax.random.PRNGKey(0), (GlobalBatch, Input_Size))
sharding_x = NamedSharding(mesh, P('data'))
x_sharded = jax.device_put(x_host, sharding_x)
class LogDetNet(nn.Module):
@nn.compact
def __call__(self, x):
matrix = x.reshape((N_dim, N_dim))
matrix = nn.Dense(N_dim)(matrix)
matrix = nn.tanh(matrix)
sign, logdet = jnp.linalg.slogdet(matrix)
return logdet
model = LogDetNet()
key = jax.random.PRNGKey(42)
dummy_input = jnp.zeros((Input_Size,))
params = model.init(key, dummy_input)
sharding_params = jax.tree.map(lambda spec: NamedSharding(mesh, P()), params)
params_replicated = jax.device_put(params, sharding_params)
def lap_op(p, x):
f_closure = lambda y: model.apply(p, y)
lap_operator = folx.ForwardLaplacianOperator(0)(f_closure)
return lap_operator(x)
lap_op_vmap = jax.vmap(lap_op, in_axes=(None, 0))
def train_epoch_fn(init_params, x_local):
def local_loss(p, x):
log_psi, lap = lap_op_vmap(p, x)
return jnp.sum(log_psi) + jnp.sum(lap)
grads = jax.grad(local_loss)(init_params, x_local)
return jax.lax.pmean(grads, axis_name='data')
distributed_train = jax.jit(shard_map(
train_epoch_fn,
mesh=mesh,
in_specs=(P(), P('data')),
out_specs=P()
),
in_shardings=(sharding_params, sharding_x),
out_shardings=sharding_params
)
return distributed_train(params_replicated, x_sharded)
if __name__ == "__main__":
os.environ["JAX_TRACEBACK_FILTERING"] = "off"
grads= reproduce_fix()
System info (python version, jaxlib version, accelerator, etc.)
python3.13 jax 0.8.1 GPU H100/A100 lastest folx