From 4307ed223b6a2da9f337dc6832480e8dee99a159 Mon Sep 17 00:00:00 2001 From: yangyuwei Date: Tue, 27 Feb 2024 16:59:02 -0800 Subject: [PATCH 01/10] Make small changes for running it via XPK. --- docker_build_dependency_image.sh | 2 +- gke/gpu/start_training.sh | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/docker_build_dependency_image.sh b/docker_build_dependency_image.sh index a1b0cfeaf6..7ffdee048a 100644 --- a/docker_build_dependency_image.sh +++ b/docker_build_dependency_image.sh @@ -22,7 +22,7 @@ # Enable "exit immediately if any command fails" option set -e -export LOCAL_IMAGE_NAME=maxtext_base_image +export LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME echo "Starting to build your docker image. This will take a few minutes but the image can be reused as you iterate." diff --git a/gke/gpu/start_training.sh b/gke/gpu/start_training.sh index a8dca2db2f..8efbbd94e7 100644 --- a/gke/gpu/start_training.sh +++ b/gke/gpu/start_training.sh @@ -115,9 +115,6 @@ resolve_coordinator_ip() { fi } -# HLO dump -export XLA_FLAGS="--xla_dump_to=/tmp/xladump" - # Resolving coordinator IP set +e resolve_coordinator_ip @@ -126,10 +123,10 @@ set -e PIDS=() for ((LOCAL_DEVICE_ID=0; LOCAL_DEVICE_ID <= $((GPUS_PER_NODE - 1)); LOCAL_DEVICE_ID++)); do PROCESS_ID=$(($GPUS_PER_NODE*$NODE_RANK + $LOCAL_DEVICE_ID)) - LOCAL_DEVICE_ID=$LOCAL_DEVICE_ID PROCESS_ID=$PROCESS_ID python MaxText/train.py MaxText/configs/base.yml hardware=gpu run_name=${RUN_NAME}_$(date +%Y-%m-%d-%H-%M) & + LOCAL_DEVICE_ID=$LOCAL_DEVICE_ID PROCESS_ID=$PROCESS_ID python MaxText/train.py MaxText/configs/base.yml hardware=gpu run_name=${RUN_NAME}_$(date +%Y-%m-%d-%H-%M) ${ARGS} & PID=$! PIDS+=($PID) echo "Launched MaxText/train.py for local_device_id: $LOCAL_DEVICE_ID process_id: $PROCESS_ID and PID $PID" done -wait_all_success_or_exit "${PIDS[@]}" \ No newline at end of file +wait_all_success_or_exit "${PIDS[@]}" From 6b105eb24ed3b607ff6409ba065cd448db091192 Mon Sep 17 00:00:00 2001 From: yangyuwei Date: Tue, 27 Feb 2024 17:05:07 -0800 Subject: [PATCH 02/10] Revert changes to docker_build_dependency_image.sh. --- docker_build_dependency_image.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker_build_dependency_image.sh b/docker_build_dependency_image.sh index 7ffdee048a..a1b0cfeaf6 100644 --- a/docker_build_dependency_image.sh +++ b/docker_build_dependency_image.sh @@ -22,7 +22,7 @@ # Enable "exit immediately if any command fails" option set -e -export LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME +export LOCAL_IMAGE_NAME=maxtext_base_image echo "Starting to build your docker image. This will take a few minutes but the image can be reused as you iterate." From 7888f6acd4c3d17cf0e35310445ce300a0276646 Mon Sep 17 00:00:00 2001 From: yangyuwei Date: Wed, 28 Feb 2024 16:59:11 -0800 Subject: [PATCH 03/10] Refactor the code for MaxText training on H100 GPUs via XPK. --- gke/gpu/maxtext_chart/Chart.yaml | 24 --- gke/gpu/maxtext_chart/templates/maxtext.yaml | 144 ------------------ gke/gpu/maxtext_chart/values.yaml | 11 -- ...rt_training.sh => gpu_multi_process_run.sh | 4 +- 4 files changed, 2 insertions(+), 181 deletions(-) delete mode 100644 gke/gpu/maxtext_chart/Chart.yaml delete mode 100644 gke/gpu/maxtext_chart/templates/maxtext.yaml delete mode 100644 gke/gpu/maxtext_chart/values.yaml rename gke/gpu/start_training.sh => gpu_multi_process_run.sh (95%) diff --git a/gke/gpu/maxtext_chart/Chart.yaml b/gke/gpu/maxtext_chart/Chart.yaml deleted file mode 100644 index bbe13883f3..0000000000 --- a/gke/gpu/maxtext_chart/Chart.yaml +++ /dev/null @@ -1,24 +0,0 @@ -apiVersion: v2 -name: maxtext_chart -description: A Helm chart for Maxtext GPU workload - -# A chart can be either an 'application' or a 'library' chart. -# -# Application charts are a collection of templates that can be packaged into versioned archives -# to be deployed. -# -# Library charts provide useful utilities or functions for the chart developer. They're included as -# a dependency of application charts to inject those utilities and functions into the rendering -# pipeline. Library charts do not define any templates and therefore cannot be deployed. -type: application - -# This is the chart version. This version number should be incremented each time you make changes -# to the chart and its templates, including the app version. -# Versions are expected to follow Semantic Versioning (https://semver.org/) -version: 0.1.0 - -# This is the version number of the application being deployed. This version number should be -# incremented each time you make changes to the application. Versions are not expected to -# follow Semantic Versioning. They should reflect the version the application is using. -# It is recommended to use it with quotes. -appVersion: "1.16.0" diff --git a/gke/gpu/maxtext_chart/templates/maxtext.yaml b/gke/gpu/maxtext_chart/templates/maxtext.yaml deleted file mode 100644 index 5ab88f38cb..0000000000 --- a/gke/gpu/maxtext_chart/templates/maxtext.yaml +++ /dev/null @@ -1,144 +0,0 @@ -{{- $requiredVar := .Values.cluster.nNodes | required ".Values.cluster.nNodes is required" -}} -{{- $requiredVar := .Values.cluster.nodePool | required ".Values.cluster.nodePool is required" -}} -{{- $requiredVar := .Values.workload.image | required ".Values.image is required" -}} -apiVersion: v1 -kind: Service -metadata: - name: "maxtext-leader-{{$.Release.Name}}" -spec: - selector: - name: "maxtext-leader-{{$.Release.Name}}" - clusterIP: None - ports: - - name: maxtext-leader - port: 6002 ---- -{{$node_count := .Values.cluster.nNodes | int}} -# This needs to be updated to allow uneven distribution of nodes to SBs -{{- $root := . -}} -{{range $node_index, $element := until $node_count}} -apiVersion: v1 -kind: Pod -metadata: - name: maxtext-{{$.Release.Name}}-pod{{$node_index}} - {{if eq $node_index 0}} - labels: - name: maxtext-leader-{{$.Release.Name}} - {{end}} -spec: - hostNetwork: true - dnsPolicy: ClusterFirstWithHostNet - hostname: maxtext-pod{{$node_index}} - subdomain: maxtext-{{$.Release.Name}} - serviceAccountName: "default" - restartPolicy: Never - affinity: - nodeAffinity: - requiredDuringSchedulingIgnoredDuringExecution: - nodeSelectorTerms: - - matchExpressions: - - key: cloud.google.com/gke-accelerator - operator: Exists - - key: cloud.google.com/gke-nodepool - operator: In - values: [{{$.Values.cluster.nodePool}}] - tolerations: - - operator: "Exists" - key: nvidia.com/gpu - volumes: - - name: nvidia-install-dir-host - hostPath: - path: /home/kubernetes/bin/nvidia/lib64 - - name: tcpd-socket - hostPath: - path: /run/tcpx - - name: shared-memory - emptyDir: - medium: "Memory" - sizeLimit: 200Gi - - name: workload-terminated-volume - emptyDir: {} - - name: tcpx-nccl-plugin-volume - emptyDir: {} - {{if eq $root.Values.network.useTcpx "yes"}} - initContainers: - - name: tcpx-nccl-plugin-installer - image: {{$root.Values.network.ncclPlugin}} - imagePullPolicy: Always - volumeMounts: - - name: tcpx-nccl-plugin-volume - mountPath: /var/lib/tcpx - resources: - requests: - cpu: 150m - command: - - /bin/sh - - -c - - | - /scripts/container_entry.sh install --install-nccl - {{end}} - containers: - {{if eq $root.Values.network.useTcpx "yes"}} - - name: tcpd-daemon - image: {{$root.Values.network.rxdmContainer}} - imagePullPolicy: Always - command: - - "bash" - - "-c" - - | - /tcpgpudmarxd/build/app/tcpgpudmarxd --gpu_nic_preset a3vm --gpu_shmem_type fd --setup_param "--verbose 128 2 0" & - while [ ! -e "/usr/share/maxtext/workload_terminated" ]; do sleep 10; echo "sleeping"; done - securityContext: - privileged: true - volumeMounts: - - name: nvidia-install-dir-host - mountPath: /usr/local/nvidia/lib64 - - name: tcpd-socket - mountPath: /tmp - - name: workload-terminated-volume - mountPath: /usr/share/maxtext - env: - - name: LD_LIBRARY_PATH - value: /usr/local/nvidia/lib64 - {{end}} - - name: maxtext - image: {{$root.Values.workload.image}} - imagePullPolicy: Always - securityContext: - privileged: true - env: - - name: NNODES - value: "{{$node_count}}" - - name: NODE_RANK - value: "{{ $node_index }}" - - name: USE_GPUDIRECT_TCPX - value: "{{$root.Values.network.useTcpx}}" - - name: GPUS_PER_NODE - value: "8" - - name: JAX_COORDINATOR_ADDRESS - value: "maxtext-leader-{{$.Release.Name}}" - - name: JAX_COORDINATOR_PORT - value: "{{$root.Values.workload.port}}" - - name: RUN_NAME - value: "{{$root.Values.workload.runName}}" - - name: LD_LIBRARY_PATH - value: /usr/local/nvidia/lib64 - volumeMounts: - - name: nvidia-install-dir-host - mountPath: /usr/local/nvidia/lib64 - - name: tcpx-nccl-plugin-volume - mountPath: /usr/local/tcpx - - name: tcpd-socket - mountPath: /tmp - - name: shared-memory - mountPath: /dev/shm - resources: - limits: - nvidia.com/gpu: !!int 8 - command: - - /bin/sh - - -c - - | - cd /deps && bash gke/gpu/start_training.sh ---- -{{end}} \ No newline at end of file diff --git a/gke/gpu/maxtext_chart/values.yaml b/gke/gpu/maxtext_chart/values.yaml deleted file mode 100644 index beaa962729..0000000000 --- a/gke/gpu/maxtext_chart/values.yaml +++ /dev/null @@ -1,11 +0,0 @@ -cluster: - nNodes: 2 # Configure the number of nodes - nodePool: "" # Configure NodePool Information -network: - useTcpx: "yes" - ncclPlugin: us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpx/nccl-plugin-gpudirecttcpx-dev:v3.1.6_2023_10_06 - rxdmContainer: us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpx/tcpgpudmarxd-dev:v2.0.9 -workload: - image: "" # Configure Image Name - port: 6002 - runName: "" # Configure Run Name diff --git a/gke/gpu/start_training.sh b/gpu_multi_process_run.sh similarity index 95% rename from gke/gpu/start_training.sh rename to gpu_multi_process_run.sh index 8efbbd94e7..d968dd80a8 100644 --- a/gke/gpu/start_training.sh +++ b/gpu_multi_process_run.sh @@ -8,7 +8,7 @@ set -o pipefail : "${JAX_COORDINATOR_PORT:?Must set JAX_COORDINATOR_PORT}" : "${JAX_COORDINATOR_ADDRESS:?Must set JAX_COORDINATOR_ADDRESS}" : "${GPUS_PER_NODE:?Must set GPUS_PER_NODE}" -: "${RUN_NAME:?Must set RUN_NAME}" +: "${COMMAND:?Must set COMMAND}" export GPUS_PER_NODE=$GPUS_PER_NODE @@ -123,7 +123,7 @@ set -e PIDS=() for ((LOCAL_DEVICE_ID=0; LOCAL_DEVICE_ID <= $((GPUS_PER_NODE - 1)); LOCAL_DEVICE_ID++)); do PROCESS_ID=$(($GPUS_PER_NODE*$NODE_RANK + $LOCAL_DEVICE_ID)) - LOCAL_DEVICE_ID=$LOCAL_DEVICE_ID PROCESS_ID=$PROCESS_ID python MaxText/train.py MaxText/configs/base.yml hardware=gpu run_name=${RUN_NAME}_$(date +%Y-%m-%d-%H-%M) ${ARGS} & + LOCAL_DEVICE_ID=$LOCAL_DEVICE_ID PROCESS_ID=$PROCESS_ID ${COMMAND} & PID=$! PIDS+=($PID) echo "Launched MaxText/train.py for local_device_id: $LOCAL_DEVICE_ID process_id: $PROCESS_ID and PID $PID" From fc0f9cf85913594bfcd22d843a0c495b5f0163ca Mon Sep 17 00:00:00 2001 From: yangyuwei Date: Wed, 13 Mar 2024 11:20:25 -0700 Subject: [PATCH 04/10] Fix the issue of creating 8 processes per GPU. --- MaxText/max_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 927aff4bd7..aef441df44 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -196,10 +196,8 @@ def initialize_jax_for_gpu(): coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT")) jax.distributed.initialize( coordinator_address=f"{coordinator_ip}:{coordinator_port}", - num_processes=int(os.getenv("JAX_NUM_PROCESSES")), - process_id=int(os.getenv("PROCESS_ID")), - local_device_ids=int(os.getenv("LOCAL_DEVICE_ID")), - ) + num_processes=int(os.getenv("NNODES")), + process_id=int(os.getenv("NODE_RANK"))) max_logging.log(f"JAX global devices: {jax.devices()}") def initialize_jax_for_cpu(): From 1046886607644b6191ed50a0fb8dce0b45c45693 Mon Sep 17 00:00:00 2001 From: yangyuwei Date: Wed, 13 Mar 2024 11:22:08 -0700 Subject: [PATCH 05/10] Fix the issue of creating 8 processes per GPU for gpu_multi_process_run.sh --- gpu_multi_process_run.sh | 43 ++++++++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/gpu_multi_process_run.sh b/gpu_multi_process_run.sh index d968dd80a8..f496be3a3e 100644 --- a/gpu_multi_process_run.sh +++ b/gpu_multi_process_run.sh @@ -14,10 +14,9 @@ set -o pipefail export GPUS_PER_NODE=$GPUS_PER_NODE export JAX_COORDINATOR_PORT=$JAX_COORDINATOR_PORT export JAX_COORDINATOR_ADDRESS=$JAX_COORDINATOR_ADDRESS -export JAX_NUM_PROCESSES=$((NNODES * GPUS_PER_NODE)) set_nccl_gpudirect_tcpx_specific_configuration() { - if [[ "$USE_GPUDIRECT_TCPX" == "yes" ]]; then + if [[ "$USE_GPUDIRECT" == "tcpx" ]]; then echo "Using GPUDirect-TCPX" export NCCL_CROSS_NIC=0 export NCCL_ALGO=Ring @@ -44,8 +43,33 @@ set_nccl_gpudirect_tcpx_specific_configuration() { export NCCL_GPUDIRECTTCPX_SOCKET_IFNAME=eth1,eth2,eth3,eth4 export NCCL_GPUDIRECTTCPX_CTRL_DEV=eth0 export NCCL_NVLS_ENABLE=0 + elif [[ "$USE_GPUDIRECT" == "fastrak" ]]; then + echo "Using GPUDirect-TCPFasTrak" + export NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV,TUNING,NET,VERSION + export NCCL_DEBUG=INFO + export NCCL_FASTRAK_ENABLE_HOTPATH_LOGGING=0 + export LD_LIBRARY_PATH="/usr/local/fastrak/lib64:${LD_LIBRARY_PATH}" + export NCCL_FASTRAK_CTRL_DEV=eth0 + export NCCL_FASTRAK_IFNAME=eth1,eth2,eth3,eth4,eth5,eth6,eth7,eth8 + export NCCL_SOCKET_IFNAME=eth0 + export NCCL_CROSS_NIC=0 + export NCCL_ALGO=Ring + export NCCL_PROTO=Simple + export NCCL_MAX_NCHANNELS=16 + export NCCL_MIN_NCHANNELS=16 + export NCCL_SOCKET_NTHREADS=4 + export NCCL_DYNAMIC_CHUNK_SIZE=524288 + export NCCL_DYNAMIC_CHUNK_SIZE=524288 + export NCCL_P2P_NET_CHUNKSIZE=524288 + export NCCL_P2P_PCI_CHUNKSIZE=524288 + export NCCL_P2P_NVL_CHUNKSIZE=1048576 + export NCCL_FASTRAK_NUM_FLOWS=8 + export NCCL_FASTRAK_FLOWS_PER_GROUP=2 + export NCCL_BUFFSIZE=4194304 + export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + export NCCL_NET_GDR_LEVEL=PIX else - echo "NOT using TCPX" + echo "NOT using GPUDirect" fi } @@ -89,7 +113,7 @@ non_blocking_wait() { resolve_coordinator_ip() { local lookup_attempt=1 - local max_coordinator_lookups=10 + local max_coordinator_lookups=500 local coordinator_found=false local coordinator_ip_address="" @@ -121,12 +145,9 @@ resolve_coordinator_ip set -e PIDS=() -for ((LOCAL_DEVICE_ID=0; LOCAL_DEVICE_ID <= $((GPUS_PER_NODE - 1)); LOCAL_DEVICE_ID++)); do - PROCESS_ID=$(($GPUS_PER_NODE*$NODE_RANK + $LOCAL_DEVICE_ID)) - LOCAL_DEVICE_ID=$LOCAL_DEVICE_ID PROCESS_ID=$PROCESS_ID ${COMMAND} & - PID=$! - PIDS+=($PID) - echo "Launched MaxText/train.py for local_device_id: $LOCAL_DEVICE_ID process_id: $PROCESS_ID and PID $PID" -done +${COMMAND} & +PID=$! +PIDS+=($PID) wait_all_success_or_exit "${PIDS[@]}" + From b2590421c347a64c1bb5a0bb674502ee4e303f5d Mon Sep 17 00:00:00 2001 From: In-Ho Yi Date: Wed, 13 Mar 2024 21:52:55 +0000 Subject: [PATCH 06/10] Constraint dependency versions Workaround of https://github.com/google/maxtext/issues/516 Also pin other dependencies for mostly reproducible container build --- .dockerignore | 1 + constraints.txt | 133 ++++++++++++++++++++++++++++ maxtext_gpu_dependencies.Dockerfile | 2 +- requirements.txt | 5 +- setup.sh | 9 +- 5 files changed, 145 insertions(+), 5 deletions(-) create mode 120000 .dockerignore create mode 100644 constraints.txt diff --git a/.dockerignore b/.dockerignore new file mode 120000 index 0000000000..3e4e48b0b5 --- /dev/null +++ b/.dockerignore @@ -0,0 +1 @@ +.gitignore \ No newline at end of file diff --git a/constraints.txt b/constraints.txt new file mode 100644 index 0000000000..4f9a3edc1c --- /dev/null +++ b/constraints.txt @@ -0,0 +1,133 @@ +absl-py==1.4.0 +aqtp==0.6.1 +array-record==0.5.0 +astroid==3.1.0 +astunparse==1.6.3 +attrs==23.2.0 +cachetools==5.3.3 +certifi==2024.2.2 +charset-normalizer==3.3.2 +chex==0.1.85 +click==8.1.7 +cloud-tpu-diagnostics==0.1.5 +cloudpickle==3.0.0 +contextlib2==21.6.0 +dill==0.3.8 +dm-tree==0.1.8 +etils==1.7.0 +exceptiongroup==1.2.0 +flatbuffers==24.3.7 +flax==0.8.1 +fsspec==2024.2.0 +gast==0.4.0 +google-api-core==2.17.1 +google-auth==2.28.2 +google-auth-oauthlib==1.0.0 +google-cloud-core==2.4.1 +google-cloud-storage==2.15.0 +google-crc32c==1.5.0 +google-pasta==0.2.0 +google-resumable-media==2.7.0 +googleapis-common-protos==1.63.0 +grain-nightly==0.0.6 +grpcio==1.62.1 +gviz-api==1.10.0 +h5py==3.10.0 +idna==3.6 +immutabledict==4.2.0 +importlab==0.8.1 +importlib_resources==6.3.0 +iniconfig==2.0.0 +isort==5.13.2 +jax==0.4.25 +jaxlib==0.4.25 +jaxtyping==0.2.28 +Jinja2==3.1.3 +keras==2.13.1 +libclang==16.0.6 +libcst==1.2.0 +Markdown==3.5.2 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +mccabe==0.7.0 +mdurl==0.1.2 +ml-collections==0.1.1 +ml-dtypes==0.3.2 +mlperf-logging==3.0.0 +more-itertools==10.2.0 +msgpack==1.0.8 +msgspec==0.18.6 +mypy-extensions==1.0.0 +nest-asyncio==1.6.0 +networkx==3.1 +ninja==1.11.1.1 +numpy==1.24.3 +nvidia-cublas-cu12==12.4.2.65 +nvidia-cuda-cupti-cu12==12.4.99 +nvidia-cuda-nvcc-cu12==12.4.99 +nvidia-cuda-nvrtc-cu12==12.4.99 +nvidia-cuda-runtime-cu12==12.4.99 +nvidia-cudnn-cu12==8.9.7.29 +nvidia-cufft-cu12==11.2.0.44 +nvidia-cusolver-cu12==11.6.0.99 +nvidia-cusparse-cu12==12.3.0.142 +nvidia-nccl-cu12==2.19.3 +nvidia-nvjitlink-cu12==12.4.99 +oauthlib==3.2.2 +opt-einsum==3.3.0 +optax==0.2.1 +orbax-checkpoint==0.5.5 +packaging==24.0 +pandas==2.2.1 +platformdirs==4.2.0 +pluggy==1.4.0 +promise==2.3 +protobuf==3.20.3 +psutil==5.9.8 +pyasn1==0.5.1 +pyasn1-modules==0.3.0 +pycnite==2023.10.11 +pydot==2.0.0 +Pygments==2.17.2 +pylint==3.1.0 +pyparsing==3.1.2 +pytest==8.1.1 +python-dateutil==2.9.0.post0 +pytype==2024.3.11 +pytz==2024.1 +PyYAML==6.0.1 +requests==2.31.0 +requests-oauthlib==1.4.0 +rich==13.7.1 +rsa==4.9 +scipy==1.12.0 +sentencepiece==0.1.97 +six==1.16.0 +tabulate==0.9.0 +tensorboard==2.13.0 +tensorboard-data-server==0.7.2 +tensorboard_plugin_profile==2.15.1 +tensorboardX==2.6.2.2 +tensorflow==2.13.1 +tensorflow-datasets==4.9.4 +tensorflow-estimator==2.13.0 +tensorflow-hub==0.16.1 +tensorflow-io-gcs-filesystem==0.36.0 +tensorflow-metadata==1.14.0 +tensorflow-text==2.13.0 +tensorstore==0.1.54 +termcolor==2.4.0 +tf-keras==2.15.0 +toml==0.10.2 +tomli==2.0.1 +tomlkit==0.12.4 +toolz==0.12.1 +tqdm==4.66.2 +typeguard==2.13.3 +typing-inspect==0.9.0 +typing_extensions==4.5.0 +tzdata==2024.1 +urllib3==2.2.1 +Werkzeug==3.0.1 +wrapt==1.16.0 +zipp==3.18.0 diff --git a/maxtext_gpu_dependencies.Dockerfile b/maxtext_gpu_dependencies.Dockerfile index 242435b29c..30c920e211 100644 --- a/maxtext_gpu_dependencies.Dockerfile +++ b/maxtext_gpu_dependencies.Dockerfile @@ -1,4 +1,4 @@ -FROM ghcr.io/nvidia/jax:base +FROM ghcr.io/nvidia/jax:base-2024-03-13 # Install dependencies for adjusting network rto RUN apt-get update && apt-get install -y iproute2 ethtool lsof diff --git a/requirements.txt b/requirements.txt index 8193a84572..22a1827c15 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,8 +16,9 @@ pylint pytest pytype sentencepiece==0.1.97 -tensorflow-text>=2.13.0 -tensorflow>=2.13.0 +# Limit tf version pending investigation https://github.com/google/maxtext/issues/516. +tensorflow-text>=2.13.0,<2.15 +tensorflow>=2.13.0,<2.15 tensorflow-datasets tensorboardx tensorboard-plugin-profile diff --git a/setup.sh b/setup.sh index 1b41f9e534..2c3de73686 100644 --- a/setup.sh +++ b/setup.sh @@ -138,7 +138,7 @@ if [[ "$MODE" == "stable" || ! -v MODE ]]; then pip3 install -U "jax[cuda12_pip]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html else echo "Installing stable jax, jaxlib, libtpu for NVIDIA gpu" - pip3 install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + pip3 install --no-cache-dir "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -c constraints.txt fi fi elif [[ $MODE == "nightly" ]]; then @@ -205,4 +205,9 @@ else fi # Install dependencies from requirements.txt -cd $run_name_folder_path && pip install --upgrade pip && pip3 install -r requirements.txt +cd $run_name_folder_path && pip install --upgrade pip +if [[ $DEVICE == "gpu" ]] && [[ "$MODE" == "stable" || ! -v MODE ]] && [[ ! -v JAX_VERSION ]]; then + pip3 install --no-cache-dir -r requirements.txt -c constraints.txt +else + pip3 install -U -r requirements.txt +fi From 34374258b318df334ce8e202c111bb368d46cf2f Mon Sep 17 00:00:00 2001 From: In-Ho Yi Date: Thu, 14 Mar 2024 17:12:02 +0000 Subject: [PATCH 07/10] Allow NCCL_DEBUG override --- gpu_multi_process_run.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gpu_multi_process_run.sh b/gpu_multi_process_run.sh index f496be3a3e..2231e3ceed 100644 --- a/gpu_multi_process_run.sh +++ b/gpu_multi_process_run.sh @@ -21,7 +21,7 @@ set_nccl_gpudirect_tcpx_specific_configuration() { export NCCL_CROSS_NIC=0 export NCCL_ALGO=Ring export NCCL_PROTO=Simple - export NCCL_DEBUG=INFO + export NCCL_DEBUG=${NCCL_DEBUG:-INFO} export NCCL_NET_GDR_LEVEL=PIX export NCCL_P2P_PXN_LEVEL=0 export NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV,TUNING,NET,VERSION @@ -46,7 +46,7 @@ set_nccl_gpudirect_tcpx_specific_configuration() { elif [[ "$USE_GPUDIRECT" == "fastrak" ]]; then echo "Using GPUDirect-TCPFasTrak" export NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV,TUNING,NET,VERSION - export NCCL_DEBUG=INFO + export NCCL_DEBUG=${NCCL_DEBUG:-INFO} export NCCL_FASTRAK_ENABLE_HOTPATH_LOGGING=0 export LD_LIBRARY_PATH="/usr/local/fastrak/lib64:${LD_LIBRARY_PATH}" export NCCL_FASTRAK_CTRL_DEV=eth0 From ddc6b6e5e0cc152cf7978099c6f872b9bb7ccecb Mon Sep 17 00:00:00 2001 From: tejasnamagoogle Date: Thu, 14 Mar 2024 23:56:40 +0000 Subject: [PATCH 08/10] adding helm configs to run llama 7b --- gke/gpu/llama/Chart.yaml | 6 ++ gke/gpu/llama/templates/llama_70b.yaml | 131 +++++++++++++++++++++++++ gke/gpu/llama/values.yaml | 7 ++ 3 files changed, 144 insertions(+) create mode 100644 gke/gpu/llama/Chart.yaml create mode 100644 gke/gpu/llama/templates/llama_70b.yaml create mode 100644 gke/gpu/llama/values.yaml diff --git a/gke/gpu/llama/Chart.yaml b/gke/gpu/llama/Chart.yaml new file mode 100644 index 0000000000..120902dc1f --- /dev/null +++ b/gke/gpu/llama/Chart.yaml @@ -0,0 +1,6 @@ +apiVersion: v2 +name: llama70b +description: llama70b +type: application +version: 0.1.0 +appVersion: "1.16.0" \ No newline at end of file diff --git a/gke/gpu/llama/templates/llama_70b.yaml b/gke/gpu/llama/templates/llama_70b.yaml new file mode 100644 index 0000000000..c868d5af04 --- /dev/null +++ b/gke/gpu/llama/templates/llama_70b.yaml @@ -0,0 +1,131 @@ +{{- $root := . -}} +apiVersion: jobset.x-k8s.io/v1alpha2 +kind: JobSet +metadata: + name: llama70b-maxtext + labels: + xpk.google.com/workload: llama70b-maxtext +spec: + failurePolicy: + maxRestarts: 0 + replicatedJobs: + - name: slice-job + replicas: 1 + template: + spec: + parallelism: {{ $root.Values.workload.nodes }} + completions: {{ $root.Values.workload.nodes }} + backoffLimit: 0 # When any pod fails, the job is failed + template: + metadata: + labels: + xpk.google.com/workload: llama70b-maxtext + spec: + schedulerName: default-scheduler + restartPolicy: Never + affinity: + nodeAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + nodeSelectorTerms: + - matchExpressions: + - key: cloud.google.com/gke-accelerator + operator: Exists + - key: cloud.google.com/gke-nodepool + operator: In + values: [a3plus-multi-nic] + nodeSelector: + cloud.google.com/gke-accelerator: nvidia-h100-80gb + hostNetwork: true + dnsPolicy: ClusterFirstWithHostNet + terminationGracePeriodSeconds: 30 + tolerations: + - operator: "Exists" + key: nvidia.com/gpu + volumes: + - name: nvidia-install-dir-host + hostPath: + path: /home/kubernetes/bin/nvidia/lib64 + - name: shared-memory + emptyDir: + medium: "Memory" + sizeLimit: 1Gi + - name: workload-terminated-volume + emptyDir: + containers: + - name: fastrak-daemon + image: us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpxo/tcpgpudmarxd-dev:v1.0.3 + imagePullPolicy: Always + command: + - "bash" + - "-c" + - | + set -ex; chmod 755 /fts/entrypoint_rxdm_container.sh; /fts/entrypoint_rxdm_container.sh --num_hops=2 --num_nics=8 --uid= --alsologtostderr & + while [ ! -e "/usr/share/maxtext/workload_terminated" ]; do sleep 10; echo "sleeping"; done + args: + - | + set -ex + chmod 755 /fts/entrypoint_rxdm_container.sh + /fts/entrypoint_rxdm_container.sh --num_hops=2 --num_nics=8 --uid= --alsologtostderr + sleep 1000 + securityContext: + privileged: true + volumeMounts: + - name: nvidia-install-dir-host + mountPath: /usr/local/nvidia/lib64 + - name: workload-terminated-volume + mountPath: /usr/share/maxtext + env: + - name: LD_LIBRARY_PATH + value: /usr/local/nvidia/lib64 + - name: maxtext-fastrak + image: "{{ $root.Values.workload.image }}" + imagePullPolicy: Always + securityContext: + privileged: true + ports: + - containerPort: 6002 + env: + - name: REPLICATED_JOB_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name'] + - name: JOBSET_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name'] + - name: JAX_COORDINATOR_ADDRESS + value: "$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)" + - name: NNODES + value: "{{ $root.Values.workload.nodes }}" + - name: NODE_RANK + valueFrom: + fieldRef: + fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index'] + - name: USE_GPUDIRECT + value: "fastrak" + - name: GPUS_PER_NODE + value: "8" + - name: JAX_COORDINATOR_PORT + value: "6002" + - name: LD_LIBRARY_PATH + value: /usr/local/nvidia/lib64 + + - name: COMMAND + value: "python MaxText/train.py MaxText/configs/base.yml hardware=gpu run_name=2024-03-07-20-59 steps={{ $root.Values.workload.steps }} per_device_batch_size={{ $root.Values.workload.per_device_batch_size }} model_name={{ $root.Values.workload.model_name }} enable_checkpointing=false attention=dot_product dataset_type=synthetic async_checkpointing=false" + - name: XLA_FLAGS + value: {{ $root.Values.workload.xla_flags }} + command: + - "bash" + - "-c" + - | + echo XPK Start: $(date) ; _sigterm() ( kill -SIGTERM $!;); trap _sigterm SIGTERM; (cd /deps && bash gpu_multi_process_run.sh) & PID=$!; while kill -0 $PID 2>/dev/null; do sleep 5; done; EXIT_CODE=$? ; echo XPK End: $(date); echo EXIT_CODE=$EXIT_CODE; echo Main app is done > /usr/share/maxtext/workload_terminated + volumeMounts: + - name: nvidia-install-dir-host + mountPath: /usr/local/nvidia/lib64 + - name: shared-memory + mountPath: /dev/shm + - name: workload-terminated-volume + mountPath: /usr/share/maxtext + resources: + limits: + nvidia.com/gpu: 8 diff --git a/gke/gpu/llama/values.yaml b/gke/gpu/llama/values.yaml new file mode 100644 index 0000000000..c6c0d5d965 --- /dev/null +++ b/gke/gpu/llama/values.yaml @@ -0,0 +1,7 @@ +workload: + nodes: 16 + image: "us-central1-docker.pkg.dev/gce-ai-infra/maxtext/maxtext_base_image:03_14_2024_release" + steps: 30 + per_device_batch_size: 12 + model_name: llama2-7b + xla_flags: "--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false --xla_gpu_simplify_all_fp_conversions --xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization --xla_gpu_enable_async_collective_permute=true --xla_gpu_enable_async_all_to_all=true" \ No newline at end of file From 97181330d2a690d254965c10df81d1bf0831f9ae Mon Sep 17 00:00:00 2001 From: In-Ho Yi Date: Thu, 14 Mar 2024 17:44:03 +0000 Subject: [PATCH 09/10] Revert "Fix the issue of creating 8 processes per GPU." This reverts commit fc0f9cf85913594bfcd22d843a0c495b5f0163ca. Also partially reverts 1046886607644b6191ed50a0fb8dce0b45c45693 to keep launch script compatible --- MaxText/max_utils.py | 6 ++++-- gpu_multi_process_run.sh | 11 ++++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index aef441df44..927aff4bd7 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -196,8 +196,10 @@ def initialize_jax_for_gpu(): coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT")) jax.distributed.initialize( coordinator_address=f"{coordinator_ip}:{coordinator_port}", - num_processes=int(os.getenv("NNODES")), - process_id=int(os.getenv("NODE_RANK"))) + num_processes=int(os.getenv("JAX_NUM_PROCESSES")), + process_id=int(os.getenv("PROCESS_ID")), + local_device_ids=int(os.getenv("LOCAL_DEVICE_ID")), + ) max_logging.log(f"JAX global devices: {jax.devices()}") def initialize_jax_for_cpu(): diff --git a/gpu_multi_process_run.sh b/gpu_multi_process_run.sh index 2231e3ceed..df4816d5ce 100644 --- a/gpu_multi_process_run.sh +++ b/gpu_multi_process_run.sh @@ -14,6 +14,7 @@ set -o pipefail export GPUS_PER_NODE=$GPUS_PER_NODE export JAX_COORDINATOR_PORT=$JAX_COORDINATOR_PORT export JAX_COORDINATOR_ADDRESS=$JAX_COORDINATOR_ADDRESS +export JAX_NUM_PROCESSES=$((NNODES * GPUS_PER_NODE)) set_nccl_gpudirect_tcpx_specific_configuration() { if [[ "$USE_GPUDIRECT" == "tcpx" ]]; then @@ -145,9 +146,13 @@ resolve_coordinator_ip set -e PIDS=() -${COMMAND} & -PID=$! -PIDS+=($PID) +for ((LOCAL_DEVICE_ID=0; LOCAL_DEVICE_ID <= $((GPUS_PER_NODE - 1)); LOCAL_DEVICE_ID++)); do + PROCESS_ID=$(($GPUS_PER_NODE*$NODE_RANK + $LOCAL_DEVICE_ID)) + LOCAL_DEVICE_ID=$LOCAL_DEVICE_ID PROCESS_ID=$PROCESS_ID ${COMMAND} & + PID=$! + PIDS+=($PID) + echo "Launched MaxText/train.py for local_device_id: $LOCAL_DEVICE_ID process_id: $PROCESS_ID and PID $PID" +done wait_all_success_or_exit "${PIDS[@]}" From 5975656ce517b59992372663052837210a025125 Mon Sep 17 00:00:00 2001 From: In-Ho Yi Date: Tue, 26 Mar 2024 23:37:15 +0000 Subject: [PATCH 10/10] Allow overriding NCCL_FASTRAK_NUM_FLOWS --- gpu_multi_process_run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpu_multi_process_run.sh b/gpu_multi_process_run.sh index df4816d5ce..6abbe953e5 100644 --- a/gpu_multi_process_run.sh +++ b/gpu_multi_process_run.sh @@ -64,7 +64,7 @@ set_nccl_gpudirect_tcpx_specific_configuration() { export NCCL_P2P_NET_CHUNKSIZE=524288 export NCCL_P2P_PCI_CHUNKSIZE=524288 export NCCL_P2P_NVL_CHUNKSIZE=1048576 - export NCCL_FASTRAK_NUM_FLOWS=8 + export NCCL_FASTRAK_NUM_FLOWS=${NCCL_FASTRAK_NUM_FLOWS:-8} export NCCL_FASTRAK_FLOWS_PER_GROUP=2 export NCCL_BUFFSIZE=4194304 export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7