2424
2525from absl import logging
2626import jax
27- from jax import numpy as jnp
2827from jax .experimental import maps
2928from lingvo .jax import base_layer
3029from lingvo .jax import base_model_params
@@ -396,20 +395,35 @@ def decode_once_pmap_model(
396395 # different prng_key.
397396 prng_key = jax .random .fold_in (prng_key , jax .process_index ())
398397 logging .info ('root prng_key: %s' , prng_key )
399-
400- def decode_step (mdl_vars , prng_key , global_step , inputs ):
401- return trainer_lib .decode_step (jax_model , mdl_vars , prng_key , global_step ,
402- inputs , model_p .fprop_dtype )
403-
404- num_devices = jax .local_device_count ()
405398 prng_key , eval_key = jax .random .split (prng_key )
406- eval_prng_seed = jax .random .split (eval_key , num = num_devices )
407- logging .info ('eval prng_seed: %s' , eval_prng_seed )
399+ prng_seed = jax .random .split (eval_key , num = jax . local_device_count () )
400+ logging .info ('decoder prng_seed: %s' , prng_seed )
408401
409- p_decode_step = jax .pmap (decode_step , axis_name = 'batch' )
410- decode_step = functools .partial (p_decode_step ,
411- replicated_model_states .mdl_vars ,
412- eval_prng_seed , replicated_model_states .step )
402+ def decode_step (mdl_vars , prng_key , global_step , inputs ):
403+ out = trainer_lib .decode_step (jax_model , mdl_vars , prng_key , global_step ,
404+ inputs , model_p .fprop_dtype )
405+ out = jax .lax .all_gather (out , axis_name = 'batch' , tiled = True )
406+ return out
407+
408+ # As an example, suppose the output leaf from trainer_lib.decoder_step()
409+ # for each core has shape: [per_core_batch_size, decoding_length].
410+ # In the all_gather we set tiled=True, so the output chunks are all
411+ # concatenated into the existing batch axis, so we get shape
412+ # [num_cores x per_core_batch_size, decoding_length].
413+ # In the pmap call we set out_axes=None to not have to manually unreplicate,
414+ # so the output of pmap_decode_step() will have the same shape.
415+ #
416+ # Example code snippet showing this:
417+ # # shape (8, 3, 2)
418+ # x = jnp.tile(jnp.arange(8)[:, None, None],[1, 3, 2])
419+ # # shape (24, 2)
420+ # z = jax.pmap(
421+ # lambda y: jax.lax.all_gather(y+1, axis_name='i', tiled=True),
422+ # axis_name='i', out_axes=None)(x)
423+ pmap_decode_step = jax .pmap (decode_step , axis_name = 'batch' , out_axes = None )
424+ decode_step_func = functools .partial (pmap_decode_step ,
425+ replicated_model_states .mdl_vars ,
426+ prng_seed , replicated_model_states .step )
413427
414428 num_steps = [
415429 - 1 if p .reset_for_eval else p .eval_loop_num_batches for p in input_p
@@ -424,11 +438,8 @@ def decode_step(mdl_vars, prng_key, global_step, inputs):
424438 except tf .errors .OutOfRangeError :
425439 break
426440 batch = tf .nest .map_structure (py_utils .reshard , batch )
427- out = decode_step (batch )
441+ out = decode_step_func (batch )
428442 if jax .process_index () == 0 :
429- # TODO(zhouwk): test with multi-process. Do we need all_gather?
430- out = jax .tree_map (lambda x : jnp .reshape (x , [- 1 ] + list (x .shape [2 :])),
431- out )
432443 processed = jax_model .process_decode_out (inputs [split ], out )
433444 decodes [split ].extend (processed )
434445
0 commit comments