🐞 Describe the Bug
I noticed checkpoint saving is suspiciously slow in some tests, so I decided to investigate.
Checkpoint saving should be bottlenecked by hardware (disk write speed), but turns out it's not. There seems to be something in torch save / safetensors that makes serialization slow.
On H100 NVMe, we should have more than 2 GiB/s of write speed, but torch.save and safetensors.torch.save_file give less than a third of that.
I think the impact isn't too big for distributed checkpoints because multiple processes are saving at the same time, but it definitely makes exports and conversions slower than they should be.
This problem is hard to solve because it's in other libraries, but maybe we can find some kind of hack to speed things up, like using multiple serialization processes.
🔄 Steps to Reproduce
Third-party library benchmarks
Running this code:
import contextlib
import pathlib
import shutil
import torch, time
import numpy as np
import safetensors.torch
size=2**30
a=torch.ones(size, dtype=torch.uint8, device="cuda")
b=a.cpu()
c=b.numpy()
dir=pathlib.Path("tmp")
#dir=pathlib.Path("/mnt/workspace/tmp/ckpt")
if dir.is_dir():
shutil.rmtree(dir)
dir.mkdir(exist_ok=True)
@contextlib.contextmanager
def measure_time(name):
start = time.time()
yield
stop = time.time()
print(f"{name}: {stop-start:.3f}s, {size/2**20/(stop-start):.3f} MiB/s")
with measure_time("torch save from gpu"):
torch.save(a, (dir/"torch.pt").open("wb"))
with measure_time("torch save from cpu"):
torch.save(b, (dir/"torch1.pt").open("wb"))
with measure_time("numpy save"):
np.save((dir/"np").open("wb"), c)
with measure_time("safetensors save from gpu"):
safetensors.torch.save_file({"a":a},dir/"c.safetensors")
with measure_time("safetensors save from cpu"):
safetensors.torch.save_file({"a":a},dir/"c1.safetensors")
with measure_time("safetensors serialize in-memory from cpu"):
d=safetensors.torch.save({"a":b})
with measure_time("safetensors serialize in-memory from gpu"):
d=safetensors.torch.save({"a":a})
with measure_time("Write to disk"):
with (dir / "d.safetensors").open("wb") as f:
f.write(d)
We get:
torch save from gpu: 1.338s, 765.159 MiB/s
torch save from cpu: 1.641s, 623.839 MiB/s
numpy save: 0.536s, 1909.272 MiB/s
safetensors save from gpu: 1.558s, 657.230 MiB/s
safetensors save from cpu: 2.307s, 443.877 MiB/s
safetensors serialize in-memory from cpu: 1.658s, 617.791 MiB/s
safetensors serialize in-memory from gpu: 2.092s, 489.483 MiB/s
Write to disk: 0.497s, 2062.189 MiB/s
I also tried with remote storage (commented path), it's a bit slower (~1.5 GiB/s) but follows the same pattern.
Fast-LLM benchmarks
Running the Mistral-7B 4-node benchmark on H100 nodes, with added checkpoint and export:
2024-10-25 16:14:10,223 [Rank 00] Saving checkpoint at iteration 100
2024-10-25 16:14:17,987 [Rank 00] Saved checkpoint to /mnt/checkpoints/fast_llm_dev/benchmark_v1/mistral_4_nodes_2024_10_25_12_09_56/checkpoints/100
2024-10-25 16:14:17,989 [Rank 00] Saving export at iteration 100
2024-10-25 16:14:17,999 [Rank 00] Saving tensors to /mnt/checkpoints/fast_llm_dev/benchmark_v1/mistral_4_nodes_2024_10_25_12_09_56/export/100/state_dict_0.safetensors
2024-10-25 16:14:31,599 [Rank 00] Saving tensors to /mnt/checkpoints/fast_llm_dev/benchmark_v1/mistral_4_nodes_2024_10_25_12_09_56/export/100/state_dict_1.safetensors
2024-10-25 16:14:40,670 [Rank 00] Saving index to /mnt/checkpoints/fast_llm_dev/benchmark_v1/mistral_4_nodes_2024_10_25_12_09_56/export/100/state_dict.safetensors.index.json
2024-10-25 16:14:40,679 [Rank 00] Saved export to /mnt/checkpoints/fast_llm_dev/benchmark_v1/mistral_4_nodes_2024_10_25_12_09_56/export/100
Checkpoint is saved in 7.76 s, and each shard is 2.53 GiB, so write speed is 334 MiB/s/process, or 2670 MiB/s/node. This seems reasonable, but I don't know what speed we should be able to get in theory.
The first export file is saved in 13.60 s and is 8.01 GiB, so write speed is 603 MiB/s. (2nd file has 618 MiB/s). This is at least 4x too slow.
📜 Environment Information
DGX-H100, saving locally (NVMe).
🐞 Describe the Bug
I noticed checkpoint saving is suspiciously slow in some tests, so I decided to investigate.
Checkpoint saving should be bottlenecked by hardware (disk write speed), but turns out it's not. There seems to be something in torch save / safetensors that makes serialization slow.
On H100 NVMe, we should have more than 2 GiB/s of write speed, but
torch.saveandsafetensors.torch.save_filegive less than a third of that.I think the impact isn't too big for distributed checkpoints because multiple processes are saving at the same time, but it definitely makes exports and conversions slower than they should be.
This problem is hard to solve because it's in other libraries, but maybe we can find some kind of hack to speed things up, like using multiple serialization processes.
🔄 Steps to Reproduce
Third-party library benchmarks
Running this code:
We get:
I also tried with remote storage (commented path), it's a bit slower (~1.5 GiB/s) but follows the same pattern.
Fast-LLM benchmarks
Running the Mistral-7B 4-node benchmark on H100 nodes, with added checkpoint and export:
Checkpoint is saved in 7.76 s, and each shard is 2.53 GiB, so write speed is 334 MiB/s/process, or 2670 MiB/s/node. This seems reasonable, but I don't know what speed we should be able to get in theory.
The first export file is saved in 13.60 s and is 8.01 GiB, so write speed is 603 MiB/s. (2nd file has 618 MiB/s). This is at least 4x too slow.
📜 Environment Information
DGX-H100, saving locally (NVMe).