Skip to content

Commit 47a0e8d

Browse files
authored
shard_map compatibility part 2 (#39)
* add missing pvary * add regression test * support old jax * test on a single device only (its sufficient to trigger the error) * skip test on old jax
1 parent 552e5d6 commit 47a0e8d

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

folx/ad.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def jvp_fun(s):
9494
return jax.jvp(f, primals, unravel(s))[1]
9595

9696
eye = jnp.eye(flat_primals.size, dtype=flat_primals.dtype)
97+
if hasattr(jax.lax, 'pvary'):
98+
eye = jax.lax.pvary(eye, tuple(jax.typeof(flat_primals).vma))
9799
J = jax.vmap(jvp_fun, out_axes=-1)(eye)
98100
return J
99101

test/test_shard_map.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from functools import partial
2+
3+
import jax
4+
import jax.numpy as jnp
5+
import pytest
6+
from packaging.version import Version
7+
8+
from folx import forward_laplacian
9+
10+
11+
@pytest.mark.skipif(
12+
Version(jax.__version__) < Version('0.7.1'), reason='jax version too old'
13+
)
14+
def test_shard_map_bug_integer_pow():
15+
# see https://github.com/microsoft/folx/issues/38
16+
17+
def f(w, x):
18+
return jax.lax.integer_pow(x @ w, 1)
19+
20+
@jax.smap(out_axes=0, in_axes=(None, 0), axis_name='i')
21+
@partial(jax.vmap, in_axes=(None, 0))
22+
def test(w, x):
23+
return forward_laplacian(partial(f, w))(x)
24+
25+
x = jnp.ones((1, 16))
26+
w = jnp.ones((16, 16))
27+
28+
with jax.set_mesh(jax.sharding.Mesh(jax.devices()[:1], 'i')):
29+
test(w, x)

0 commit comments

Comments
 (0)