Skip to content

Commit 1089219

Browse files
author
Yuwei Guo
committed
support v2
1 parent 1b50d64 commit 1089219

File tree

3 files changed

+61
-6
lines changed

3 files changed

+61
-6
lines changed

animatediff/models/resnet.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@ def forward(self, x):
1818
return x
1919

2020

21+
class InflatedGroupNorm(nn.GroupNorm):
22+
def forward(self, x):
23+
video_length = x.shape[2]
24+
25+
x = rearrange(x, "b c f h w -> (b f) c h w")
26+
x = super().forward(x)
27+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
28+
29+
return x
30+
31+
2132
class Upsample3D(nn.Module):
2233
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
2334
super().__init__()
@@ -112,6 +123,7 @@ def __init__(
112123
time_embedding_norm="default",
113124
output_scale_factor=1.0,
114125
use_in_shortcut=None,
126+
use_inflated_groupnorm=None,
115127
):
116128
super().__init__()
117129
self.pre_norm = pre_norm
@@ -126,7 +138,11 @@ def __init__(
126138
if groups_out is None:
127139
groups_out = groups
128140

129-
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
141+
assert use_inflated_groupnorm != None
142+
if use_inflated_groupnorm:
143+
self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
144+
else:
145+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
130146

131147
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
132148

@@ -142,7 +158,11 @@ def __init__(
142158
else:
143159
self.time_emb_proj = None
144160

145-
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
161+
if use_inflated_groupnorm:
162+
self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
163+
else:
164+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
165+
146166
self.dropout = torch.nn.Dropout(dropout)
147167
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
148168

animatediff/models/unet.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
get_down_block,
2525
get_up_block,
2626
)
27-
from .resnet import InflatedConv3d
27+
from .resnet import InflatedConv3d, InflatedGroupNorm
2828

2929

3030
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -77,6 +77,8 @@ def __init__(
7777
upcast_attention: bool = False,
7878
resnet_time_scale_shift: str = "default",
7979

80+
use_inflated_groupnorm=False,
81+
8082
# Additional
8183
use_motion_module = False,
8284
motion_module_resolutions = ( 1,2,4,8 ),
@@ -88,7 +90,7 @@ def __init__(
8890
unet_use_temporal_attention = None,
8991
):
9092
super().__init__()
91-
93+
9294
self.sample_size = sample_size
9395
time_embed_dim = block_out_channels[0] * 4
9496

@@ -150,6 +152,7 @@ def __init__(
150152

151153
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
152154
unet_use_temporal_attention=unet_use_temporal_attention,
155+
use_inflated_groupnorm=use_inflated_groupnorm,
153156

154157
use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
155158
motion_module_type=motion_module_type,
@@ -175,6 +178,7 @@ def __init__(
175178

176179
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
177180
unet_use_temporal_attention=unet_use_temporal_attention,
181+
use_inflated_groupnorm=use_inflated_groupnorm,
178182

179183
use_motion_module=use_motion_module and motion_module_mid_block,
180184
motion_module_type=motion_module_type,
@@ -227,6 +231,7 @@ def __init__(
227231

228232
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
229233
unet_use_temporal_attention=unet_use_temporal_attention,
234+
use_inflated_groupnorm=use_inflated_groupnorm,
230235

231236
use_motion_module=use_motion_module and (res in motion_module_resolutions),
232237
motion_module_type=motion_module_type,
@@ -236,7 +241,10 @@ def __init__(
236241
prev_output_channel = output_channel
237242

238243
# out
239-
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
244+
if use_inflated_groupnorm:
245+
self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
246+
else:
247+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
240248
self.conv_act = nn.SiLU()
241249
self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
242250

animatediff/models/unet_blocks.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ def get_down_block(
3030

3131
unet_use_cross_frame_attention=None,
3232
unet_use_temporal_attention=None,
33-
33+
use_inflated_groupnorm=None,
34+
3435
use_motion_module=None,
3536

3637
motion_module_type=None,
@@ -50,6 +51,8 @@ def get_down_block(
5051
downsample_padding=downsample_padding,
5152
resnet_time_scale_shift=resnet_time_scale_shift,
5253

54+
use_inflated_groupnorm=use_inflated_groupnorm,
55+
5356
use_motion_module=use_motion_module,
5457
motion_module_type=motion_module_type,
5558
motion_module_kwargs=motion_module_kwargs,
@@ -77,6 +80,7 @@ def get_down_block(
7780

7881
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
7982
unet_use_temporal_attention=unet_use_temporal_attention,
83+
use_inflated_groupnorm=use_inflated_groupnorm,
8084

8185
use_motion_module=use_motion_module,
8286
motion_module_type=motion_module_type,
@@ -106,6 +110,7 @@ def get_up_block(
106110

107111
unet_use_cross_frame_attention=None,
108112
unet_use_temporal_attention=None,
113+
use_inflated_groupnorm=None,
109114

110115
use_motion_module=None,
111116
motion_module_type=None,
@@ -125,6 +130,8 @@ def get_up_block(
125130
resnet_groups=resnet_groups,
126131
resnet_time_scale_shift=resnet_time_scale_shift,
127132

133+
use_inflated_groupnorm=use_inflated_groupnorm,
134+
128135
use_motion_module=use_motion_module,
129136
motion_module_type=motion_module_type,
130137
motion_module_kwargs=motion_module_kwargs,
@@ -152,6 +159,7 @@ def get_up_block(
152159

153160
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
154161
unet_use_temporal_attention=unet_use_temporal_attention,
162+
use_inflated_groupnorm=use_inflated_groupnorm,
155163

156164
use_motion_module=use_motion_module,
157165
motion_module_type=motion_module_type,
@@ -181,6 +189,7 @@ def __init__(
181189

182190
unet_use_cross_frame_attention=None,
183191
unet_use_temporal_attention=None,
192+
use_inflated_groupnorm=None,
184193

185194
use_motion_module=None,
186195

@@ -206,6 +215,8 @@ def __init__(
206215
non_linearity=resnet_act_fn,
207216
output_scale_factor=output_scale_factor,
208217
pre_norm=resnet_pre_norm,
218+
219+
use_inflated_groupnorm=use_inflated_groupnorm,
209220
)
210221
]
211222
attentions = []
@@ -248,6 +259,8 @@ def __init__(
248259
non_linearity=resnet_act_fn,
249260
output_scale_factor=output_scale_factor,
250261
pre_norm=resnet_pre_norm,
262+
263+
use_inflated_groupnorm=use_inflated_groupnorm,
251264
)
252265
)
253266

@@ -290,6 +303,7 @@ def __init__(
290303

291304
unet_use_cross_frame_attention=None,
292305
unet_use_temporal_attention=None,
306+
use_inflated_groupnorm=None,
293307

294308
use_motion_module=None,
295309

@@ -318,6 +332,8 @@ def __init__(
318332
non_linearity=resnet_act_fn,
319333
output_scale_factor=output_scale_factor,
320334
pre_norm=resnet_pre_norm,
335+
336+
use_inflated_groupnorm=use_inflated_groupnorm,
321337
)
322338
)
323339
if dual_cross_attention:
@@ -421,6 +437,8 @@ def __init__(
421437
output_scale_factor=1.0,
422438
add_downsample=True,
423439
downsample_padding=1,
440+
441+
use_inflated_groupnorm=None,
424442

425443
use_motion_module=None,
426444
motion_module_type=None,
@@ -444,6 +462,8 @@ def __init__(
444462
non_linearity=resnet_act_fn,
445463
output_scale_factor=output_scale_factor,
446464
pre_norm=resnet_pre_norm,
465+
466+
use_inflated_groupnorm=use_inflated_groupnorm,
447467
)
448468
)
449469
motion_modules.append(
@@ -526,6 +546,7 @@ def __init__(
526546

527547
unet_use_cross_frame_attention=None,
528548
unet_use_temporal_attention=None,
549+
use_inflated_groupnorm=None,
529550

530551
use_motion_module=None,
531552

@@ -556,6 +577,8 @@ def __init__(
556577
non_linearity=resnet_act_fn,
557578
output_scale_factor=output_scale_factor,
558579
pre_norm=resnet_pre_norm,
580+
581+
use_inflated_groupnorm=use_inflated_groupnorm,
559582
)
560583
)
561584
if dual_cross_attention:
@@ -661,6 +684,8 @@ def __init__(
661684
output_scale_factor=1.0,
662685
add_upsample=True,
663686

687+
use_inflated_groupnorm=None,
688+
664689
use_motion_module=None,
665690
motion_module_type=None,
666691
motion_module_kwargs=None,
@@ -685,6 +710,8 @@ def __init__(
685710
non_linearity=resnet_act_fn,
686711
output_scale_factor=output_scale_factor,
687712
pre_norm=resnet_pre_norm,
713+
714+
use_inflated_groupnorm=use_inflated_groupnorm,
688715
)
689716
)
690717
motion_modules.append(

0 commit comments

Comments
 (0)