From e7c019acbd5d9635e58e82798eaea876d5ea39dd Mon Sep 17 00:00:00 2001 From: Xuefeng Gu Date: Thu, 27 Jun 2024 18:42:24 +0000 Subject: [PATCH] Preliminary restore with lots of hardcoding and hacking Refactor the code and remove the hardcoding More refactoring Cleanup for pull request Address linting issues Preliminary restore with lots of hardcoding and hacking Refactor the code and remove the hardcoding More refactoring Cleanup for pull request Address linting issues Small formatting Fix merging issues --- MaxText/checkpointing.py | 12 ++++++ MaxText/max_utils.py | 80 +++++++++++++++++++++++++++++++++++++--- MaxText/train.py | 4 ++ xpk | 1 + 4 files changed, 92 insertions(+), 5 deletions(-) create mode 160000 xpk diff --git a/MaxText/checkpointing.py b/MaxText/checkpointing.py index f61b297531..e291ab4610 100644 --- a/MaxText/checkpointing.py +++ b/MaxText/checkpointing.py @@ -17,6 +17,7 @@ """Create an Orbax CheckpointManager with specified (Async or not) Checkpointer.""" from typing import Optional, Union +from absl import flags from etils import epath import orbax.checkpoint from orbax.checkpoint.logging import abstract_logger, cloud_logger, standard_logger, composite_logger @@ -81,6 +82,7 @@ def create_orbax_emergency_checkpoint_manager( persistent_save_interval_steps: int, ): """Returns an emergency checkpoint.""" + flags.FLAGS.experimental_orbax_use_distributed_process_id = True max_logging.log("Creating emergency checkpoint manager...") local_registry = type_handlers.create_type_handler_registry( @@ -220,6 +222,16 @@ def map_to_pspec(data): map_to_pspec, abstract_unboxed_pre_state, ) + + if isinstance(checkpoint_manager, emergency_checkpoint_manager.CheckpointManager): + return ( + checkpoint_manager.restore( + latest_step, + args=orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args), + ), + None, + ) + if dataset_type == "grain" and data_iterator is not None: return ( checkpoint_manager.restore( diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index ba7a508439..71fcb873e1 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -21,6 +21,7 @@ import time import socket import subprocess +from etils import epath import max_logging @@ -29,6 +30,7 @@ import jax.numpy as jnp from jax.experimental import mesh_utils import orbax.checkpoint as ocp +import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager import json @@ -228,8 +230,10 @@ def maybe_initialize_jax_distributed_system(raw_keys): and not raw_keys["enable_single_controller"] ) or raw_keys["hardware"] == "gpu_multiprocess": max_logging.log("Attempting to initialize the jax distributed system...") - jax.distributed.initialize() - ocp.multihost.utils.initialize_runtime_to_distributed_ids() + if not raw_keys['enable_emergency_checkpoint']: + jax.distributed.initialize() + else: + initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys) max_logging.log("Jax distributed system initialized!") @@ -266,6 +270,69 @@ def initialize_jax_for_cpu(): ) +def initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys): + """Initialize JAX distributed runtime for TPUs when emergency checkpointing is used. + Currently, this only works for two scenarios: + 1) A fresh run where no ID files exist on any nodes, + 2) A "restore" run where the pods land on the same set of nodes as the previous run. + For Scenario #2 specifically, note the following cases where it will not work: + a) Some slices have failed, or + b) In a shared cluster, the pods land on a different set of nodes in the "restore run. + + Ongoing work outside of MaxText would get rid of these restrictions. + + TODO: Update this function to incorporate the ongoing work once available. + """ + ID_FILE = "id_file.txt" + + local_id_file = epath.Path(raw_keys["local_checkpoint_directory"]) / ID_FILE + if local_id_file.exists(): + max_logging.log("An ID file from a previous run exists, initializing JAX distributed runtime using the saved ID.") + process_id = int(local_id_file.read_text()) + coordinator_address_file = _get_coordinator_address_file(raw_keys) + if process_id == 0: + coordinator_address = _get_coordinator_address_for_emergency_checkpointing(raw_keys) + coordinator_address_file.write_text(coordinator_address) + else: + coordinator_address = _retrieve_coordinator_address(coordinator_address_file) + max_logging.log(f"Using the saved process_id of {process_id} and the 0'th process's address {coordinator_address}" + " to initialize JAX distributed runtime...") + jax.distributed.initialize(coordinator_address=coordinator_address, process_id=process_id) + else: + max_logging.log("No ID file from a previous run exists, initializing JAX distributed runtime without args.") + jax.distributed.initialize() + process_id = jax._src.distributed.global_state.process_id # pylint: disable=protected-access + local_id_file.write_text(str(process_id)) + + ocp.multihost.utils.initialize_runtime_to_distributed_ids() + + +def _get_run_name(raw_keys): + if raw_keys["run_name"] != "": + return raw_keys["run_name"] + return os.environ.get("JOBSET_NAME") + + +def _get_coordinator_address_file(raw_keys): + COORDINATOR_ADDRESS_FILE = "coordinator_address.txt" + return epath.Path(os.path.join(raw_keys["base_output_directory"], _get_run_name(raw_keys), COORDINATOR_ADDRESS_FILE)) + + +def _get_coordinator_address_for_emergency_checkpointing(raw_keys): + run_name = _get_run_name(raw_keys) + slice_id = os.environ.get('MEGASCALE_SLICE_ID') + worker_id = os.environ.get('TPU_WORKER_ID') + return f"{run_name}-slice-job-{slice_id}-{worker_id}.{run_name}:8476" + + +def _retrieve_coordinator_address(coordinator_address_file): + for _ in range(30): + if coordinator_address_file.exists(): + return coordinator_address_file.read_text() + time.sleep(1) + return "" + + def is_cpu_backend(raw_keys): """Determine whether Maxtext is intended to run on a CPU backend.""" return raw_keys["hardware"] == "cpu" @@ -543,9 +610,12 @@ def setup_initial_state( ) if restored: - if "iter" in restored and restored["iter"] is not None: - data_iterator.local_iterator = restored["iter"] - state = restored["items"] + if isinstance(checkpoint_manager, emergency_checkpoint_manager.CheckpointManager): + state = restored + else: + if "iter" in restored and restored["iter"] is not None: + data_iterator.local_iterator = restored["iter"] + state = restored["items"] else: init_state_partial = functools.partial( init_initial_state, model, tx, config, is_training diff --git a/MaxText/train.py b/MaxText/train.py index 6e003e6192..d565793b92 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -23,6 +23,7 @@ import datetime import os import sys +from etils import epath import functools from typing import Sequence @@ -351,6 +352,9 @@ def setup_mesh_and_model(config): devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) + if emergency_checkpoint_manager.should_restore_mesh_from_metadata(epath.Path(config.checkpoint_dir)): + mesh = emergency_checkpoint_manager.consistent_restore_mesh_from_metadata(epath.Path(config.checkpoint_dir), mesh) + # Model and Optimizer definition quant = quantizations.configure_quantization(config) model = Transformer(config, mesh, quant=quant) diff --git a/xpk b/xpk new file mode 160000 index 0000000000..096028844c --- /dev/null +++ b/xpk @@ -0,0 +1 @@ +Subproject commit 096028844c9c1ad02d53c63b0086aa71a4e02db0