Skip to content

Commit fd011c1

Browse files
torch rmsnorm
1 parent 9b288b8 commit fd011c1

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

inference/model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
140140
class RMSNorm(nn.Module):
141141
def __init__(self, dim: int, eps: float = 1e-6):
142142
super().__init__()
143+
self.dim = dim
143144
self.eps = eps
144145
self.weight = nn.Parameter(torch.ones(dim))
145146

146147
def forward(self, x: torch.Tensor):
147-
x = x.float()
148-
y = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
149-
return y.type_as(self.weight) * self.weight
148+
return F.rms_norm(x, (self.dim,), self.weight, self.eps)
150149

151150

152151
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:

0 commit comments

Comments
 (0)