From 9685f394b8bffb49de02cb987d00eb0115ddb4b4 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Sat, 17 Jan 2026 15:47:51 -0800 Subject: [PATCH] fix(pipeline): set _running_engine_backward for non-last stage backward In PipelineEngine._exec_backward_pass(), for non-last stages (Stage 0), torch.autograd.backward() was called directly without setting _running_engine_backward=True. This caused the post-backward hook (_backward_post_hook) to raise a RuntimeError when needs_scaler=True because it incorrectly detected that backward() was called without proper loss scaling. The exception raised inside the callback caused the process to hang, which in turn caused NCCL collective operations to deadlock while waiting for all ranks. Fix by setting _running_engine_backward=True before calling backward() for non-last stages, and resetting it in a finally block. Also update to use the new tensor.backward(gradient) API style instead of torch.autograd.backward(), which properly integrates with DeepSpeed's hooks and loss scaling for non-scalar backward. Fixes pipeline checkpoint tests timing out with ZeRO stage 1. Signed-off-by: Masahiro Tanaka --- deepspeed/runtime/pipe/engine.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 3ca4a3b49f39..463ab711d3cb 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -852,13 +852,22 @@ def _exec_backward_pass(self, buffer_id): # manually call because we don't call optimizer.backward() self.optimizer.clear_lp_grads() - # This handles either a single tensor or tuple of tensors. - if isinstance(outputs, tuple): - out_tensors = [t for t in outputs if t.is_floating_point()] - assert len(out_tensors) == len(grad_tensors) - torch.autograd.backward(tensors=out_tensors, grad_tensors=grad_tensors) - else: - torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, )) + # Set _running_engine_backward to avoid RuntimeError in post-backward hook + # when needs_scaler=True (the hook checks this flag to skip error checking) + self._running_engine_backward = True + try: + # Use tensor.backward(gradient) style which is now supported by DeepSpeed. + # This properly integrates with DeepSpeed's hooks and loss scaling. + if isinstance(outputs, tuple): + out_tensors = [t for t in outputs if t.is_floating_point()] + assert len(out_tensors) == len(grad_tensors) + # For multiple tensors, use retain_graph for all but the last + for i, (out, grad) in enumerate(zip(out_tensors, grad_tensors)): + out.backward(gradient=grad, retain_graph=(i < len(out_tensors) - 1)) + else: + outputs.backward(gradient=grad_tensors) + finally: + self._running_engine_backward = False if self.using_bf16_optimizer and not self.is_last_stage(): # manually call because we don't call optimizer.backward()