diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 462bb679751d..9d92d5d7647d 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1002,7 +1002,8 @@ def flatten_dense_tensors_aligned(self, tensor_list, alignment, use_cpu_data=Fal def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): grad_reduc = self.get_gradient_for_reduction(param) - bucket = self.ipg_buckets[self.get_param_comm_dtype(param)] + comm_dtype = self.get_param_comm_dtype(param) + bucket = self.ipg_buckets[comm_dtype] if bucket.elements + param.numel() > self.reduce_bucket_size: self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel()) self.reduce_ipg_grads() @@ -1022,7 +1023,7 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): if self.contiguous_gradients: if param.numel() > self.reduce_bucket_size: - self.extra_large_param_to_reduce[param.dtype] = param + self.extra_large_param_to_reduce[comm_dtype] = param else: # keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening new_grad_tensor = bucket.buffer[bucket.index].narrow(0, bucket.elements, param.numel())