Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,13 +852,22 @@ def _exec_backward_pass(self, buffer_id):
# manually call because we don't call optimizer.backward()
self.optimizer.clear_lp_grads()

# This handles either a single tensor or tuple of tensors.
if isinstance(outputs, tuple):
out_tensors = [t for t in outputs if t.is_floating_point()]
assert len(out_tensors) == len(grad_tensors)
torch.autograd.backward(tensors=out_tensors, grad_tensors=grad_tensors)
else:
torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, ))
# Set _running_engine_backward to avoid RuntimeError in post-backward hook
# when needs_scaler=True (the hook checks this flag to skip error checking)
self._running_engine_backward = True
try:
Comment thread
sfc-gh-truwase marked this conversation as resolved.
# Use tensor.backward(gradient) style which is now supported by DeepSpeed.
# This properly integrates with DeepSpeed's hooks and loss scaling.
if isinstance(outputs, tuple):
out_tensors = [t for t in outputs if t.is_floating_point()]
assert len(out_tensors) == len(grad_tensors)
# For multiple tensors, use retain_graph for all but the last
for i, (out, grad) in enumerate(zip(out_tensors, grad_tensors)):
out.backward(gradient=grad, retain_graph=(i < len(out_tensors) - 1))
else:
outputs.backward(gradient=grad_tensors)
finally:
self._running_engine_backward = False

if self.using_bf16_optimizer and not self.is_last_stage():
# manually call because we don't call optimizer.backward()
Expand Down