-
-
Notifications
You must be signed in to change notification settings - Fork 175
Description
Based on a talk on NODE's on youtube I came across this package, and this looks perfect for some project we are planning (thanks for the great talk!) . Now one of the platforms where we want to run our code does not support JAX/XLA/Tensorflow. Just ONNX. I tried converting a simulation function to Tensorflow for later conversion to ONNX, but this fails because the unsupported unvmap_any is used (at compiletime!) to deduce the amount of iterations needed.
Minimal example:
import tensorflow as tf
import jax.numpy as jnp
import tf2onnx
from diffrax import diffeqsolve, ODETerm, Euler
from jax.experimental import jax2tf
def simulate(y0):
solution = diffeqsolve(
terms=ODETerm(lambda t, y, a: -y), solver=Euler(),
t0=0, t1=1, dt0=0.1, y0=y0)
return solution.ys[0]
# This works
x = simulate(100)
assert jnp.isclose(x, jnp.exp(-1)*100, atol=.1, rtol=.1)
simulate_tf = tf.function(jax2tf.convert(simulate, enable_xla=False))
# Does not work:
# simulate_tf(100)
# => NotImplementedError: TensorFlow interpretation rule for 'unvmap_any' not implemented
# Also doesn't not work:
tf2onnx.convert.from_function(
simulate_tf, input_signature=[tf.TensorSpec((), tf.float32)])
# simulate_tf(100)
# => NotImplementedError: TensorFlow interpretation rule for 'unvmap_any' not implementedFor us, it would be really nice to use a GPU/TSP during training with jax, then transfer to this specifc piece of hardware with just ONNX support for inference (at this point I don't need gradient calculation anymore).
Of course, solving this might be completely outside the scope of the project and there are other solutions like writing the solvers from scratch or using existing solvers in TF/PyTorch.
Currently my knowledge of JAX is limited (hopefully this will soon improve!). If this is the only function stopping Diffrax from being tensorflow-convertable maybe a small workaround could be possible. I'm also happy with a answer like 'no we don't do this' or 'send us a PR if you want to have this fixed'