@@ -30,7 +30,8 @@ def get_down_block(
3030
3131 unet_use_cross_frame_attention = None ,
3232 unet_use_temporal_attention = None ,
33-
33+ use_inflated_groupnorm = None ,
34+
3435 use_motion_module = None ,
3536
3637 motion_module_type = None ,
@@ -50,6 +51,8 @@ def get_down_block(
5051 downsample_padding = downsample_padding ,
5152 resnet_time_scale_shift = resnet_time_scale_shift ,
5253
54+ use_inflated_groupnorm = use_inflated_groupnorm ,
55+
5356 use_motion_module = use_motion_module ,
5457 motion_module_type = motion_module_type ,
5558 motion_module_kwargs = motion_module_kwargs ,
@@ -77,6 +80,7 @@ def get_down_block(
7780
7881 unet_use_cross_frame_attention = unet_use_cross_frame_attention ,
7982 unet_use_temporal_attention = unet_use_temporal_attention ,
83+ use_inflated_groupnorm = use_inflated_groupnorm ,
8084
8185 use_motion_module = use_motion_module ,
8286 motion_module_type = motion_module_type ,
@@ -106,6 +110,7 @@ def get_up_block(
106110
107111 unet_use_cross_frame_attention = None ,
108112 unet_use_temporal_attention = None ,
113+ use_inflated_groupnorm = None ,
109114
110115 use_motion_module = None ,
111116 motion_module_type = None ,
@@ -125,6 +130,8 @@ def get_up_block(
125130 resnet_groups = resnet_groups ,
126131 resnet_time_scale_shift = resnet_time_scale_shift ,
127132
133+ use_inflated_groupnorm = use_inflated_groupnorm ,
134+
128135 use_motion_module = use_motion_module ,
129136 motion_module_type = motion_module_type ,
130137 motion_module_kwargs = motion_module_kwargs ,
@@ -152,6 +159,7 @@ def get_up_block(
152159
153160 unet_use_cross_frame_attention = unet_use_cross_frame_attention ,
154161 unet_use_temporal_attention = unet_use_temporal_attention ,
162+ use_inflated_groupnorm = use_inflated_groupnorm ,
155163
156164 use_motion_module = use_motion_module ,
157165 motion_module_type = motion_module_type ,
@@ -181,6 +189,7 @@ def __init__(
181189
182190 unet_use_cross_frame_attention = None ,
183191 unet_use_temporal_attention = None ,
192+ use_inflated_groupnorm = None ,
184193
185194 use_motion_module = None ,
186195
@@ -206,6 +215,8 @@ def __init__(
206215 non_linearity = resnet_act_fn ,
207216 output_scale_factor = output_scale_factor ,
208217 pre_norm = resnet_pre_norm ,
218+
219+ use_inflated_groupnorm = use_inflated_groupnorm ,
209220 )
210221 ]
211222 attentions = []
@@ -248,6 +259,8 @@ def __init__(
248259 non_linearity = resnet_act_fn ,
249260 output_scale_factor = output_scale_factor ,
250261 pre_norm = resnet_pre_norm ,
262+
263+ use_inflated_groupnorm = use_inflated_groupnorm ,
251264 )
252265 )
253266
@@ -290,6 +303,7 @@ def __init__(
290303
291304 unet_use_cross_frame_attention = None ,
292305 unet_use_temporal_attention = None ,
306+ use_inflated_groupnorm = None ,
293307
294308 use_motion_module = None ,
295309
@@ -318,6 +332,8 @@ def __init__(
318332 non_linearity = resnet_act_fn ,
319333 output_scale_factor = output_scale_factor ,
320334 pre_norm = resnet_pre_norm ,
335+
336+ use_inflated_groupnorm = use_inflated_groupnorm ,
321337 )
322338 )
323339 if dual_cross_attention :
@@ -421,6 +437,8 @@ def __init__(
421437 output_scale_factor = 1.0 ,
422438 add_downsample = True ,
423439 downsample_padding = 1 ,
440+
441+ use_inflated_groupnorm = None ,
424442
425443 use_motion_module = None ,
426444 motion_module_type = None ,
@@ -444,6 +462,8 @@ def __init__(
444462 non_linearity = resnet_act_fn ,
445463 output_scale_factor = output_scale_factor ,
446464 pre_norm = resnet_pre_norm ,
465+
466+ use_inflated_groupnorm = use_inflated_groupnorm ,
447467 )
448468 )
449469 motion_modules .append (
@@ -526,6 +546,7 @@ def __init__(
526546
527547 unet_use_cross_frame_attention = None ,
528548 unet_use_temporal_attention = None ,
549+ use_inflated_groupnorm = None ,
529550
530551 use_motion_module = None ,
531552
@@ -556,6 +577,8 @@ def __init__(
556577 non_linearity = resnet_act_fn ,
557578 output_scale_factor = output_scale_factor ,
558579 pre_norm = resnet_pre_norm ,
580+
581+ use_inflated_groupnorm = use_inflated_groupnorm ,
559582 )
560583 )
561584 if dual_cross_attention :
@@ -661,6 +684,8 @@ def __init__(
661684 output_scale_factor = 1.0 ,
662685 add_upsample = True ,
663686
687+ use_inflated_groupnorm = None ,
688+
664689 use_motion_module = None ,
665690 motion_module_type = None ,
666691 motion_module_kwargs = None ,
@@ -685,6 +710,8 @@ def __init__(
685710 non_linearity = resnet_act_fn ,
686711 output_scale_factor = output_scale_factor ,
687712 pre_norm = resnet_pre_norm ,
713+
714+ use_inflated_groupnorm = use_inflated_groupnorm ,
688715 )
689716 )
690717 motion_modules .append (
0 commit comments