diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 316e79da4fd6..57b28991d255 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -165,9 +165,8 @@ def __init__(self, model_dim, time_dim, max_period=10000.0): self.activation = nn.SiLU() self.out_layer = nn.Linear(time_dim, time_dim, bias=True) - @torch.autocast(device_type="cuda", dtype=torch.float32) def forward(self, time): - args = torch.outer(time, self.freqs.to(device=time.device)) + args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device)) time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) return time_embed @@ -269,7 +268,6 @@ def __init__(self, time_dim, model_dim, num_params): self.out_layer.weight.data.zero_() self.out_layer.bias.data.zero_() - @torch.autocast(device_type="cuda", dtype=torch.float32) def forward(self, x): return self.out_layer(self.activation(x)) @@ -525,6 +523,7 @@ class Kandinsky5Transformer3DModel( "Kandinsky5TransformerEncoderBlock", "Kandinsky5TransformerDecoderBlock", ] + _keep_in_fp32_modules = ["time_embeddings", "modulation", "visual_modulation", "text_modulation"] _supports_gradient_checkpointing = True @register_to_config