diff --git a/.github/workflows/build_documentation.yaml b/.github/workflows/build_documentation.yaml
deleted file mode 100644
index 6925489ae..000000000
--- a/.github/workflows/build_documentation.yaml
+++ /dev/null
@@ -1,18 +0,0 @@
-name: Build the documentation
-
-on:
- pull_request:
- branches: [main]
-
-jobs:
- build:
- name: Build
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v4
- - uses: actions/setup-python@v4
- with:
- python-version: "3.10"
- - run: pip install -r requirements-doc.txt
- - name: Build the documentation
- run: mkdocs build
diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
new file mode 100644
index 000000000..49544317f
--- /dev/null
+++ b/.github/workflows/ci.yaml
@@ -0,0 +1,85 @@
+name: CI
+
+on:
+ schedule:
+ - cron: "0 10 * * *"
+ push:
+ branches:
+ - "main"
+ tags:
+ - "v*.*.*"
+ pull_request:
+ branches:
+ - "main"
+
+jobs:
+ test:
+ name: Test
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ - uses: actions/setup-python@v5
+ with:
+ python-version: '3.12'
+ cache: 'pip'
+
+ - name: Install dependencies
+ run: |
+ pip install "torch>=2.2.2"
+ FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
+
+ - name: Run tests
+ run: pytest .
+
+ docker:
+ name: Docker
+ runs-on: ubuntu-latest
+ needs: test
+ steps:
+ - name: Clean unused files
+ run: |
+ sudo rm -rf /usr/local/lib/android || true # will release about 10 GB
+ sudo rm -rf /usr/share/dotnet || true # will release about 20GB
+ sudo rm -rf /opt/ghc || true
+ sudo rm -rf /usr/local/.ghcup || true
+
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ - name: Docker meta
+ id: meta
+ uses: docker/metadata-action@v5
+ with:
+ images: |
+ ghcr.io/servicenow/fast-llm
+ tags: |
+ type=schedule
+ type=ref,event=branch
+ type=semver,pattern={{version}}
+ type=semver,pattern={{major}}.{{minor}}
+ type=semver,pattern={{major}}
+ type=sha
+ type=raw,value=latest,enabled={{github.ref == 'refs/heads/main'}}
+
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v3
+
+ - name: Login to GHCR
+ uses: docker/login-action@v3
+ with:
+ registry: ghcr.io
+ username: ${{ github.repository_owner }}
+ password: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Build and push
+ uses: docker/build-push-action@v6
+ with:
+ context: .
+ # push: ${{ github.event_name != 'pull_request' }}
+ push: true
+ tags: ${{ steps.meta.outputs.tags }}
+ labels: ${{ steps.meta.outputs.labels }}
+ cache-from: type=registry,ref=ghcr.io/servicenow/fast-llm:cache
+ cache-to: type=registry,ref=ghcr.io/servicenow/fast-llm:cache,mode=max
diff --git a/.github/workflows/deploy_documentation.yaml b/.github/workflows/deploy_documentation.yaml
deleted file mode 100644
index 92bc27524..000000000
--- a/.github/workflows/deploy_documentation.yaml
+++ /dev/null
@@ -1,28 +0,0 @@
-name: Publish the documentation
-
-on:
- push:
- branches:
- - main
-
-permissions:
- contents: write
-
-jobs:
- deploy:
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v4
- - uses: actions/setup-python@v4
- with:
- python-version: "3.10"
- - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- - uses: actions/cache@v3
- with:
- key: mkdocs-material-${{ env.cache_id }}
- path: .cache
- restore-keys: |
- mkdocs-material-
- - run: pip install -r requirements-doc.txt
- - name: Publish the documentation
- run: mkdocs gh-deploy --force
diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml
new file mode 100644
index 000000000..6122805fe
--- /dev/null
+++ b/.github/workflows/docs.yaml
@@ -0,0 +1,59 @@
+name: Documentation
+
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+ branches:
+ - main
+
+permissions:
+ contents: write
+
+jobs:
+ build:
+ name: Build
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v5
+ with:
+ python-version: "3.12"
+ cache: "pip"
+ - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
+ - uses: actions/cache@v4
+ with:
+ key: mkdocs-material-${{ env.cache_id }}
+ path: .cache
+ restore-keys: |
+ mkdocs-material-
+ - run: |
+ pip install "torch>=2.2.2"
+ FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
+ - name: Build the documentation
+ run: mkdocs build
+
+ deploy:
+ if: github.event_name == 'push'
+ name: Deploy
+ needs: build
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v5
+ with:
+ python-version: "3.12"
+ cache: "pip"
+ - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
+ - uses: actions/cache@v4
+ with:
+ key: mkdocs-material-${{ env.cache_id }}
+ path: .cache
+ restore-keys: |
+ mkdocs-material-
+ - run: |
+ pip install "torch>=2.2.2"
+ FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
+ - name: Publish the documentation
+ run: mkdocs gh-deploy --force --dirty
diff --git a/README.md b/README.md
new file mode 100644
index 000000000..c2324ad83
--- /dev/null
+++ b/README.md
@@ -0,0 +1,161 @@
+
+
+

