diff --git a/MaxText/checkpointing.py b/MaxText/checkpointing.py index f61b297531..e291ab4610 100644 --- a/MaxText/checkpointing.py +++ b/MaxText/checkpointing.py @@ -17,6 +17,7 @@ """Create an Orbax CheckpointManager with specified (Async or not) Checkpointer.""" from typing import Optional, Union +from absl import flags from etils import epath import orbax.checkpoint from orbax.checkpoint.logging import abstract_logger, cloud_logger, standard_logger, composite_logger @@ -81,6 +82,7 @@ def create_orbax_emergency_checkpoint_manager( persistent_save_interval_steps: int, ): """Returns an emergency checkpoint.""" + flags.FLAGS.experimental_orbax_use_distributed_process_id = True max_logging.log("Creating emergency checkpoint manager...") local_registry = type_handlers.create_type_handler_registry( @@ -220,6 +222,16 @@ def map_to_pspec(data): map_to_pspec, abstract_unboxed_pre_state, ) + + if isinstance(checkpoint_manager, emergency_checkpoint_manager.CheckpointManager): + return ( + checkpoint_manager.restore( + latest_step, + args=orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args), + ), + None, + ) + if dataset_type == "grain" and data_iterator is not None: return ( checkpoint_manager.restore( diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index ba7a508439..71fcb873e1 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -21,6 +21,7 @@ import time import socket import subprocess +from etils import epath import max_logging @@ -29,6 +30,7 @@ import jax.numpy as jnp from jax.experimental import mesh_utils import orbax.checkpoint as ocp +import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager import json @@ -228,8 +230,10 @@ def maybe_initialize_jax_distributed_system(raw_keys): and not raw_keys["enable_single_controller"] ) or raw_keys["hardware"] == "gpu_multiprocess": max_logging.log("Attempting to initialize the jax distributed system...") - jax.distributed.initialize() - ocp.multihost.utils.initialize_runtime_to_distributed_ids() + if not raw_keys['enable_emergency_checkpoint']: + jax.distributed.initialize() + else: + initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys) max_logging.log("Jax distributed system initialized!") @@ -266,6 +270,69 @@ def initialize_jax_for_cpu(): ) +def initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys): + """Initialize JAX distributed runtime for TPUs when emergency checkpointing is used. + Currently, this only works for two scenarios: + 1) A fresh run where no ID files exist on any nodes, + 2) A "restore" run where the pods land on the same set of nodes as the previous run. + For Scenario #2 specifically, note the following cases where it will not work: + a) Some slices have failed, or + b) In a shared cluster, the pods land on a different set of nodes in the "restore run. + + Ongoing work outside of MaxText would get rid of these restrictions. + + TODO: Update this function to incorporate the ongoing work once available. + """ + ID_FILE = "id_file.txt" + + local_id_file = epath.Path(raw_keys["local_checkpoint_directory"]) / ID_FILE + if local_id_file.exists(): + max_logging.log("An ID file from a previous run exists, initializing JAX distributed runtime using the saved ID.") + process_id = int(local_id_file.read_text()) + coordinator_address_file = _get_coordinator_address_file(raw_keys) + if process_id == 0: + coordinator_address = _get_coordinator_address_for_emergency_checkpointing(raw_keys) + coordinator_address_file.write_text(coordinator_address) + else: + coordinator_address = _retrieve_coordinator_address(coordinator_address_file) + max_logging.log(f"Using the saved process_id of {process_id} and the 0'th process's address {coordinator_address}" + " to initialize JAX distributed runtime...") + jax.distributed.initialize(coordinator_address=coordinator_address, process_id=process_id) + else: + max_logging.log("No ID file from a previous run exists, initializing JAX distributed runtime without args.") + jax.distributed.initialize() + process_id = jax._src.distributed.global_state.process_id # pylint: disable=protected-access + local_id_file.write_text(str(process_id)) + + ocp.multihost.utils.initialize_runtime_to_distributed_ids() + + +def _get_run_name(raw_keys): + if raw_keys["run_name"] != "": + return raw_keys["run_name"] + return os.environ.get("JOBSET_NAME") + + +def _get_coordinator_address_file(raw_keys): + COORDINATOR_ADDRESS_FILE = "coordinator_address.txt" + return epath.Path(os.path.join(raw_keys["base_output_directory"], _get_run_name(raw_keys), COORDINATOR_ADDRESS_FILE)) + + +def _get_coordinator_address_for_emergency_checkpointing(raw_keys): + run_name = _get_run_name(raw_keys) + slice_id = os.environ.get('MEGASCALE_SLICE_ID') + worker_id = os.environ.get('TPU_WORKER_ID') + return f"{run_name}-slice-job-{slice_id}-{worker_id}.{run_name}:8476" + + +def _retrieve_coordinator_address(coordinator_address_file): + for _ in range(30): + if coordinator_address_file.exists(): + return coordinator_address_file.read_text() + time.sleep(1) + return "" + + def is_cpu_backend(raw_keys): """Determine whether Maxtext is intended to run on a CPU backend.""" return raw_keys["hardware"] == "cpu" @@ -543,9 +610,12 @@ def setup_initial_state( ) if restored: - if "iter" in restored and restored["iter"] is not None: - data_iterator.local_iterator = restored["iter"] - state = restored["items"] + if isinstance(checkpoint_manager, emergency_checkpoint_manager.CheckpointManager): + state = restored + else: + if "iter" in restored and restored["iter"] is not None: + data_iterator.local_iterator = restored["iter"] + state = restored["items"] else: init_state_partial = functools.partial( init_initial_state, model, tx, config, is_training diff --git a/MaxText/train.py b/MaxText/train.py index 6e003e6192..d565793b92 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -23,6 +23,7 @@ import datetime import os import sys +from etils import epath import functools from typing import Sequence @@ -351,6 +352,9 @@ def setup_mesh_and_model(config): devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) + if emergency_checkpoint_manager.should_restore_mesh_from_metadata(epath.Path(config.checkpoint_dir)): + mesh = emergency_checkpoint_manager.consistent_restore_mesh_from_metadata(epath.Path(config.checkpoint_dir), mesh) + # Model and Optimizer definition quant = quantizations.configure_quantization(config) model = Transformer(config, mesh, quant=quant) diff --git a/xpk b/xpk new file mode 160000 index 0000000000..096028844c --- /dev/null +++ b/xpk @@ -0,0 +1 @@ +Subproject commit 096028844c9c1ad02d53c63b0086aa71a4e02db0