@@ -50,7 +50,7 @@ def test_elementwise(self, test_complex: bool):
5050 x = x + 1j * np .random .randn (10 )
5151 for f in functions :
5252 for sparsity in [0 , x .size ]:
53- with self .subTest (sparsity = sparsity , f = f ):
53+ with self .subTest (sparsity = sparsity , f = getattr ( f , '__name__' , str ( f )) ):
5454 y = forward_laplacian (f , sparsity )(x )
5555 self .assertEqual (y .x .shape , x .shape , msg = f'{ f } ' )
5656 self .assert_allclose (y .x , f (x ))
@@ -79,7 +79,9 @@ def f_right(x):
7979
8080 for sparsity in [0 , x1 .size ]:
8181 # test both arguments
82- with self .subTest (sparsity = sparsity , f = f , binary = True ):
82+ with self .subTest (
83+ sparsity = sparsity , f = getattr (f , '__name__' , str (f )), binary = True
84+ ):
8385 y = forward_laplacian (wrapped_f , sparsity )(x )
8486 self .assertEqual (y .x .shape , x1 .shape , msg = f'{ f } ' )
8587 self .assert_allclose (y .x , wrapped_f (x ))
@@ -89,7 +91,9 @@ def f_right(x):
8991 self .assert_allclose (y .laplacian , self .laplacian (wrapped_f , x ))
9092
9193 # test left hand argument
92- with self .subTest (sparsity = sparsity , f = f , binary = False ):
94+ with self .subTest (
95+ sparsity = sparsity , f = getattr (f , '__name__' , str (f )), binary = False
96+ ):
9397 y = forward_laplacian (f_left , sparsity )(x1 )
9498 self .assertEqual (y .x .shape , x1 .shape , msg = f'{ f } ' )
9599 self .assert_allclose (y .x , f_left (x1 ))
@@ -99,7 +103,9 @@ def f_right(x):
99103 self .assert_allclose (y .laplacian , self .laplacian (f_left , x1 ))
100104
101105 # test right hand argument
102- with self .subTest (sparsity = sparsity , f = f , binary = False ):
106+ with self .subTest (
107+ sparsity = sparsity , f = getattr (f , '__name__' , str (f )), binary = False
108+ ):
103109 y = forward_laplacian (f_right , sparsity )(x1 )
104110 self .assertEqual (y .x .shape , x1 .shape , msg = f'{ f } ' )
105111 self .assert_allclose (y .x , f_right (x1 ))
@@ -280,7 +286,7 @@ def f(x, dtype):
280286 jnp .complex64 ,
281287 jnp .complex128 ,
282288 ]:
283- with self .subTest (dtype = dtype ):
289+ with self .subTest (dtype = dtype . __name__ ):
284290 y = jax .jit (forward_laplacian (functools .partial (f , dtype = dtype )))(x )
285291 self .assertEqual (y .x .dtype , dtype )
286292 self .assertEqual (y .jacobian .dense_array .dtype , dtype )
@@ -297,7 +303,7 @@ def f(x, dtype):
297303 jnp .uint32 ,
298304 jnp .uint64 ,
299305 ]:
300- with self .subTest (dtype = dtype ):
306+ with self .subTest (dtype = dtype . __name__ ):
301307 y = jax .jit (forward_laplacian (functools .partial (f , dtype = dtype )))(x )
302308 self .assertIsInstance (y , jax .Array )
303309
0 commit comments