diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 3ca4a3b49f39..463ab711d3cb 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -852,13 +852,22 @@ def _exec_backward_pass(self, buffer_id): # manually call because we don't call optimizer.backward() self.optimizer.clear_lp_grads() - # This handles either a single tensor or tuple of tensors. - if isinstance(outputs, tuple): - out_tensors = [t for t in outputs if t.is_floating_point()] - assert len(out_tensors) == len(grad_tensors) - torch.autograd.backward(tensors=out_tensors, grad_tensors=grad_tensors) - else: - torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, )) + # Set _running_engine_backward to avoid RuntimeError in post-backward hook + # when needs_scaler=True (the hook checks this flag to skip error checking) + self._running_engine_backward = True + try: + # Use tensor.backward(gradient) style which is now supported by DeepSpeed. + # This properly integrates with DeepSpeed's hooks and loss scaling. + if isinstance(outputs, tuple): + out_tensors = [t for t in outputs if t.is_floating_point()] + assert len(out_tensors) == len(grad_tensors) + # For multiple tensors, use retain_graph for all but the last + for i, (out, grad) in enumerate(zip(out_tensors, grad_tensors)): + out.backward(gradient=grad, retain_graph=(i < len(out_tensors) - 1)) + else: + outputs.backward(gradient=grad_tensors) + finally: + self._running_engine_backward = False if self.using_bf16_optimizer and not self.is_last_stage(): # manually call because we don't call optimizer.backward()