Skip to content

Commit 516082c

Browse files
committed
add bf16 master weights example
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
1 parent e676aa3 commit 516082c

File tree

7 files changed

+1020
-0
lines changed

7 files changed

+1020
-0
lines changed
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# BF16 Low-Precision Master Weights and Optimizer States
2+
3+
This example demonstrates DeepSpeed's new low-precision training options that can significantly reduce memory usage:
4+
5+
- `bf16_master_weights_and_grads`: Keep master parameters and gradients in BF16 instead of FP32
6+
- `bf16_optimizer_states`: Keep optimizer states (e.g., Adam moments) in BF16
7+
8+
These options work with ZeRO Stage 3 and `torch.autocast` to provide memory-efficient training while maintaining numerical stability.
9+
10+
## Memory Savings
11+
12+
Using a 254M parameter simple transformer model with the following configuration:
13+
- Hidden dimension: 1024
14+
- Layers: 12
15+
- Attention heads: 16
16+
- Batch size: 4
17+
- Sequence length: 512
18+
- ZeRO Stage: 3
19+
20+
### 1-GPU Results
21+
22+
| Configuration | Allocated Memory | Peak Memory | Avg Step Time |
23+
|---------------|------------------|-------------|---------------|
24+
| Baseline (fp32 master) | 4.14 GB | 5.93 GB | 0.1042s |
25+
| BF16 low-precision (master + opt states) | **2.71 GB** | 5.73 GB | 0.1121s |
26+
27+
**Allocated memory reduction: 1.43 GB (34.5%)**
28+
29+
### 4-GPU Results (per GPU) - 254M Model
30+
31+
| Configuration | Allocated Memory | Peak Memory | Avg Step Time |
32+
|---------------|------------------|-------------|---------------|
33+
| Baseline (fp32 master) | 1.29 GB | 3.57 GB | 0.1189s |
34+
| BF16 low-precision (master + opt states) | **0.94 GB** | 4.44 GB | 0.1249s |
35+
36+
**Allocated memory reduction: 0.35 GB per GPU (27%)**
37+
38+
### 4-GPU Results (per GPU) - 6.86B Model
39+
40+
Using a 6.86B parameter model (hidden=4096, layers=32, heads=32, batch=1, seq=512):
41+
42+
| Configuration | Allocated Memory | Peak Memory | Avg Step Time |
43+
|---------------|------------------|-------------|---------------|
44+
| Baseline (fp32 master) | 25.74 GB | 41.28 GB | 0.5078s |
45+
| BF16 low-precision (master + opt states) | **16.17 GB** | **33.20 GB** | 0.5064s |
46+
47+
**Memory reduction: 9.57 GB allocated (37%), 8.08 GB peak (19.6%)**
48+
49+
### 4-GPU Results (per GPU) - 6.86B Model with Activation Checkpointing
50+
51+
With activation checkpointing enabled, the optimizer state memory becomes the dominant factor, making the savings even more visible:
52+
53+
| Configuration | Allocated Memory | Peak Memory | Avg Step Time |
54+
|---------------|------------------|-------------|---------------|
55+
| Baseline (fp32 master) | 25.74 GB | 31.38 GB | 0.6016s |
56+
| BF16 low-precision (master + opt states) | **16.17 GB** | **18.93 GB** | 0.6427s |
57+
58+
**Memory reduction: 9.57 GB allocated (37%), 12.45 GB peak (39.7%)**
59+
60+
With activation checkpointing, peak memory drops significantly for both configurations, but the bf16 low-precision option shows an even larger relative improvement - nearly **40% reduction in peak memory**.
61+
62+
The allocated memory reflects the optimizer state memory, which is where the low-precision options provide savings. Peak memory includes activations and temporary buffers which can vary based on execution order.
63+
64+
## Loss Curve Comparison
65+
66+
To verify that BF16 low-precision training maintains numerical stability, we trained for 1000 steps on the Wikitext-103 dataset:
67+
68+
![Loss Comparison](logs/7b_loss_run/loss_comparison.png)
69+
70+
| Configuration | Final Loss | Mean Loss | Loss Std |
71+
|---------------|------------|-----------|----------|
72+
| Baseline (fp32 master) | 3.09 | 2.78 | 1.56 |
73+
| BF16 Low-Precision | 3.12 | 2.90 | 2.37 |
74+
75+
The loss curves show that both configurations converge similarly, demonstrating that the reduced precision does not significantly impact training quality while providing substantial memory savings.
76+
77+
To reproduce the loss curve comparison:
78+
79+
```bash
80+
# Run 1000 steps with wikitext dataset
81+
deepspeed --num_gpus=4 train.py --deepspeed_config configs/baseline.json \
82+
--num_layers 32 --hidden_dim 4096 --num_heads 32 --batch_size 1 \
83+
--num_steps 1000 --activation_checkpointing \
84+
--loss_log_file logs/baseline_loss.csv --use_real_data --seed 42
85+
86+
deepspeed --num_gpus=4 train.py --deepspeed_config configs/bf16_full.json \
87+
--num_layers 32 --hidden_dim 4096 --num_heads 32 --batch_size 1 \
88+
--num_steps 1000 --activation_checkpointing \
89+
--loss_log_file logs/bf16_full_loss.csv --use_real_data --seed 42
90+
91+
# Generate comparison plot
92+
python plot_loss.py --baseline logs/baseline_loss.csv --bf16 logs/bf16_full_loss.csv \
93+
--output loss_comparison.png
94+
```
95+
96+
## Configuration
97+
98+
### Baseline (FP32 master weights and optimizer states)
99+
100+
```json
101+
{
102+
"bf16": {
103+
"enabled": true
104+
},
105+
"zero_optimization": {
106+
"stage": 3
107+
}
108+
}
109+
```
110+
111+
### BF16 Low-Precision (BF16 master weights, gradients, and optimizer states)
112+
113+
```json
114+
{
115+
"bf16": {
116+
"enabled": true,
117+
"bf16_master_weights_and_grads": true,
118+
"bf16_optimizer_states": true
119+
},
120+
"zero_optimization": {
121+
"stage": 3
122+
},
123+
"torch_autocast": {
124+
"enabled": true,
125+
"dtype": "torch.bfloat16"
126+
}
127+
}
128+
```
129+
130+
## Usage
131+
132+
### Run Individual Configurations
133+
134+
```bash
135+
# Run baseline configuration
136+
deepspeed --num_gpus=1 train.py --deepspeed_config configs/baseline.json
137+
138+
# Run BF16 low-precision configuration
139+
deepspeed --num_gpus=1 train.py --deepspeed_config configs/bf16_full.json
140+
```
141+
142+
### Run Memory Comparison
143+
144+
```bash
145+
# Run both configurations and generate comparison report
146+
./run_comparison.sh
147+
148+
# With custom settings
149+
./run_comparison.sh --num_layers 24 --hidden_dim 2048 --batch_size 2
150+
```
151+
152+
### Gather Results from Logs
153+
154+
```bash
155+
python gather_memory.py --log_dir logs/<timestamp>
156+
```
157+
158+
## Training Script Options
159+
160+
```
161+
--hidden_dim Hidden dimension size (default: 1024)
162+
--num_layers Number of transformer layers (default: 12)
163+
--num_heads Number of attention heads (default: 16)
164+
--vocab_size Vocabulary size (default: 50000)
165+
--batch_size Batch size per GPU (default: 4)
166+
--seq_length Sequence length (default: 512)
167+
--num_steps Number of training steps (default: 20)
168+
--warmup_steps Warmup steps before measuring (default: 5)
169+
--deepspeed_config Path to DeepSpeed config file
170+
```
171+
172+
## Requirements
173+
174+
- DeepSpeed with BF16 support
175+
- PyTorch with BF16 support
176+
- GPU with BF16 support (e.g., NVIDIA Ampere or newer)
177+
178+
## How It Works
179+
180+
### Standard BF16 Training (Baseline)
181+
182+
In standard BF16 training with DeepSpeed:
183+
- Model parameters are stored in BF16
184+
- Forward/backward computations use BF16 via `torch.autocast`
185+
- Master weights are maintained in FP32 for optimizer updates
186+
- Optimizer states (Adam momentum and variance) are in FP32
187+
188+
This requires significant memory for the FP32 copies.
189+
190+
### BF16 Low-Precision Training
191+
192+
With the new options enabled:
193+
- `bf16_master_weights_and_grads=true`: Master weights and gradients stay in BF16
194+
- `bf16_optimizer_states=true`: Adam momentum and variance buffers use BF16
195+
196+
This eliminates the FP32 copies, reducing memory by approximately 2 bytes per parameter for master weights and 4 bytes per parameter for optimizer states (for Adam which has 2 state buffers).
197+
198+
### Memory Breakdown
199+
200+
For a model with N parameters:
201+
202+
| Component | Baseline | BF16 Low-Precision |
203+
|-----------|----------|-------------------|
204+
| Model params | 2N bytes (BF16) | 2N bytes (BF16) |
205+
| Master weights | 4N bytes (FP32) | 2N bytes (BF16) |
206+
| Gradients | 4N bytes (FP32) | 2N bytes (BF16) |
207+
| Adam momentum | 4N bytes (FP32) | 2N bytes (BF16) |
208+
| Adam variance | 4N bytes (FP32) | 2N bytes (BF16) |
209+
| **Total** | **18N bytes** | **10N bytes** |
210+
211+
This gives a theoretical ~44% reduction in optimizer-related memory. The actual savings depend on activation memory and other factors.
212+
213+
## Related Resources
214+
215+
- [DeepSpeed BF16 Documentation](https://www.deepspeed.ai/docs/config-json/#bf16-training-options)
216+
- [DeepSpeed Core API Updates Blog](../../blogs/core_api_update/README.md)
217+
- [Low-precision master params PR](https://github.com/deepspeedai/DeepSpeed/pull/7700)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
{
2+
"train_micro_batch_size_per_gpu": 4,
3+
"gradient_accumulation_steps": 1,
4+
"steps_per_print": 100,
5+
6+
"optimizer": {
7+
"type": "AdamW",
8+
"params": {
9+
"lr": 1e-4,
10+
"betas": [0.9, 0.999],
11+
"eps": 1e-8,
12+
"weight_decay": 0.01
13+
}
14+
},
15+
16+
"bf16": {
17+
"enabled": true
18+
},
19+
20+
"zero_optimization": {
21+
"stage": 3,
22+
"overlap_comm": true,
23+
"contiguous_gradients": true,
24+
"reduce_bucket_size": 5e7,
25+
"stage3_param_persistence_threshold": 0
26+
},
27+
28+
"torch_autocast": {
29+
"enabled": false
30+
}
31+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"train_micro_batch_size_per_gpu": 4,
3+
"gradient_accumulation_steps": 1,
4+
"steps_per_print": 100,
5+
6+
"optimizer": {
7+
"type": "AdamW",
8+
"params": {
9+
"lr": 1e-4,
10+
"betas": [0.9, 0.999],
11+
"eps": 1e-8,
12+
"weight_decay": 0.01
13+
}
14+
},
15+
16+
"bf16": {
17+
"enabled": true,
18+
"bf16_master_weights_and_grads": true,
19+
"bf16_optimizer_states": true
20+
},
21+
22+
"zero_optimization": {
23+
"stage": 3,
24+
"overlap_comm": true,
25+
"contiguous_gradients": true,
26+
"reduce_bucket_size": 5e7,
27+
"stage3_param_persistence_threshold": 0
28+
},
29+
30+
"torch_autocast": {
31+
"enabled": true,
32+
"dtype": "torch.bfloat16"
33+
}
34+
}

0 commit comments

Comments
 (0)