We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 0c2cd81 commit a2f1e9aCopy full SHA for a2f1e9a
lingvo/jax/eval.py
@@ -494,7 +494,12 @@ def decode_once_spmd_model(
494
f'{jax.process_index():03d}')
495
496
def get_shape_dtype(x):
497
- y = jax.ShapeDtypeStruct(x.shape, x.dtype)
+ # The sample input batch we are getting shape from is only from
498
+ # the current process. Manually scale this to the global batch size
499
+ # by assuming all the hosts infeed the same data.
500
+ assert len(x.shape) >= 1
501
+ x_shape = (x.shape[0] * jax.process_count(),) + x.shape[1:]
502
+ y = jax.ShapeDtypeStruct(x_shape, x.dtype)
503
return y
504
505
sample_inputs = input_p[0].Instantiate().get_next()
0 commit comments