diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index f8c146c81d2f..dc3974e1680e 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1264,6 +1264,7 @@ def make_hook(params): def reduce_leaf_module_grads(module, grad_input, grad_output): for param in params: + # this takes care of grads for MoE experts that didn't participate in the current iteration/layer if param.grad is None: param.grad = torch.zeros_like(param) self.reduce_ready_partitions_and_remove_grads(param)