From 895ac3fe04a6df30d7ee33f3326b43d2f0599490 Mon Sep 17 00:00:00 2001 From: Rakshit-gen Date: Thu, 18 Dec 2025 21:24:13 +0530 Subject: [PATCH 1/4] Fix #7733: Replace torch.sqrt with math.sqrt in scale_lr for sqrt method When using lr_scaling_method='sqrt' with dynamic batching, the scale_lr function was failing with TypeError because torch.sqrt expects a Tensor but receives a Python float from batch_size/base_batch_size division. Changed torch.sqrt to math.sqrt which correctly handles Python floats. This fixes the issue where training would fail with: TypeError: sqrt(): argument 'input' (position 1) must be Tensor, not float Signed-off-by: Rakshit-gen --- .../data_pipeline/data_sampling/variable_batch_size_and_lr.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/data_pipeline/data_sampling/variable_batch_size_and_lr.py b/deepspeed/runtime/data_pipeline/data_sampling/variable_batch_size_and_lr.py index c9a39bbc53b5..e61d0acb196c 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/variable_batch_size_and_lr.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/variable_batch_size_and_lr.py @@ -8,6 +8,7 @@ import random import torch import os +import math import numpy as np from torch.optim.lr_scheduler import LRScheduler from torch.optim.optimizer import Optimizer @@ -156,7 +157,7 @@ def scale_lr(base_batch_size, batch_size, base_lr=1, method="linear"): # Square Root scaling: "when multiplying the batch size by k, multiply the learning rate # by √k, to keep the variance in the gradient expectation constant" # (A. Krizhevsky. One weird trick for parallelizing convolutional neural networks) - return base_lr * torch.sqrt(batch_size / base_batch_size) + return base_lr * math.sqrt(batch_size / base_batch_size) elif method == None or method.upper() == "NONE": return base_lr raise ValueError("Unknown scaling method: {}".format(method)) From 7286dd94bc86cd766782fd29da7de2f25db8feab Mon Sep 17 00:00:00 2001 From: Rakshit-gen Date: Fri, 19 Dec 2025 20:00:07 +0530 Subject: [PATCH 2/4] removed unused imports Signed-off-by: Rakshit-gen --- .pre-commit-config.yaml | 2 +- .../data_pipeline/data_sampling/variable_batch_size_and_lr.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9a7bb1c9b371..e9a9c33cb773 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -65,7 +65,7 @@ repos: ] - repo: https://github.com/pycqa/flake8 - rev: 5.0.4 + rev: 7.0.0 hooks: - id: flake8 args: ['--config=.flake8'] diff --git a/deepspeed/runtime/data_pipeline/data_sampling/variable_batch_size_and_lr.py b/deepspeed/runtime/data_pipeline/data_sampling/variable_batch_size_and_lr.py index e61d0acb196c..20ecd64052ec 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/variable_batch_size_and_lr.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/variable_batch_size_and_lr.py @@ -6,7 +6,6 @@ # support/questions/maintenance: github user @brunomaga or @deepspeedai/deepspeed import random -import torch import os import math import numpy as np From eea3ef8621b65caf28d765fb2dd22fcb5a651251 Mon Sep 17 00:00:00 2001 From: Xinyu Lian Date: Thu, 18 Dec 2025 12:21:53 -0600 Subject: [PATCH 3/4] Fix rare hang in DeepSpeed Async I/O wait by releasing the Python GIL (#7727) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _**What this PR does**_ - This PR fixes an occasional deadlock / hang when using DeepSpeed Async I/O (AIO) for NVMe swap-in/swap-out - The hang happens inside aio_handle.wait() where training can stall forever. _**Reproduction**_ [ds_config.json](https://github.com/user-attachments/files/24179010/ds_config.json) [finetune_zero3.py](https://github.com/user-attachments/files/24179011/finetune_zero3.py) Steps 1. Replace {NVME_PATH} in ds_config.json with a valid NVMe mount path on your cluster. 2. Build/install DeepSpeed with AIO enabled: `DS_BUILD_AIO=1 pip install --no-build-isolation .` 3. Run: `CUDA_VISIBLE_DEVICES=0 deepspeed finetune_zero3.py` _**Fix:**_ Release the Python GIL while aio_handle.wait() is blocking by adding a pybind11 call guard (py::gil_scoped_release) to the wait() binding. _**Why this is needed (root cause)**_ Two threads are involved: - Python main thread: calls aio_handle.wait() and blocks until all async I/O operations complete. - AIO worker thread(s): perform the actual file I/O in the background. In some cases, after an I/O operation completes, the worker thread triggers cleanup of PyTorch tensors (e.g., decref / refcount updates for Python-backed objects). That cleanup path may require acquiring the Python GIL. **Before this PR:** - The Python main thread enters aio_handle.wait() while still holding the GIL. - wait() blocks, waiting for the worker thread(s) to finish. - A worker thread completes an I/O op and reaches a cleanup path that attempts to acquire the GIL. - The worker thread cannot acquire the GIL because it is held by the Python thread blocked in wait(). - Result: the Python thread is waiting for the worker, and the worker is waiting for the GIL → deadlock. Signed-off-by: Rakshit-gen --- csrc/aio/py_lib/py_ds_aio.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/aio/py_lib/py_ds_aio.cpp b/csrc/aio/py_lib/py_ds_aio.cpp index 62500bf4a6e9..cf9838cf8191 100644 --- a/csrc/aio/py_lib/py_ds_aio.cpp +++ b/csrc/aio/py_lib/py_ds_aio.cpp @@ -124,5 +124,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) .def("wait", &deepspeed_aio_handle_t::wait, - "Wait for (ongoing) asynchronous operations to complete"); + "Wait for (ongoing) asynchronous operations to complete", + py::call_guard()); } From 994da4735f38c550556632fdb13ef5c791ab95e0 Mon Sep 17 00:00:00 2001 From: Rakshit-gen Date: Fri, 19 Dec 2025 20:23:51 +0530 Subject: [PATCH 4/4] revert unneccesary change Signed-off-by: Rakshit-gen --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e9a9c33cb773..9a7bb1c9b371 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -65,7 +65,7 @@ repos: ] - repo: https://github.com/pycqa/flake8 - rev: 7.0.0 + rev: 5.0.4 hooks: - id: flake8 args: ['--config=.flake8']