+
+[![Docker][ci-badge]][ci-workflow]
+[![Documentation][docs-badge]][docs-workflow]
+[![License][license-badge]][license]
+
+*Accelerating your LLM training to full speed*
+
+Made with โค๏ธ by [ServiceNow Research][servicenow-research]
+
+
+
+## Overview
+
+Fast-LLM is a new open-source library for training large language models, built on [PyTorch][pytorch] and [Triton][triton]. It is extremely fast, scales to large clusters, supports a wide range of model architectures, and is easy to use. Unlike commercial frameworks like Megatron-LM, which are largely closed off and fragmented across forks, Fast-LLM is fully open-source and encourages community-driven development. Researchers can freely customize and optimize as needed, making it a flexible and hackable alternative that combines the speed of specialized tools with the openness of libraries like [Hugging Face Transformers][transformers].
+
+> [!NOTE]
+> Fast-LLM is not affiliated with Fast.AI, FastHTML, FastAPI, FastText, or other similarly named projects. Our library's name refers to its speed and efficiency in language model training.
+
+## Why Fast-LLM?
+
+1. ๐ **Fast-LLM is Blazingly Fast**:
+ - โก๏ธ Optimized kernel efficiency and reduced overheads.
+ - ๐ Optimized memory usage for best performance.
+ - โณ Minimizes training time and cost.
+
+2. ๐ **Fast-LLM is Highly Scalable**:
+ - ๐ก Distributed training across multiple GPUs and nodes using 3D parallelism (Data, Tensor, and Pipeline).
+ - ๐ Supports sequence length parallelism to handle longer sequences effectively.
+ - ๐ง ZeRO-1, ZeRO-2, and ZeRO-3 implementations for improved memory efficiency.
+ - ๐๏ธ Mixed precision training support for better performance.
+ - ๐๏ธโโ๏ธ Large batch training and gradient accumulation support.
+ - ๐ Reproducible training with deterministic behavior.
+
+3. ๐จ **Fast-LLM is Incredibly Flexible**:
+ - ๐ค Compatible with all common language model architectures in a unified class.
+ - โก Efficient dropless Mixture-of-Experts (MoE) implementation with SoTA performance.
+ - ๐งฉ Customizable language model architectures, data loaders, loss functions, and optimizers (in progress).
+ - ๐ค Seamless integration with [Hugging Face Transformers][transformers].
+
+4. ๐ฏ **Fast-LLM is Super Easy to Use**:
+ - ๐ฆ [Pre-built Docker images](https://github.com/ServiceNow/Fast-LLM/pkgs/container/fast-llm) for quick deployment.
+ - ๐ Simple YAML configuration for hassle-free setup.
+ - ๐ป Command-line interface for easy launches.
+ - ๐ Detailed logging and real-time monitoring features.
+ - ๐ Extensive [documentation][docs] and practical tutorials (in progress).
+
+5. ๐ **Fast-LLM is Truly Open Source**:
+ - โ๏ธ Licensed under [Apache 2.0][license] for maximum freedom to use Fast-LLM at work, in your projects, or for research.
+ - ๐ป Fully developed on GitHub with a public [roadmap][roadmap] and transparent [issue tracking][issues].
+ - ๐ค Contributions and collaboration are always welcome!
+
+## Usage
+
+We'll walk you through how to use Fast-LLM to train a large language model on a cluster with multiple nodes and GPUs. We'll show an example setup using a Slurm cluster and a Kubernetes cluster.
+
+For this demo, we will train a Mistral-7B model from scratch for 100 steps on random data. The config file `examples/mistral-4-node-benchmark.yaml` is pre-configured for a multi-node setup with 4 DGX nodes, each with 8 A100-80GB or H100-80GB GPUs.
+
+> [!NOTE]
+> Fast-LLM scales from a single GPU to large clusters. You can start small and expand based on your resources.
+
+Expect to see a significant speedup in training time compared to other libraries! For training Mistral-7B, Fast-LLM is expected to achieve a throughput of **9,800 tokens/s/H100** (batch size 32, sequence length 8k) on a 4-node cluster with 32 H100s.
+
+### Running Fast-LLM on a Slurm Cluster
+
+#### Prerequisites
+
+- A [Slurm](https://slurm.schedmd.com/) cluster with at least 4 DGX nodes with 8 A100-80GB or H100-80GB GPUs each.
+- CUDA 12.1 or higher.
+- Dependencies: [PyTorch][pytorch], [Triton][triton], and [Apex](https://github.com/NVIDIA/apex) installed on all nodes.
+
+#### Steps
+
+1. Deploy the [nvcr.io/nvidia/pytorch:24.07-py3](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) Docker image to all nodes (recommended), because it contains all the necessary dependencies.
+2. Install Fast-LLM on all nodes:
+
+ ```bash
+ sbatch <
+
+## Reporting a Vulnerability
+
+To report a security vulnerability in Fast-LLM, please email our [Product Security Incident Response Team (PSIRT)](https://securitylab.servicenow.com) at [disclosure@servicenow.com](mailto:disclosure@servicenow.com). Include a detailed description of the issue, steps to reproduce it, and any relevant information that may help in investigating the matter.
+
+## Guidelines
+
+Please follow the guidelines below when [disclosing vulnerabilities](https://www.servicenow.com/company/trust/privacy/responsible-disclosure.html):
+
+- Report any potential security issue as soon as possible. ServiceNow will make every effort to quickly resolve the issue.
+- Provide sufficient detail to reproduce the vulnerability, including proof of concept. The use ofโฏReproNowโฏto demonstrate reproducibility is encouraged but not required.
+- Please do not disclose an issue to the public or any third party until ServiceNow has resolved it.
+- Make a good faith effort to avoid privacy violations, data destruction, and interruption or degradation of our services. Only interact with accounts youโฏown or have explicitโฏpermission from the account holder to access.
+- Redact any language or images that may identify the program or ServiceNow customers from information about a resolved vulnerability.
+- Do not engage in disruptive testing (such as Denial of Service attacks) or any action that could impact the confidentiality, integrity, or availability of information and systems.
+- Do not engage in social engineering or phishing against customers or employees.
+- Please do not request compensation for time, materials, or discovered vulnerabilities through the Responsible Disclosure Program.
diff --git a/examples/example_config.yaml b/examples/example_config.yaml
deleted file mode 100644
index c23d7c7b1..000000000
--- a/examples/example_config.yaml
+++ /dev/null
@@ -1,54 +0,0 @@
-run:
- experiment_dir: null
- wandb_post_alerts: false
-model:
- base_model:
- transformer:
- num_layers: 12
- hidden_size: 1024
- num_attention_heads: 8
- head_groups: 1
- add_linear_biases: true
- ffn_hidden_size: 4096
- kv_channels: 128
- activation_type: gelu
- init_method_std: 0.03125
- init_method_std_qkv: 0.03125
- init_method_std_attn_proj: 0.0063788795384978605
- init_method_std_mlp_1: 0.03125
- init_method_std_mlp_2: 0.0063788795384978605
- mlp_lr_scale:
- - null
- vocab_size: 49152
- use_position_embeddings: true
- tie_word_embeddings: true
- init_method_std_embed: 0.03125
- distributed:
- distributed_timeout: 60.0
- training_dtype: float32
-pretrained:
- pretrained_checkpoint_path: null
- pretrained_checkpoint_type: distributed
-batch:
- micro_batch_size: 1
- depth_first_micro_batches: 1
- breadth_first_micro_batches: 1
- sequential_micro_batches: 1
- batch_size: 1
- sequence_length: 2048
- micro_sequence_length: 2048
-data:
- split:
- - 969.0
- - 30.0
- - 1.0
- dataset_source: list
- data_path:
- - fkgtiu
- data_sample_warn_time_ms: 1000.0
-profiling:
- profile_cuda: false
- profile_ranks: []
-optimizer:
- weight_decay: 0.01
- initial_loss_scale: 65536.0
diff --git a/examples/fast-llm-pvc.yaml b/examples/fast-llm-pvc.yaml
new file mode 100644
index 000000000..b26e27eb8
--- /dev/null
+++ b/examples/fast-llm-pvc.yaml
@@ -0,0 +1,12 @@
+# Create persistent volume claim for Fast-LLM
+apiVersion: "v1"
+kind: "PersistentVolumeClaim"
+metadata:
+ name: "pvc-fast-llm-home"
+spec:
+ storageClassName: local-path
+ accessModes:
+ - ReadWriteMany
+ resources:
+ requests:
+ storage: 1000Gi
diff --git a/examples/fast-llm.pytorchjob.yaml b/examples/fast-llm.pytorchjob.yaml
new file mode 100644
index 000000000..9decff91f
--- /dev/null
+++ b/examples/fast-llm.pytorchjob.yaml
@@ -0,0 +1,127 @@
+apiVersion: "kubeflow.org/v1"
+kind: "PyTorchJob"
+metadata:
+ name: "fast-llm"
+spec:
+ nprocPerNode: "8"
+ pytorchReplicaSpecs:
+ Master:
+ replicas: 1
+ restartPolicy: Never
+ template:
+ spec:
+ tolerations:
+ - key: nvidia.com/gpu
+ value: "true"
+ operator: Equal
+ effect: NoSchedule
+ containers:
+ - name: pytorch
+ image: servicenowdocker/fast-llm:latest
+ resources:
+ limits:
+ nvidia.com/gpu: 8
+ rdma/rdma_shared_device_a: 1
+ memory: "1024Gi"
+ cpu:
+ requests:
+ nvidia.com/gpu: 8
+ rdma/rdma_shared_device_a: 1
+ memory: "1024Gi"
+ cpu: 128
+ command:
+ - /bin/bash
+ - -c
+ - |
+ torchrun --rdzv_backend=static \
+ --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
+ --node_rank=${RANK} \
+ --nproc_per_node=${PET_NPROC_PER_NODE} \
+ --nnodes=${PET_NNODES} \
+ --max_restarts=0 \
+ --rdzv_conf=timeout=3600 \
+ --no_python \
+ fast-llm train gpt \
+ --config examples/mistral-4-node-benchmark.yaml
+ env:
+ - name: NCCL_DEBUG
+ value: "INFO"
+ - name: PYTHONHASHSEED
+ value: "0"
+ securityContext:
+ capabilities:
+ add:
+ - IPC_LOCK
+ volumeMounts:
+ - mountPath: /home/fast-llm
+ name: fast-llm-home
+ - mountPath: /dev/shm
+ name: dshm
+ volumes:
+ - name: fast-llm-home
+ persistentVolumeClaim:
+ claimName: pvc-fast-llm-home
+ - name: dshm
+ emptyDir:
+ medium: Memory
+ sizeLimit: "1024Gi"
+ Worker:
+ replicas: 3
+ restartPolicy: Never
+ template:
+ spec:
+ tolerations:
+ - key: nvidia.com/gpu
+ value: "true"
+ operator: Equal
+ effect: NoSchedule
+ containers:
+ - name: pytorch
+ image: servicenowdocker/fast-llm:latest
+ resources:
+ limits:
+ nvidia.com/gpu: 8
+ rdma/rdma_shared_device_a: 1
+ memory: "1024Gi"
+ cpu:
+ requests:
+ nvidia.com/gpu: 8
+ rdma/rdma_shared_device_a: 1
+ memory: "1024Gi"
+ cpu: 128
+ command:
+ - /bin/bash
+ - -c
+ - |
+ torchrun --rdzv_backend=static \
+ --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
+ --node_rank=${RANK} \
+ --nproc_per_node=${PET_NPROC_PER_NODE} \
+ --nnodes=${PET_NNODES} \
+ --max_restarts=0 \
+ --rdzv_conf=timeout=3600 \
+ --no_python \
+ fast-llm train gpt \
+ --config examples/mistral-4-node-benchmark.yaml
+ env:
+ - name: NCCL_DEBUG
+ value: "INFO"
+ - name: PYTHONHASHSEED
+ value: "0"
+ securityContext:
+ capabilities:
+ add:
+ - IPC_LOCK
+ volumeMounts:
+ - mountPath: /home/fast-llm
+ name: fast-llm-home
+ - mountPath: /dev/shm
+ name: dshm
+ volumes:
+ - name: fast-llm-home
+ persistentVolumeClaim:
+ claimName: pvc-fast-llm-home
+ - name: dshm
+ emptyDir:
+ medium: Memory
+ sizeLimit: "1024Gi"
diff --git a/examples/fast-llm.sbat b/examples/fast-llm.sbat
new file mode 100644
index 000000000..13a966ec3
--- /dev/null
+++ b/examples/fast-llm.sbat
@@ -0,0 +1,37 @@
+#!/bin/bash
+#SBATCH --job-name=fast_llm_train
+#SBATCH --nodes=4
+#SBATCH --gpus-per-node=8
+#SBATCH --ntasks-per-node=1
+#SBATCH --exclusive
+#SBATCH --output=job_output.log
+#SBATCH --error=job_error.log
+
+MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
+MASTER_PORT=8001
+
+echo $MASTER_ADDR
+
+export NCCL_DEBUG=WARN
+export NCCL_SOCKET_IFNAME=eno1
+export UCX_TLS=self,shm,tcp
+export NCCL_NET_GDR_LEVEL=PIX
+export NCCL_IB_PCI_RELAXED_ORDERING=1
+export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
+export PYTHONHASHSEED=0
+export TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1
+
+srun --gpus-per-node=$SLURM_GPUS_PER_NODE \
+ --ntasks-per-node=$SLURM_NTASKS_PER_NODE \
+ bash -c "
+ torchrun --rdzv_backend=static \
+ --rdzv_id=0 \
+ --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
+ --node_rank=\$SLURM_NODEID \
+ --nproc_per_node=\$SLURM_GPUS_PER_NODE \
+ --nnodes=\$SLURM_NNODES \
+ --max_restarts=0 \
+ --rdzv_conf=timeout=3600 \
+ --no_python \
+ fast-llm train gpt \
+ --config examples/mistral_4_node_benchmark.yaml"
diff --git a/examples/mistral-4-node-benchmark.yaml b/examples/mistral-4-node-benchmark.yaml
new file mode 100644
index 000000000..99dd0ee78
--- /dev/null
+++ b/examples/mistral-4-node-benchmark.yaml
@@ -0,0 +1,55 @@
+training:
+ train_iters: 100
+ num_workers: 8
+ logs:
+ interval: 10
+ validation:
+ iterations: null
+ test_iters: 0
+batch:
+ sequence_length: 8192
+ micro_batch_size: 1
+ batch_size: 32
+data:
+ format: random
+ split: [1, 0, 0]
+optimizer:
+ learning_rate:
+ base: 1.0e-05
+ decay_style: constant
+ warmup_iterations: 0
+ weight_decay: 0.1
+ beta_1: 0.9
+ beta_2: 0.95
+model:
+ base_model:
+ transformer:
+ normalization:
+ type: rms_norm
+ epsilon: 1.0e-05
+ num_layers: 32
+ hidden_size: 4096
+ ffn_hidden_size: 14336
+ num_attention_heads: 32
+ head_groups: 8
+ add_linear_biases: false
+ use_rotary_embeddings: true
+ gated: true
+ activation_type: silu
+ triton_rotary: true
+ kv_channels: 128
+ rotary_embedding_scale: -9.210340371976184
+ window_size: 4096
+ init_method_std: 0.009021
+ attention_dropout: 0.0
+ hidden_dropout: 0.0
+ vocab_size: 32000
+ tie_word_embeddings: false
+ multi_stage:
+ zero_stage: 2
+ distributed:
+ training_dtype: bf16
+ distributed_timeout: 3600
+ seed: 984059
+run:
+ experiment_dir: mistral_4_nodes_benchmark
diff --git a/examples/train_mistral.sh b/examples/train_mistral.sh
deleted file mode 100644
index 5745e38c8..000000000
--- a/examples/train_mistral.sh
+++ /dev/null
@@ -1,128 +0,0 @@
-# Required or optional environment variables
-# export PROJECT_DIR=
-# export PROJECT_NAME=
-# export PROJECT_VERSION=
-# export DATA_PATH_LIST=
-# export DATA_PATH_JSON=
-# export PRETRAINED_MISTRAL_PATH=
-# export PRETRAINED_MIXTRAL_PATH=
-
-export CMD_ARGS="fast-llm train gpt"
-
-export MODEL_ARGS_PRETRAINED="\
---pretrained_checkpoint_type=huggingface \
---pretrained_checkpoint_path=$PRETRAINED_MISTRAL_PATH \
---use_pretrained_config=1 \
-"
-
-export MODEL_ARGS_ARCHITECTURE="\
---num_layers=32 \
---hidden_size=4096 \
---vocab_size=32000 \
---num_attention_heads=32 \
---head_groups=8 \
---add_linear_biases=0 \
---ffn_hidden_size=14336 \
---kv_channels=128 \
---use_rotary_embeddings=1 \
---rotary_embedding_scale=-9.210340371976184 \
---gated=1 \
---activation_type=silu \
---normalization_type=rms_norm \
---tie_word_embeddings=0 \
---window_size=4096 \
-"
-
-export DATA_ARGS_JSON="\
---split=9998,2,0 \
---dataset_source=file \
---data_path=$DATA_PATH_JSON \
-"
-
-export DATA_ARGS_LIST="\
---split=9998,2,0 \
---dataset_source=list \
---data_path=$DATA_PATH_DATA_ARGS_LIST \
-"
-
-export TRAINING_ARGS="\
---batch_size=128 \
---sequence_length=8192 \
---train_iters=500000 \
---weight_decay=0.1 \
---adam_beta1=0.9 \
---adam_beta2=0.95 \
---clip_grad=1.0 \
---lr=0.0001 \
---lr_warmup_iters=1000 \
---lr_decay_style=cosine \
---lr_decay_iters=500000 \
---min_lr=0.000003 \
-"
-
-export PERFORMANCE_ARGS="\
---micro_batch_size=1 \
---training_dtype=bf16 \
---zero_stage=2 \
---num_workers=8 \
-"
-
-export MONITORING_ARGS="\
---validation_iters=25 \
---validation_interval=1000 \
---log_interval=10 \
---log_offset=0 \
---checkpoint_interval=500 \
---max_checkpoints=5 \
---export_interval=25000 \
---wandb_status_interval=25000 \
---wandb_entity_name=$WANDB_ENTITY_NAME \
---wandb_project_name=$PROJECT_NAME \
---wandb_group_name=$PROJECT_VERSION \
-"
-
-export ALL_ARGS="\
-$CMD_ARGS \
-$MODEL_ARGS_PRETRAINED \
-$DATA_ARGS_LIST \
-$TRAINING_ARGS \
-$PERFORMANCE_ARGS \
-$MONITORING_ARGS \
-"
-
-export MODEL_ARGS_MIXTRAL_ARCHITECTURE="\
-$MODEL_ARGS_ARCHITECTURE \
---num_experts=8 \
---num_experts_per_token=2 \
-"
-
-export MIXTRAL_ARGS="\
---pretrained_checkpoint_path=$PRETRAINED_MIXTRAL_PATH \
---zero_stage=3 \
---mlp_recompute_level=activation \
-"
-
-export PROFILE_ARGS="\
---profile_cuda=1 \
---profile_skip=10 \
---profile_wait=95 \
---profile_warmup=2 \
---profile_cycles=3 \
---profile_export=1 \
-"
-
-
-run_local () { # run(name, num_gpus, base_cmd)
- echo $1 $2 $3
- export TORCHRUN="torchrun --nproc-per-node=$2 --nnodes=1 --no-python"
- $TORCHRUN $3 --experiment_dir=$PROJECT_DIR/$PROJECT_NAME_$PROJECT_VERSION/$1
-}
-
-run_c10d () { # run(name, num_nodes, base_cmd)
- echo $1 $2 $3
- export TORCHRUN="torchrun --nproc-per-node=8 --nnodes=$2 --no-python --rdzv-backend=c10d --rdzv-endpoint=$HOST_NODE_ADDR"
- $TORCHRUN $3 --experiment_dir=$PROJECT_DIR/$PROJECT_NAME_$PROJECT_VERSION/$1
-}
-
-run_c10d mistral_example 16 "$ALL_ARGS"
-# run_c10d mixtral_example 16 "$ALL_ARGS $MIXTRAL_ARGS --train_iters=50"
diff --git a/requirements-doc.txt b/requirements-doc.txt
deleted file mode 100644
index b7bd2efbb..000000000
--- a/requirements-doc.txt
+++ /dev/null
@@ -1,7 +0,0 @@
-mkdocs
-mkdocs-material
-mkdocs-material[imaging]
-mkdocs-section-index
-mkdocstrings[python]
-mkdocs-git-committers-plugin-2
-mkdocs-git-revision-date-localized-plugin
diff --git a/setup.cfg b/setup.cfg
index d5de782ea..55816ff49 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -15,7 +15,7 @@ install_requires =
CORE =
# Available through the nvidia base image
# Keeping an older min version because later ones have no x86 wheel for Mac OS
- torch >=2.2.2
+ torch>=2.2.2
# Numpy major needs to match torch
numpy>=1.24.4,<2.0.0
# Used for checkpoints
@@ -39,6 +39,15 @@ DEV =
pytest>=8.3.2
pytest-depends>=1.0.1
+# Required for building the documentation
+DOCS =
+ mkdocs
+ mkdocs-material
+ mkdocs-material[imaging]
+ mkdocs-section-index
+ mkdocstrings[python]
+ mkdocs-git-committers-plugin-2
+ mkdocs-git-revision-date-localized-plugin
[options.entry_points]
console_scripts =
diff --git a/tests/common.py b/tests/common.py
index 127dfb731..edc6d2111 100644
--- a/tests/common.py
+++ b/tests/common.py
@@ -164,6 +164,9 @@
TEST_MODEL_TYPE, CONFIG_FAST_LLM, CONFIG_GPT2, CONFIG_COMMON, HUGGINGFACE_MODEL_TYPE = _CONFIGS[TEST_MODEL]
+requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
+
+
def get_test_data():
if not TOKENIZER_FILE.is_file():
import transformers
diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py
index 50ba105fa..fd0c2c3cf 100644
--- a/tests/test_checkpoint.py
+++ b/tests/test_checkpoint.py
@@ -18,6 +18,7 @@
TEST_MODEL,
TEST_MODEL_TYPE,
TEST_RESULTS_PATH,
+ requires_cuda,
run_test_script,
)
from tests.compare_tensor_logs import CompareConfig, compare_logged_tensor
@@ -29,6 +30,7 @@
TEST_ARCHITECTURE_CONFIG_CLS = TEST_BASE_MODEL_CONFIG_CLS.architecture_cls
+@requires_cuda
@pytest.mark.depends()
def test_checkpoint_and_eval():
# A baseline config (single-gpu, bf16, flash-attn).
diff --git a/tests/test_config.py b/tests/test_config.py
index 840323d60..623106bf7 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -1,5 +1,7 @@
import pathlib
import subprocess
+import yaml
+from fast_llm.models.auto import trainer_registry
def test_validate_without_import():
@@ -26,3 +28,10 @@ def test_validate_without_import():
completed_proc = subprocess.run(command)
if completed_proc.returncode:
raise RuntimeError(f"Process failed with return code {completed_proc.returncode}")
+
+
+def test_validate_example_config():
+ fast_llm_config_dict = yaml.safe_load(
+ (pathlib.Path(__file__).parents[1] / "examples" / "mistral-4-node-benchmark.yaml").read_text()
+ )
+ trainer_registry["gpt"].from_dict(fast_llm_config_dict)
diff --git a/tests/test_functional.py b/tests/test_functional.py
index 4b6962164..531ebccbc 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -5,8 +5,10 @@
from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation
from fast_llm.functional.triton.sparse_copy import get_sparse_map
from fast_llm.utils import Assert
+from tests.common import requires_cuda
+@requires_cuda
@pytest.mark.parametrize("gated", [True, False])
@pytest.mark.parametrize(
"activation_type", [ActivationType.gelu, ActivationType.silu, ActivationType.relu, ActivationType.squared_relu]
@@ -62,6 +64,7 @@ def test_mlp_recomputation(gated, activation_type):
Assert.all_equal(param.grad_buffer, param_grad_ref)
+@requires_cuda
def test_dropless_mlp():
num_experts = 4
experts_per_token = 4
diff --git a/tests/test_triton_kernels.py b/tests/test_triton_kernels.py
index 5cb7d1bb7..3ad4605ab 100644
--- a/tests/test_triton_kernels.py
+++ b/tests/test_triton_kernels.py
@@ -24,8 +24,10 @@
from fast_llm.functional.triton.rotary import triton_rotary_
from fast_llm.functional.triton.sparse_copy import get_sparse_map
from fast_llm.utils import Assert, rms_diff
+from tests.common import requires_cuda
+@requires_cuda
def test_triton_fill():
assert TritonConfig.TRITON_ENABLED
x = torch.randn(425, 549, dtype=torch.bfloat16, device="cuda")
@@ -33,6 +35,7 @@ def test_triton_fill():
assert x.min().item() == x.max().item() == 32
+@requires_cuda
def test_triton_copy():
assert TritonConfig.TRITON_ENABLED
x = torch.randn(7563, dtype=torch.bfloat16, device="cuda")
@@ -44,6 +47,7 @@ def test_triton_copy():
Assert.all_equal(x, x1)
+@requires_cuda
def test_triton_copy_cast():
assert TritonConfig.TRITON_ENABLED
x = torch.randn(7563, dtype=torch.bfloat16, device="cuda")
@@ -55,6 +59,7 @@ def test_triton_copy_cast():
Assert.all_equal(x, x1)
+@requires_cuda
def test_triton_add():
assert TritonConfig.TRITON_ENABLED
x = torch.randn(8934, dtype=torch.float32, device="cuda")
@@ -69,6 +74,7 @@ def test_triton_add():
Assert.all_equal(y, y1)
+@requires_cuda
@pytest.mark.parametrize(
("batch_size", "sequence_length", "num_heads", "kv_channels"),
[(4, 1024, 8, 128), (1, 32, 1, 16), (2, 2048, 2, 192), (3, 519, 7, 134)],
@@ -90,6 +96,7 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels):
Assert.rms_close(y1, y2, 1e-3)
+@requires_cuda
@pytest.mark.parametrize("has_bias", [True, False])
@pytest.mark.parametrize("zero_centered", [True, False])
def test_triton_normalization(has_bias, zero_centered):
@@ -139,6 +146,7 @@ def test_triton_normalization(has_bias, zero_centered):
Assert.rms_close(bias_grad0, bias.grad, 1e-3)
+@requires_cuda
@pytest.mark.parametrize("gated", [True, False])
@pytest.mark.parametrize(
"activation_type", [ActivationType.gelu, ActivationType.silu, ActivationType.relu, ActivationType.squared_relu]
@@ -161,6 +169,7 @@ def test_triton_mlp_activation(gated, activation_type, recompute):
Assert.rms_close(output1, output3, 1e-5)
+@requires_cuda
def test_triton_cross_entropy():
assert TritonConfig.TRITON_ENABLED
logits = torch.randn(1024, 8192, dtype=torch.bfloat16, device="cuda", requires_grad=True)
@@ -181,6 +190,7 @@ def test_triton_cross_entropy():
Assert.rms_close(g2, g3, 1e-3)
+@requires_cuda
def test_triton_adam():
assert TritonConfig.TRITON_ENABLED
params = torch.randn(4576427, dtype=torch.float32, device="cuda")
@@ -238,6 +248,7 @@ def compare(i, j, fn, arg):
compare(0, 4, Assert.eq, 0)
+@requires_cuda
@pytest.mark.parametrize(
("num_rows_dense", "num_experts", "num_experts_per_token"),
[(2048, 8, 2), (2048, 6, 2), (2048, 8, 8), (256, 8, 2), (5627, 8, 2)],