Skip to content

TypeError: ShapedArray with Sharding ({V:data}) invalid in backward pass (shard_map + slogdet + folx) #40

@zhangylch

Description

@zhangylch

Description

Hello, I encountered a TypeError when computing the gradient of a function inside jax.shard_map including jnp.linalg.slogdet.

The specific error is: File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/interpreters/ad.py", line 543, in backward_pass3 rule(cts_in, *map(read, eqn.invars), **eqn.params) ~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/pjit.py", line 2431, in _pjit_transpose_fancy trans_jaxpr, out_tree = _transpose_jaxpr_fancy(jaxpr, in_tree, (*in_avals,), specs) ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/pjit.py", line 2478, in _transpose_jaxpr_fancy trans_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( ~~~~~~~~~~~~~~~~~~~~~~~~~^ lu.wrap_init(transposed, debug_info=dbg), in_avals) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/profiler.py", line 359, in wrapper return func(*args, **kwargs) File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 2409, in trace_to_jaxpr_dynamic ans = fun.call_wrapped(*in_tracers) File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/linear_util.py", line 212, in call_wrapped return self.f_transformed(*args, **kwargs) ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/pjit.py", line 2473, in transposed ad.backward_pass3(jaxpr.jaxpr, False, jaxpr.consts, args, cts_in) ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/interpreters/ad.py", line 555, in backward_pass3 cts_out = rule(cts_in, *map(up, primals), **eqn.params) File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/interpreters/ad.py", line 1233, in linear_transpose2 return transpose_rule(cotangent, *args, **kwargs) File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 6959, in _split_transpose_rule _zeros(t.aval) if type(t) is ad_util.Zero else t ~~~~~~^^^^^^^^ File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 3655, in full_like val = core.pvary(val, tuple(core.typeof(x).vma)) ~~~~~~~~~~~^^^ File "/users/ylz/.conda/envs/jax_0_7_1/lib/python3.13/site-packages/jax/_src/core.py", line 1809, in get_aval raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type") TypeError: Argument 'float32[16,1]{V:data}' of type '<class 'jax._src.core.ShapedArray'>' is not a valid JAX type.

This seems to happen when:

  1. Running inside shard_map (distributed context).
  2. Computing high-order derivatives (Grad of Laplacian but works fine when I replace folx calculation with jax.hessian).
  3. The computation involves slogdet (which involves inverses/trace)

Here is a minimal script that reproduces the crash.

import jax
import jax.numpy as jnp
import folx
from flax import linen as nn
from jax import shard_map
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import os

LOCAL_DEVICES = [0, 1]

def reproduce_fix():
    jax.distributed.initialize(local_device_ids=LOCAL_DEVICES)

    devices = jax.devices()
    mesh = Mesh(devices, axis_names=('data',))

    N_dim = 16
    Input_Size = N_dim * N_dim
    GlobalBatch = 16 * len(devices)

    x_host = jax.random.normal(jax.random.PRNGKey(0), (GlobalBatch, Input_Size))
    sharding_x = NamedSharding(mesh, P('data'))
    x_sharded = jax.device_put(x_host, sharding_x)

    class LogDetNet(nn.Module):
        @nn.compact
        def __call__(self, x):
            matrix = x.reshape((N_dim, N_dim))
            matrix = nn.Dense(N_dim)(matrix)
            matrix = nn.tanh(matrix)

            sign, logdet = jnp.linalg.slogdet(matrix)
            return logdet

    model = LogDetNet()

    key = jax.random.PRNGKey(42)
    dummy_input = jnp.zeros((Input_Size,))
    params = model.init(key, dummy_input)

    sharding_params = jax.tree.map(lambda spec: NamedSharding(mesh, P()), params)
    params_replicated = jax.device_put(params, sharding_params)

    def lap_op(p, x):
        f_closure = lambda y: model.apply(p, y)
        lap_operator = folx.ForwardLaplacianOperator(0)(f_closure)
        return lap_operator(x)

    lap_op_vmap = jax.vmap(lap_op, in_axes=(None, 0))

    def train_epoch_fn(init_params, x_local):

        def local_loss(p, x):
            log_psi, lap = lap_op_vmap(p, x)
            return jnp.sum(log_psi) + jnp.sum(lap)

        grads = jax.grad(local_loss)(init_params, x_local)

        return jax.lax.pmean(grads, axis_name='data')

    distributed_train = jax.jit(shard_map(
        train_epoch_fn,
        mesh=mesh,
        in_specs=(P(), P('data')),
        out_specs=P()
    ),
      in_shardings=(sharding_params, sharding_x),
      out_shardings=sharding_params
    )

    return distributed_train(params_replicated, x_sharded)

if __name__ == "__main__":
    os.environ["JAX_TRACEBACK_FILTERING"] = "off"

    grads= reproduce_fix()

System info (python version, jaxlib version, accelerator, etc.)

python3.13 jax 0.8.1 GPU H100/A100 lastest folx

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions