Skip to content

Commit d05c107

Browse files
authored
shard_map compatibility (#36)
* pvary I * workaround to fix sparse version (jax>=0.7.2) * restore compatibilty with older versions of jax * add a test * fmt * fix test
1 parent ea364bc commit d05c107

File tree

4 files changed

+77
-21
lines changed

4 files changed

+77
-21
lines changed

folx/ad.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,10 @@ def flat_f(x):
7171

7272
out = flat_f(flat_primals)
7373

74-
result = jax.vmap(vjp(flat_f, flat_primals))(
75-
jnp.eye(out.size, dtype=out.dtype)
76-
)[0]
74+
eye = jnp.eye(out.size, dtype=out.dtype)
75+
if hasattr(jax.lax, 'pvary'):
76+
eye = jax.lax.pvary(eye, tuple(jax.typeof(out).vma))
77+
result = jax.vmap(vjp(flat_f, flat_primals))(eye)[0]
7778
result = jax.vmap(unravel, out_axes=0)(result)
7879
if len(primals) == 1:
7980
return result[0]

folx/api.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,18 @@ def get_indices(mask, out_mask):
131131

132132
if isinstance(outputs, np.ndarray):
133133
with jax.ensure_compile_time_eval():
134-
result = np.asarray(get_indices(flat_mask, flat_outputs), dtype=int).T
134+
if hasattr(jax.sharding, 'use_abstract_mesh'): # jax>=0.7.2
135+
# see https://github.com/jax-ml/jax/discussions/31461
136+
with jax.sharding.use_abstract_mesh(
137+
jax.sharding.AbstractMesh((), ())
138+
):
139+
result = np.asarray(
140+
get_indices(flat_mask, flat_outputs), dtype=int
141+
).T
142+
else:
143+
result = np.asarray(
144+
get_indices(flat_mask, flat_outputs), dtype=int
145+
).T
135146
else:
136147
result = get_indices(flat_mask, flat_outputs).T
137148
return result.reshape(mask.shape)

folx/wrapped_functions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,10 @@ def custom_jvp(jacobian, tangent, sign):
227227
log_det_jvp = jac_dot_tangent.real
228228
else:
229229
sign_jvp = jnp.zeros((), dtype=jac_dot_tangent.dtype)
230+
if hasattr(jax.lax, 'pvary'):
231+
sign_jvp = jax.lax.pvary(
232+
sign_jvp, tuple(jax.typeof(jac_dot_tangent).vma)
233+
)
230234
log_det_jvp = jac_dot_tangent
231235
return (sign_jvp, log_det_jvp)
232236

test/test_layers.py

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import functools
2+
from functools import partial
23

34
import jax
45
import jax.numpy as jnp
56
import jax.tree_util as jtu
67
import numpy as np
78
from laplacian_testcase import LaplacianTestCase
9+
from packaging.version import Version
810
from parameterized import parameterized
911

1012
from folx import (
@@ -174,28 +176,66 @@ def test_slogdet(self, test_complex: bool):
174176
w = w + 1j * np.random.normal(size=w.shape)
175177

176178
@jax.jit
177-
def f(x):
179+
def _f(w, x):
178180
return jnp.linalg.slogdet(jnp.tanh((x @ w).reshape(16, 16)))
179181

182+
f = partial(_f, w)
183+
180184
for sparsity in [0, x.size]:
181-
with self.subTest(sparsity=sparsity):
182-
sign_y, log_y = jax.jit(forward_laplacian(f, sparsity))(x)
183-
self.assertEqual(log_y.x.shape, f(x)[1].shape)
184-
self.assert_allclose(log_y.x, f(x)[1])
185-
self.assert_allclose(
186-
log_y.jacobian.dense_array, self.jacobian(f, x)[1].T
187-
)
188-
self.assert_allclose(log_y.laplacian, self.laplacian(f, x)[1])
189-
190-
self.assertEqual(sign_y.shape, log_y.x.shape)
191-
if test_complex:
192-
self.assertIsInstance(sign_y, FwdLaplArray)
185+
for use_shard_map in [False, True]:
186+
with self.subTest(sparsity=sparsity, use_shard_map=use_shard_map):
187+
if use_shard_map and (
188+
Version(jax.__version__) < Version('0.7.1')
189+
or (
190+
sparsity != 0
191+
and Version(jax.__version__) < Version('0.7.2')
192+
)
193+
):
194+
self.skipTest('jax version too old')
195+
if use_shard_map:
196+
mesh = jax.sharding.Mesh(
197+
jax.devices()[:1],
198+
'i',
199+
axis_types=jax.sharding.AxisType.Explicit,
200+
)
201+
202+
@jax.jit
203+
@partial(
204+
jax.shard_map,
205+
in_specs=(jax.P(), jax.P('i')),
206+
out_specs=jax.P('i'),
207+
)
208+
@partial(jax.vmap, in_axes=(None, 0))
209+
def forward_laplacian_sh(w, x):
210+
return forward_laplacian(partial(_f, w), sparsity)(x)
211+
212+
with jax.set_mesh(mesh):
213+
x_sh = jax.sharding.reshard(x[None], jax.P('i'))
214+
w_sh = jax.sharding.reshard(w, jax.P())
215+
sign_y, log_y = jax.tree.map(
216+
lambda x: x[0], forward_laplacian_sh(w_sh, x_sh)
217+
)
218+
else:
219+
sign_y, log_y = jax.jit(forward_laplacian(f, sparsity))(x)
220+
221+
self.assertEqual(log_y.x.shape, f(x)[1].shape)
222+
self.assert_allclose(log_y.x, f(x)[1])
193223
self.assert_allclose(
194-
sign_y.jacobian.dense_array, self.jacobian(f, x)[0].T
224+
log_y.jacobian.dense_array, self.jacobian(f, x)[1].T
195225
)
196-
self.assert_allclose(sign_y.laplacian, self.laplacian(f, x)[0])
197-
else:
198-
self.assertIsInstance(sign_y, jax.Array)
226+
self.assert_allclose(log_y.laplacian, self.laplacian(f, x)[1])
227+
228+
self.assertEqual(sign_y.shape, log_y.x.shape)
229+
if test_complex:
230+
self.assertIsInstance(sign_y, FwdLaplArray)
231+
self.assert_allclose(
232+
sign_y.jacobian.dense_array, self.jacobian(f, x)[0].T
233+
)
234+
self.assert_allclose(sign_y.laplacian, self.laplacian(f, x)[0])
235+
else:
236+
self.assertIsInstance(sign_y, jax.Array)
237+
del sign_y
238+
del log_y
199239

200240
def test_custom_hessian(self):
201241
x = np.random.normal(size=(16,))

0 commit comments

Comments
 (0)