Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
# support/questions/maintenance: github user @brunomaga or @deepspeedai/deepspeed

import random
import torch
import os
import math
import numpy as np
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
Expand Down Expand Up @@ -156,7 +156,7 @@ def scale_lr(base_batch_size, batch_size, base_lr=1, method="linear"):
# Square Root scaling: "when multiplying the batch size by k, multiply the learning rate
# by √k, to keep the variance in the gradient expectation constant"
# (A. Krizhevsky. One weird trick for parallelizing convolutional neural networks)
return base_lr * torch.sqrt(batch_size / base_batch_size)
return base_lr * math.sqrt(batch_size / base_batch_size)
elif method == None or method.upper() == "NONE":
return base_lr
raise ValueError("Unknown scaling method: {}".format(method))
Expand Down