Skip to content

Commit a2f1e9a

Browse files
lingvo-botcopybara-github
authored andcommitted
Fix spmd decoding: manually fix global batch size by a factor of process count.
PiperOrigin-RevId: 415626396
1 parent 0c2cd81 commit a2f1e9a

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

lingvo/jax/eval.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,12 @@ def decode_once_spmd_model(
494494
f'{jax.process_index():03d}')
495495

496496
def get_shape_dtype(x):
497-
y = jax.ShapeDtypeStruct(x.shape, x.dtype)
497+
# The sample input batch we are getting shape from is only from
498+
# the current process. Manually scale this to the global batch size
499+
# by assuming all the hosts infeed the same data.
500+
assert len(x.shape) >= 1
501+
x_shape = (x.shape[0] * jax.process_count(),) + x.shape[1:]
502+
y = jax.ShapeDtypeStruct(x_shape, x.dtype)
498503
return y
499504

500505
sample_inputs = input_p[0].Instantiate().get_next()

0 commit comments

Comments
 (0)