Skip to content

Commit 9ddf63a

Browse files
lingvo-botcopybara-github
authored andcommitted
Add support for decode_once with spmd model.
PiperOrigin-RevId: 415300883
1 parent fb11eb0 commit 9ddf63a

File tree

4 files changed

+204
-209
lines changed

4 files changed

+204
-209
lines changed

lingvo/jax/eval.py

Lines changed: 122 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def get_shape_dtype(x):
216216
device_mesh = mesh_utils.create_device_mesh(mesh_shape)
217217
logging.info('device_mesh: %s', device_mesh)
218218
with maps.mesh(device_mesh, model_p.mesh_axis_names):
219-
partitioned_train_state, partitioned_specs, _, eval_step, _, _, _ = (
219+
partitioned_train_state, partitioned_specs, _, eval_step, _ = (
220220
trainer_lib.partition_spmd_model(model_p, init_key, inputs_shape))
221221
partitioned_train_state = checkpoints.restore_checkpoint(
222222
partitioned_train_state,
@@ -316,8 +316,9 @@ def decode_once(
316316
inp.num_infeed_hosts = jax.process_count()
317317
inp.infeed_host_index = jax.process_index()
318318
if model_p.device_mesh is not None:
319-
del multi_host_checkpointing
320-
raise NotImplementedError('Decoding spmd model not yet supported.')
319+
decode_once_spmd_model(model_p, decoder_inputs, job_log_dir,
320+
checkpoint_type, restore_checkpoint_dir,
321+
restore_checkpoint_step, multi_host_checkpointing)
321322
else:
322323
decode_once_pmap_model(model_p, decoder_inputs, job_log_dir,
323324
checkpoint_type, restore_checkpoint_dir,
@@ -456,3 +457,121 @@ def decode_step(mdl_vars, prng_key, global_step, inputs):
456457
logging.info('Writing decoder output to %s with %d entries', output_file,
457458
len(decodes[split]))
458459
io_utils.WriteKeyValuePairs(output_file, decodes[split])
460+
461+
462+
def decode_once_spmd_model(
463+
model_p: InstantiableParams,
464+
input_p: Sequence[InstantiableParams],
465+
job_log_dir: Optional[str],
466+
checkpoint_type: checkpoints.CheckpointType,
467+
restore_checkpoint_dir: str,
468+
restore_checkpoint_step: Optional[int],
469+
multi_host_checkpointing: bool,
470+
) -> None:
471+
"""Runs the decoding once on the entire decoder datasets for SPMD model.
472+
473+
Args:
474+
model_p: Params for the spmd model.
475+
input_p: List of input params to be decoded.
476+
job_log_dir: Directory for the job logs.
477+
checkpoint_type: Type of model checkpointing method to use.
478+
restore_checkpoint_dir: The directory from which to restore checkpoint.
479+
restore_checkpoint_step: If set, the checkpoint step to restore. If unset,
480+
try to restore from the latest checkpoint if any.
481+
multi_host_checkpointing: Whether to use multi-host checkpointing.
482+
"""
483+
# TODO(bf-jax): Retrieve the seeds from the model definition instead.
484+
prng_key = jax.random.PRNGKey(1234)
485+
prng_key, init_key = jax.random.split(prng_key)
486+
487+
if restore_checkpoint_dir and multi_host_checkpointing:
488+
# TODO(zhouwk): add sanity check on number of subdirs and number of
489+
# processes and fail early if unequal.
490+
restore_checkpoint_dir = os.path.join(restore_checkpoint_dir,
491+
f'{jax.process_index():03d}')
492+
493+
def get_shape_dtype(x):
494+
y = jax.ShapeDtypeStruct(x.shape, x.dtype)
495+
return y
496+
497+
sample_inputs = input_p[0].Instantiate().get_next()
498+
inputs_shape = tf.nest.map_structure(get_shape_dtype, sample_inputs)
499+
500+
# TODO(b/198356509): This is a hack for now as we need to change some
501+
# annotations for mode='decode'. A future cl will move this logic
502+
# to a more generic model_p.update_sharding_params_v1(mode='decode').
503+
model_p.lm = model_p.lm.cls.set_sharding_params_v1(
504+
model_p.lm,
505+
replica_axis=model_p.lm.mesh_axis_names[0],
506+
data_axis=model_p.lm.mesh_axis_names[1],
507+
mdl_axis=model_p.lm.mesh_axis_names[2],
508+
device_ids_mesh=model_p.lm.device_mesh,
509+
mesh_axis_names=model_p.lm.mesh_axis_names,
510+
mode='decode')
511+
512+
mesh_shape = model_p.device_mesh.shape
513+
device_mesh = mesh_utils.create_device_mesh(mesh_shape)
514+
logging.info('device_mesh: %s', device_mesh)
515+
if jax.process_index() == 0:
516+
# The instantiated model is only used for processing decode
517+
# outputs, which only happens on process 0.
518+
jax_model = model_p.Instantiate()
519+
with maps.mesh(device_mesh, model_p.mesh_axis_names):
520+
partitioned_train_state, partitioned_specs, decode_step_fn = (
521+
trainer_lib.partition_spmd_model_decode(model_p, init_key,
522+
inputs_shape))
523+
if restore_checkpoint_dir:
524+
partitioned_train_state = checkpoints.restore_checkpoint(
525+
partitioned_train_state,
526+
restore_checkpoint_dir,
527+
checkpoint_type=checkpoint_type,
528+
state_specs=partitioned_specs,
529+
step=restore_checkpoint_step)
530+
if multi_host_checkpointing:
531+
py_utils.sync_global_devices(
532+
f'checkpointer:restored:{restore_checkpoint_dir}')
533+
logging.info('partitioned_train_state: %s',
534+
jax.tree_map(lambda x: x.shape, partitioned_train_state))
535+
536+
# We do not fold in jax.process_index in contrast to the pmap version and
537+
# use a single global key instead to rely on pjit to split for different
538+
# replicas.
539+
logging.info('root prng_key: %s', prng_key)
540+
prng_key, decode_key = jax.random.split(prng_key)
541+
logging.info('eval prng_key: %s', decode_key)
542+
spmd_decode_step_fn = functools.partial(decode_step_fn,
543+
partitioned_train_state.mdl_vars,
544+
decode_key,
545+
partitioned_train_state.step)
546+
547+
num_steps = [
548+
-1 if p.reset_for_eval else p.eval_loop_num_batches for p in input_p
549+
]
550+
inputs = [p.Instantiate() for p in input_p]
551+
decodes = [list() for _ in input_p]
552+
for split, num_split_steps in enumerate(num_steps):
553+
step_num = 0
554+
while num_split_steps < 0 or step_num < num_split_steps:
555+
step_num += 1
556+
try:
557+
batch = inputs[split].get_next()
558+
except tf.errors.OutOfRangeError:
559+
break
560+
out = spmd_decode_step_fn(batch)
561+
if jax.process_index() == 0:
562+
processed = jax_model.process_decode_out(inputs[split], out)
563+
decodes[split].extend(processed)
564+
565+
basedir = os.path.join(job_log_dir, 'decoder_out')
566+
dirnames = _get_dir_names(input_p)
567+
filename = _get_filename(partitioned_train_state.step)
568+
for s in dirnames:
569+
dir_path = os.path.join(basedir, s)
570+
if not tf.io.gfile.exists(dir_path):
571+
tf.io.gfile.makedirs(dir_path)
572+
filenames = [os.path.join(basedir, s, filename) for s in dirnames]
573+
if jax.process_index() == 0:
574+
for split, output_file in enumerate(filenames):
575+
logging.info('Writing decoder output to %s with %d entries', output_file,
576+
len(decodes[split]))
577+
io_utils.WriteKeyValuePairs(output_file, decodes[split])

lingvo/jax/mlperf/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def get_shape_dtype(x):
397397
device_mesh = mesh_utils.create_device_mesh(mesh_shape)
398398
logging.info('device_mesh: %s', device_mesh)
399399
with maps.mesh(device_mesh, model_p.mesh_axis_names):
400-
(partitioned_train_state, _, train_step, eval_step, _, _,
400+
(partitioned_train_state, _, train_step, eval_step,
401401
total_num_params) = trainer_lib.partition_spmd_model(
402402
model_p, init_key, inputs_shape)
403403

lingvo/jax/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def get_shape_dtype(x):
389389
device_mesh = mesh_utils.create_device_mesh(mesh_shape)
390390
logging.info('device_mesh: %s', device_mesh)
391391
with maps.mesh(device_mesh, model_p.mesh_axis_names):
392-
(partitioned_train_state, partitioned_specs, train_step, eval_step, _, _,
392+
(partitioned_train_state, partitioned_specs, train_step, eval_step,
393393
total_num_params) = trainer_lib.partition_spmd_model(
394394
model_p, init_key, inputs_shape)
395395

0 commit comments

Comments
 (0)