Skip to content

Commit a63c08b

Browse files
Copilotor4k2l
andcommitted
Add batch_stats verification in test_training_mode
Co-authored-by: or4k2l <219930442+or4k2l@users.noreply.github.com>
1 parent c6e7398 commit a63c08b

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

tests/test_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def test_training_mode(self, model, test_input):
5454
# Extract logits and updated batch_stats
5555
logits, updated_vars = output
5656

57+
# Verify batch_stats is present (model uses BatchNorm)
58+
assert 'batch_stats' in updated_vars
59+
5760
# Should produce valid output
5861
assert logits.shape == (4, 10)
5962
assert not jnp.any(jnp.isnan(logits))

0 commit comments

Comments
 (0)