Skip to content

Commit 2db2059

Browse files
lingvo-botcopybara-github
authored andcommitted
Make sure pmap decode one step works with multi-processes by adding all_gather to get decoded outputs from all replicas.
PiperOrigin-RevId: 414907978
1 parent 8b6a834 commit 2db2059

File tree

1 file changed

+28
-17
lines changed

1 file changed

+28
-17
lines changed

lingvo/jax/eval.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
from absl import logging
2626
import jax
27-
from jax import numpy as jnp
2827
from jax.experimental import maps
2928
from lingvo.jax import base_layer
3029
from lingvo.jax import base_model_params
@@ -396,20 +395,35 @@ def decode_once_pmap_model(
396395
# different prng_key.
397396
prng_key = jax.random.fold_in(prng_key, jax.process_index())
398397
logging.info('root prng_key: %s', prng_key)
399-
400-
def decode_step(mdl_vars, prng_key, global_step, inputs):
401-
return trainer_lib.decode_step(jax_model, mdl_vars, prng_key, global_step,
402-
inputs, model_p.fprop_dtype)
403-
404-
num_devices = jax.local_device_count()
405398
prng_key, eval_key = jax.random.split(prng_key)
406-
eval_prng_seed = jax.random.split(eval_key, num=num_devices)
407-
logging.info('eval prng_seed: %s', eval_prng_seed)
399+
prng_seed = jax.random.split(eval_key, num=jax.local_device_count())
400+
logging.info('decoder prng_seed: %s', prng_seed)
408401

409-
p_decode_step = jax.pmap(decode_step, axis_name='batch')
410-
decode_step = functools.partial(p_decode_step,
411-
replicated_model_states.mdl_vars,
412-
eval_prng_seed, replicated_model_states.step)
402+
def decode_step(mdl_vars, prng_key, global_step, inputs):
403+
out = trainer_lib.decode_step(jax_model, mdl_vars, prng_key, global_step,
404+
inputs, model_p.fprop_dtype)
405+
out = jax.lax.all_gather(out, axis_name='batch', tiled=True)
406+
return out
407+
408+
# As an example, suppose the output leaf from trainer_lib.decoder_step()
409+
# for each core has shape: [per_core_batch_size, decoding_length].
410+
# In the all_gather we set tiled=True, so the output chunks are all
411+
# concatenated into the existing batch axis, so we get shape
412+
# [num_cores x per_core_batch_size, decoding_length].
413+
# In the pmap call we set out_axes=None to not have to manually unreplicate,
414+
# so the output of pmap_decode_step() will have the same shape.
415+
#
416+
# Example code snippet showing this:
417+
# # shape (8, 3, 2)
418+
# x = jnp.tile(jnp.arange(8)[:, None, None],[1, 3, 2])
419+
# # shape (24, 2)
420+
# z = jax.pmap(
421+
# lambda y: jax.lax.all_gather(y+1, axis_name='i', tiled=True),
422+
# axis_name='i', out_axes=None)(x)
423+
pmap_decode_step = jax.pmap(decode_step, axis_name='batch', out_axes=None)
424+
decode_step_func = functools.partial(pmap_decode_step,
425+
replicated_model_states.mdl_vars,
426+
prng_seed, replicated_model_states.step)
413427

414428
num_steps = [
415429
-1 if p.reset_for_eval else p.eval_loop_num_batches for p in input_p
@@ -424,11 +438,8 @@ def decode_step(mdl_vars, prng_key, global_step, inputs):
424438
except tf.errors.OutOfRangeError:
425439
break
426440
batch = tf.nest.map_structure(py_utils.reshard, batch)
427-
out = decode_step(batch)
441+
out = decode_step_func(batch)
428442
if jax.process_index() == 0:
429-
# TODO(zhouwk): test with multi-process. Do we need all_gather?
430-
out = jax.tree_map(lambda x: jnp.reshape(x, [-1] + list(x.shape[2:])),
431-
out)
432443
processed = jax_model.process_decode_out(inputs[split], out)
433444
decodes[split].extend(processed)
434445

0 commit comments

Comments
 (0)