@@ -216,7 +216,7 @@ def get_shape_dtype(x):
216216 device_mesh = mesh_utils .create_device_mesh (mesh_shape )
217217 logging .info ('device_mesh: %s' , device_mesh )
218218 with maps .mesh (device_mesh , model_p .mesh_axis_names ):
219- partitioned_train_state , partitioned_specs , _ , eval_step , _ , _ , _ = (
219+ partitioned_train_state , partitioned_specs , _ , eval_step , _ = (
220220 trainer_lib .partition_spmd_model (model_p , init_key , inputs_shape ))
221221 partitioned_train_state = checkpoints .restore_checkpoint (
222222 partitioned_train_state ,
@@ -316,8 +316,9 @@ def decode_once(
316316 inp .num_infeed_hosts = jax .process_count ()
317317 inp .infeed_host_index = jax .process_index ()
318318 if model_p .device_mesh is not None :
319- del multi_host_checkpointing
320- raise NotImplementedError ('Decoding spmd model not yet supported.' )
319+ decode_once_spmd_model (model_p , decoder_inputs , job_log_dir ,
320+ checkpoint_type , restore_checkpoint_dir ,
321+ restore_checkpoint_step , multi_host_checkpointing )
321322 else :
322323 decode_once_pmap_model (model_p , decoder_inputs , job_log_dir ,
323324 checkpoint_type , restore_checkpoint_dir ,
@@ -456,3 +457,121 @@ def decode_step(mdl_vars, prng_key, global_step, inputs):
456457 logging .info ('Writing decoder output to %s with %d entries' , output_file ,
457458 len (decodes [split ]))
458459 io_utils .WriteKeyValuePairs (output_file , decodes [split ])
460+
461+
462+ def decode_once_spmd_model (
463+ model_p : InstantiableParams ,
464+ input_p : Sequence [InstantiableParams ],
465+ job_log_dir : Optional [str ],
466+ checkpoint_type : checkpoints .CheckpointType ,
467+ restore_checkpoint_dir : str ,
468+ restore_checkpoint_step : Optional [int ],
469+ multi_host_checkpointing : bool ,
470+ ) -> None :
471+ """Runs the decoding once on the entire decoder datasets for SPMD model.
472+
473+ Args:
474+ model_p: Params for the spmd model.
475+ input_p: List of input params to be decoded.
476+ job_log_dir: Directory for the job logs.
477+ checkpoint_type: Type of model checkpointing method to use.
478+ restore_checkpoint_dir: The directory from which to restore checkpoint.
479+ restore_checkpoint_step: If set, the checkpoint step to restore. If unset,
480+ try to restore from the latest checkpoint if any.
481+ multi_host_checkpointing: Whether to use multi-host checkpointing.
482+ """
483+ # TODO(bf-jax): Retrieve the seeds from the model definition instead.
484+ prng_key = jax .random .PRNGKey (1234 )
485+ prng_key , init_key = jax .random .split (prng_key )
486+
487+ if restore_checkpoint_dir and multi_host_checkpointing :
488+ # TODO(zhouwk): add sanity check on number of subdirs and number of
489+ # processes and fail early if unequal.
490+ restore_checkpoint_dir = os .path .join (restore_checkpoint_dir ,
491+ f'{ jax .process_index ():03d} ' )
492+
493+ def get_shape_dtype (x ):
494+ y = jax .ShapeDtypeStruct (x .shape , x .dtype )
495+ return y
496+
497+ sample_inputs = input_p [0 ].Instantiate ().get_next ()
498+ inputs_shape = tf .nest .map_structure (get_shape_dtype , sample_inputs )
499+
500+ # TODO(b/198356509): This is a hack for now as we need to change some
501+ # annotations for mode='decode'. A future cl will move this logic
502+ # to a more generic model_p.update_sharding_params_v1(mode='decode').
503+ model_p .lm = model_p .lm .cls .set_sharding_params_v1 (
504+ model_p .lm ,
505+ replica_axis = model_p .lm .mesh_axis_names [0 ],
506+ data_axis = model_p .lm .mesh_axis_names [1 ],
507+ mdl_axis = model_p .lm .mesh_axis_names [2 ],
508+ device_ids_mesh = model_p .lm .device_mesh ,
509+ mesh_axis_names = model_p .lm .mesh_axis_names ,
510+ mode = 'decode' )
511+
512+ mesh_shape = model_p .device_mesh .shape
513+ device_mesh = mesh_utils .create_device_mesh (mesh_shape )
514+ logging .info ('device_mesh: %s' , device_mesh )
515+ if jax .process_index () == 0 :
516+ # The instantiated model is only used for processing decode
517+ # outputs, which only happens on process 0.
518+ jax_model = model_p .Instantiate ()
519+ with maps .mesh (device_mesh , model_p .mesh_axis_names ):
520+ partitioned_train_state , partitioned_specs , decode_step_fn = (
521+ trainer_lib .partition_spmd_model_decode (model_p , init_key ,
522+ inputs_shape ))
523+ if restore_checkpoint_dir :
524+ partitioned_train_state = checkpoints .restore_checkpoint (
525+ partitioned_train_state ,
526+ restore_checkpoint_dir ,
527+ checkpoint_type = checkpoint_type ,
528+ state_specs = partitioned_specs ,
529+ step = restore_checkpoint_step )
530+ if multi_host_checkpointing :
531+ py_utils .sync_global_devices (
532+ f'checkpointer:restored:{ restore_checkpoint_dir } ' )
533+ logging .info ('partitioned_train_state: %s' ,
534+ jax .tree_map (lambda x : x .shape , partitioned_train_state ))
535+
536+ # We do not fold in jax.process_index in contrast to the pmap version and
537+ # use a single global key instead to rely on pjit to split for different
538+ # replicas.
539+ logging .info ('root prng_key: %s' , prng_key )
540+ prng_key , decode_key = jax .random .split (prng_key )
541+ logging .info ('eval prng_key: %s' , decode_key )
542+ spmd_decode_step_fn = functools .partial (decode_step_fn ,
543+ partitioned_train_state .mdl_vars ,
544+ decode_key ,
545+ partitioned_train_state .step )
546+
547+ num_steps = [
548+ - 1 if p .reset_for_eval else p .eval_loop_num_batches for p in input_p
549+ ]
550+ inputs = [p .Instantiate () for p in input_p ]
551+ decodes = [list () for _ in input_p ]
552+ for split , num_split_steps in enumerate (num_steps ):
553+ step_num = 0
554+ while num_split_steps < 0 or step_num < num_split_steps :
555+ step_num += 1
556+ try :
557+ batch = inputs [split ].get_next ()
558+ except tf .errors .OutOfRangeError :
559+ break
560+ out = spmd_decode_step_fn (batch )
561+ if jax .process_index () == 0 :
562+ processed = jax_model .process_decode_out (inputs [split ], out )
563+ decodes [split ].extend (processed )
564+
565+ basedir = os .path .join (job_log_dir , 'decoder_out' )
566+ dirnames = _get_dir_names (input_p )
567+ filename = _get_filename (partitioned_train_state .step )
568+ for s in dirnames :
569+ dir_path = os .path .join (basedir , s )
570+ if not tf .io .gfile .exists (dir_path ):
571+ tf .io .gfile .makedirs (dir_path )
572+ filenames = [os .path .join (basedir , s , filename ) for s in dirnames ]
573+ if jax .process_index () == 0 :
574+ for split , output_file in enumerate (filenames ):
575+ logging .info ('Writing decoder output to %s with %d entries' , output_file ,
576+ len (decodes [split ]))
577+ io_utils .WriteKeyValuePairs (output_file , decodes [split ])
0 commit comments