Skip to content

Make diffeqsolve convertable to TensorFlow #202

@llandsmeer

Description

@llandsmeer

Based on a talk on NODE's on youtube I came across this package, and this looks perfect for some project we are planning (thanks for the great talk!) . Now one of the platforms where we want to run our code does not support JAX/XLA/Tensorflow. Just ONNX. I tried converting a simulation function to Tensorflow for later conversion to ONNX, but this fails because the unsupported unvmap_any is used (at compiletime!) to deduce the amount of iterations needed.

Minimal example:

import tensorflow as tf
import jax.numpy as jnp
import tf2onnx

from diffrax import diffeqsolve, ODETerm, Euler
from jax.experimental import jax2tf

def simulate(y0):
    solution = diffeqsolve(
            terms=ODETerm(lambda t, y, a: -y), solver=Euler(),
            t0=0, t1=1, dt0=0.1, y0=y0)
    return solution.ys[0]

# This works
x = simulate(100)
assert jnp.isclose(x, jnp.exp(-1)*100, atol=.1, rtol=.1)

simulate_tf = tf.function(jax2tf.convert(simulate, enable_xla=False))

# Does not work:
# simulate_tf(100)
# => NotImplementedError: TensorFlow interpretation rule for 'unvmap_any' not implemented

# Also doesn't not work:
tf2onnx.convert.from_function(
        simulate_tf, input_signature=[tf.TensorSpec((), tf.float32)])
# simulate_tf(100)
# => NotImplementedError: TensorFlow interpretation rule for 'unvmap_any' not implemented

For us, it would be really nice to use a GPU/TSP during training with jax, then transfer to this specifc piece of hardware with just ONNX support for inference (at this point I don't need gradient calculation anymore).
Of course, solving this might be completely outside the scope of the project and there are other solutions like writing the solvers from scratch or using existing solvers in TF/PyTorch.

Currently my knowledge of JAX is limited (hopefully this will soon improve!). If this is the only function stopping Diffrax from being tensorflow-convertable maybe a small workaround could be possible. I'm also happy with a answer like 'no we don't do this' or 'send us a PR if you want to have this fixed'

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions