|
1 | 1 | import functools |
| 2 | +from functools import partial |
2 | 3 |
|
3 | 4 | import jax |
4 | 5 | import jax.numpy as jnp |
5 | 6 | import jax.tree_util as jtu |
6 | 7 | import numpy as np |
7 | 8 | from laplacian_testcase import LaplacianTestCase |
| 9 | +from packaging.version import Version |
8 | 10 | from parameterized import parameterized |
9 | 11 |
|
10 | 12 | from folx import ( |
@@ -174,28 +176,66 @@ def test_slogdet(self, test_complex: bool): |
174 | 176 | w = w + 1j * np.random.normal(size=w.shape) |
175 | 177 |
|
176 | 178 | @jax.jit |
177 | | - def f(x): |
| 179 | + def _f(w, x): |
178 | 180 | return jnp.linalg.slogdet(jnp.tanh((x @ w).reshape(16, 16))) |
179 | 181 |
|
| 182 | + f = partial(_f, w) |
| 183 | + |
180 | 184 | for sparsity in [0, x.size]: |
181 | | - with self.subTest(sparsity=sparsity): |
182 | | - sign_y, log_y = jax.jit(forward_laplacian(f, sparsity))(x) |
183 | | - self.assertEqual(log_y.x.shape, f(x)[1].shape) |
184 | | - self.assert_allclose(log_y.x, f(x)[1]) |
185 | | - self.assert_allclose( |
186 | | - log_y.jacobian.dense_array, self.jacobian(f, x)[1].T |
187 | | - ) |
188 | | - self.assert_allclose(log_y.laplacian, self.laplacian(f, x)[1]) |
189 | | - |
190 | | - self.assertEqual(sign_y.shape, log_y.x.shape) |
191 | | - if test_complex: |
192 | | - self.assertIsInstance(sign_y, FwdLaplArray) |
| 185 | + for use_shard_map in [False, True]: |
| 186 | + with self.subTest(sparsity=sparsity, use_shard_map=use_shard_map): |
| 187 | + if use_shard_map and ( |
| 188 | + Version(jax.__version__) < Version('0.7.1') |
| 189 | + or ( |
| 190 | + sparsity != 0 |
| 191 | + and Version(jax.__version__) < Version('0.7.2') |
| 192 | + ) |
| 193 | + ): |
| 194 | + self.skipTest('jax version too old') |
| 195 | + if use_shard_map: |
| 196 | + mesh = jax.sharding.Mesh( |
| 197 | + jax.devices()[:1], |
| 198 | + 'i', |
| 199 | + axis_types=jax.sharding.AxisType.Explicit, |
| 200 | + ) |
| 201 | + |
| 202 | + @jax.jit |
| 203 | + @partial( |
| 204 | + jax.shard_map, |
| 205 | + in_specs=(jax.P(), jax.P('i')), |
| 206 | + out_specs=jax.P('i'), |
| 207 | + ) |
| 208 | + @partial(jax.vmap, in_axes=(None, 0)) |
| 209 | + def forward_laplacian_sh(w, x): |
| 210 | + return forward_laplacian(partial(_f, w), sparsity)(x) |
| 211 | + |
| 212 | + with jax.set_mesh(mesh): |
| 213 | + x_sh = jax.sharding.reshard(x[None], jax.P('i')) |
| 214 | + w_sh = jax.sharding.reshard(w, jax.P()) |
| 215 | + sign_y, log_y = jax.tree.map( |
| 216 | + lambda x: x[0], forward_laplacian_sh(w_sh, x_sh) |
| 217 | + ) |
| 218 | + else: |
| 219 | + sign_y, log_y = jax.jit(forward_laplacian(f, sparsity))(x) |
| 220 | + |
| 221 | + self.assertEqual(log_y.x.shape, f(x)[1].shape) |
| 222 | + self.assert_allclose(log_y.x, f(x)[1]) |
193 | 223 | self.assert_allclose( |
194 | | - sign_y.jacobian.dense_array, self.jacobian(f, x)[0].T |
| 224 | + log_y.jacobian.dense_array, self.jacobian(f, x)[1].T |
195 | 225 | ) |
196 | | - self.assert_allclose(sign_y.laplacian, self.laplacian(f, x)[0]) |
197 | | - else: |
198 | | - self.assertIsInstance(sign_y, jax.Array) |
| 226 | + self.assert_allclose(log_y.laplacian, self.laplacian(f, x)[1]) |
| 227 | + |
| 228 | + self.assertEqual(sign_y.shape, log_y.x.shape) |
| 229 | + if test_complex: |
| 230 | + self.assertIsInstance(sign_y, FwdLaplArray) |
| 231 | + self.assert_allclose( |
| 232 | + sign_y.jacobian.dense_array, self.jacobian(f, x)[0].T |
| 233 | + ) |
| 234 | + self.assert_allclose(sign_y.laplacian, self.laplacian(f, x)[0]) |
| 235 | + else: |
| 236 | + self.assertIsInstance(sign_y, jax.Array) |
| 237 | + del sign_y |
| 238 | + del log_y |
199 | 239 |
|
200 | 240 | def test_custom_hessian(self): |
201 | 241 | x = np.random.normal(size=(16,)) |
|
0 commit comments