Skip to content

Commit 552e5d6

Browse files
committed
fix tests
1 parent d05c107 commit 552e5d6

File tree

2 files changed

+912
-772
lines changed

2 files changed

+912
-772
lines changed

test/test_layers.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_elementwise(self, test_complex: bool):
5050
x = x + 1j * np.random.randn(10)
5151
for f in functions:
5252
for sparsity in [0, x.size]:
53-
with self.subTest(sparsity=sparsity, f=f):
53+
with self.subTest(sparsity=sparsity, f=getattr(f, '__name__', str(f))):
5454
y = forward_laplacian(f, sparsity)(x)
5555
self.assertEqual(y.x.shape, x.shape, msg=f'{f}')
5656
self.assert_allclose(y.x, f(x))
@@ -79,7 +79,9 @@ def f_right(x):
7979

8080
for sparsity in [0, x1.size]:
8181
# test both arguments
82-
with self.subTest(sparsity=sparsity, f=f, binary=True):
82+
with self.subTest(
83+
sparsity=sparsity, f=getattr(f, '__name__', str(f)), binary=True
84+
):
8385
y = forward_laplacian(wrapped_f, sparsity)(x)
8486
self.assertEqual(y.x.shape, x1.shape, msg=f'{f}')
8587
self.assert_allclose(y.x, wrapped_f(x))
@@ -89,7 +91,9 @@ def f_right(x):
8991
self.assert_allclose(y.laplacian, self.laplacian(wrapped_f, x))
9092

9193
# test left hand argument
92-
with self.subTest(sparsity=sparsity, f=f, binary=False):
94+
with self.subTest(
95+
sparsity=sparsity, f=getattr(f, '__name__', str(f)), binary=False
96+
):
9397
y = forward_laplacian(f_left, sparsity)(x1)
9498
self.assertEqual(y.x.shape, x1.shape, msg=f'{f}')
9599
self.assert_allclose(y.x, f_left(x1))
@@ -99,7 +103,9 @@ def f_right(x):
99103
self.assert_allclose(y.laplacian, self.laplacian(f_left, x1))
100104

101105
# test right hand argument
102-
with self.subTest(sparsity=sparsity, f=f, binary=False):
106+
with self.subTest(
107+
sparsity=sparsity, f=getattr(f, '__name__', str(f)), binary=False
108+
):
103109
y = forward_laplacian(f_right, sparsity)(x1)
104110
self.assertEqual(y.x.shape, x1.shape, msg=f'{f}')
105111
self.assert_allclose(y.x, f_right(x1))
@@ -280,7 +286,7 @@ def f(x, dtype):
280286
jnp.complex64,
281287
jnp.complex128,
282288
]:
283-
with self.subTest(dtype=dtype):
289+
with self.subTest(dtype=dtype.__name__):
284290
y = jax.jit(forward_laplacian(functools.partial(f, dtype=dtype)))(x)
285291
self.assertEqual(y.x.dtype, dtype)
286292
self.assertEqual(y.jacobian.dense_array.dtype, dtype)
@@ -297,7 +303,7 @@ def f(x, dtype):
297303
jnp.uint32,
298304
jnp.uint64,
299305
]:
300-
with self.subTest(dtype=dtype):
306+
with self.subTest(dtype=dtype.__name__):
301307
y = jax.jit(forward_laplacian(functools.partial(f, dtype=dtype)))(x)
302308
self.assertIsInstance(y, jax.Array)
303309

0 commit comments

Comments
 (0)