Skip to content

Commit 59ef5a5

Browse files
committed
add interp training
1 parent 26e665c commit 59ef5a5

File tree

4 files changed

+225
-7
lines changed

4 files changed

+225
-7
lines changed

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,8 @@ From CUHK and Tencent AI Lab.
277277

278278

279279
## 📝 Changelog
280-
- __[2024.05.24]__: 🔥🔥 Release WebVid10M-motion annotations.
280+
- __[2024.06.14]__: 🔥🔥 Release training code for interpolation.
281+
- __[2024.05.24]__: Release WebVid10M-motion annotations.
281282
- __[2024.05.05]__: Release training code.
282283
- __[2024.03.14]__: Release generative frame interpolation and looping video models (320x512).
283284
- __[2024.02.05]__: Release high-resolution models (320x512 & 576x1024).
@@ -361,6 +362,13 @@ We adopt `DDPShardedStrategy` by default for training, please make sure it is av
361362
```
362363
5. All the checkpoints/tensorboard record/loginfo will be saved in `<YOUR_SAVE_ROOT_DIR>`.
363364

365+
### Generative Frame Interpolation
366+
Download pretrained model DynamiCrafter512_interp and put the `model.ckpt` in `checkpoints/dynamicrafter_512_interp_v1/model.ckpt`. Follow the same fine-tuning procedure in "Image-to-Video Generation", and run the script below:
367+
```bash
368+
sh configs/training_512_v1.0/run_interp.sh
369+
```
370+
371+
364372
## 🎁 WebVid-10M-motion annotations (~2.6M)
365373
The annoations of our WebVid-10M-motion is available on [Huggingface Dataset](https://huggingface.co/datasets/Doubiiu/webvid10m_motion). In addition to the original annotations, we add three more motion-related annotations: `dynamic_confidence`, `dynamic_wording`, and `dynamic_source_category`. Please refer to our [supplementary document](https://arxiv.org/pdf/2310.12190) (Section D) for more details.
366374

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
model:
2+
pretrained_checkpoint: checkpoints/dynamicrafter_512_interp_v1/model.ckpt
3+
base_learning_rate: 1.0e-05
4+
scale_lr: False
5+
target: lvdm.models.ddpm3d.LatentVisualDiffusion
6+
params:
7+
rescale_betas_zero_snr: True
8+
parameterization: "v"
9+
linear_start: 0.00085
10+
linear_end: 0.012
11+
num_timesteps_cond: 1
12+
log_every_t: 200
13+
timesteps: 1000
14+
first_stage_key: video
15+
cond_stage_key: caption
16+
cond_stage_trainable: False
17+
image_proj_model_trainable: True
18+
conditioning_key: hybrid
19+
image_size: [40, 64]
20+
channels: 4
21+
scale_by_std: False
22+
scale_factor: 0.18215
23+
use_ema: False
24+
uncond_prob: 0.05
25+
uncond_type: 'empty_seq'
26+
rand_cond_frame: false
27+
use_dynamic_rescale: true
28+
base_scale: 0.7
29+
fps_condition_type: 'fps'
30+
perframe_ae: true
31+
interp_mode: true
32+
33+
unet_config:
34+
target: lvdm.modules.networks.openaimodel3d.UNetModel
35+
params:
36+
in_channels: 8
37+
out_channels: 4
38+
model_channels: 320
39+
attention_resolutions:
40+
- 4
41+
- 2
42+
- 1
43+
num_res_blocks: 2
44+
channel_mult:
45+
- 1
46+
- 2
47+
- 4
48+
- 4
49+
dropout: 0.1
50+
num_head_channels: 64
51+
transformer_depth: 1
52+
context_dim: 1024
53+
use_linear: true
54+
use_checkpoint: True
55+
temporal_conv: True
56+
temporal_attention: True
57+
temporal_selfatt_only: true
58+
use_relative_position: false
59+
use_causal_attention: False
60+
temporal_length: 16
61+
addition_attention: true
62+
image_cross_attention: true
63+
default_fs: 10
64+
fs_condition: true
65+
66+
first_stage_config:
67+
target: lvdm.models.autoencoder.AutoencoderKL
68+
params:
69+
embed_dim: 4
70+
monitor: val/rec_loss
71+
ddconfig:
72+
double_z: True
73+
z_channels: 4
74+
resolution: 256
75+
in_channels: 3
76+
out_ch: 3
77+
ch: 128
78+
ch_mult:
79+
- 1
80+
- 2
81+
- 4
82+
- 4
83+
num_res_blocks: 2
84+
attn_resolutions: []
85+
dropout: 0.0
86+
lossconfig:
87+
target: torch.nn.Identity
88+
89+
cond_stage_config:
90+
target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
91+
params:
92+
freeze: true
93+
layer: "penultimate"
94+
95+
img_cond_stage_config:
96+
target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
97+
params:
98+
freeze: true
99+
100+
image_proj_stage_config:
101+
target: lvdm.modules.encoders.resampler.Resampler
102+
params:
103+
dim: 1024
104+
depth: 4
105+
dim_head: 64
106+
heads: 12
107+
num_queries: 16
108+
embedding_dim: 1280
109+
output_dim: 1024
110+
ff_mult: 4
111+
video_length: 16
112+
113+
data:
114+
target: utils_data.DataModuleFromConfig
115+
params:
116+
batch_size: 2
117+
num_workers: 12
118+
wrap: false
119+
train:
120+
target: lvdm.data.webvid.WebVid
121+
params:
122+
data_dir: <WebVid10M DATA>
123+
meta_path: <.csv FILE>
124+
video_length: 16
125+
frame_stride: 6
126+
load_raw_resolution: true
127+
resolution: [320, 512]
128+
spatial_transform: resize_center_crop
129+
random_fs: true ## if true, we uniformly sample fs with max_fs=frame_stride (above)
130+
131+
lightning:
132+
precision: 16
133+
# strategy: deepspeed_stage_2
134+
trainer:
135+
benchmark: True
136+
accumulate_grad_batches: 2
137+
max_steps: 100000
138+
# logger
139+
log_every_n_steps: 50
140+
# val
141+
val_check_interval: 0.5
142+
gradient_clip_algorithm: 'norm'
143+
gradient_clip_val: 0.5
144+
callbacks:
145+
model_checkpoint:
146+
target: pytorch_lightning.callbacks.ModelCheckpoint
147+
params:
148+
every_n_train_steps: 9000 #1000
149+
filename: "{epoch}-{step}"
150+
save_weights_only: True
151+
metrics_over_trainsteps_checkpoint:
152+
target: pytorch_lightning.callbacks.ModelCheckpoint
153+
params:
154+
filename: '{epoch}-{step}'
155+
save_weights_only: True
156+
every_n_train_steps: 10000 #20000 # 3s/step*2w=
157+
batch_logger:
158+
target: callbacks.ImageLogger
159+
params:
160+
batch_frequency: 500
161+
to_local: False
162+
max_images: 8
163+
log_images_kwargs:
164+
ddim_steps: 50
165+
unconditional_guidance_scale: 7.5
166+
timestep_spacing: uniform_trailing
167+
guidance_rescale: 0.7
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# NCCL configuration
2+
# export NCCL_DEBUG=INFO
3+
# export NCCL_IB_DISABLE=0
4+
# export NCCL_IB_GID_INDEX=3
5+
# export NCCL_NET_GDR_LEVEL=3
6+
# export NCCL_TOPO_FILE=/tmp/topo.txt
7+
8+
# args
9+
name="training_512_v1.0"
10+
config_file=configs/${name}/config_interp.yaml
11+
12+
# save root dir for logs, checkpoints, tensorboard record, etc.
13+
save_root="<YOUR_SAVE_ROOT_DIR>"
14+
15+
mkdir -p $save_root/${name}_interp
16+
17+
## run
18+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \
19+
--nproc_per_node=$HOST_GPU_NUM --nnodes=1 --master_addr=127.0.0.1 --master_port=12352 --node_rank=0 \
20+
./main/trainer.py \
21+
--base $config_file \
22+
--train \
23+
--name ${name}_interp \
24+
--logdir $save_root \
25+
--devices $HOST_GPU_NUM \
26+
lightning.trainer.num_nodes=1
27+
28+
## debugging
29+
# CUDA_VISIBLE_DEVICES=2,3,4,5,6,7 python3 -m torch.distributed.launch \
30+
# --nproc_per_node=6 --nnodes=1 --master_addr=127.0.0.1 --master_port=12352 --node_rank=0 \
31+
# ./main/trainer.py \
32+
# --base $config_file \
33+
# --train \
34+
# --name ${name}_interp \
35+
# --logdir $save_root \
36+
# --devices 6 \
37+
# lightning.trainer.num_nodes=1

lvdm/models/ddpm3d.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def __init__(self,
481481
use_dynamic_rescale=False,
482482
base_scale=0.7,
483483
turning_step=400,
484-
loop_video=False,
484+
interp_mode=False,
485485
fps_condition_type='fs',
486486
perframe_ae=False,
487487
# added
@@ -502,7 +502,7 @@ def __init__(self,
502502
self.cond_stage_key = cond_stage_key
503503
self.noise_strength = noise_strength
504504
self.use_dynamic_rescale = use_dynamic_rescale
505-
self.loop_video = loop_video
505+
self.interp_mode = interp_mode
506506
self.fps_condition_type = fps_condition_type
507507
self.perframe_ae = perframe_ae
508508

@@ -1093,10 +1093,16 @@ def get_batch_input(self, batch, random_uncond, return_first_stage_outputs=False
10931093
img_emb = self.image_proj_model(img_emb)
10941094

10951095
if self.model.conditioning_key == 'hybrid':
1096-
## simply repeat the cond_frame to match the seq_len of z
1097-
img_cat_cond = z[:,:,cond_frame_index,:,:]
1098-
img_cat_cond = img_cat_cond.unsqueeze(2)
1099-
img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=z.shape[2])
1096+
if self.interp_mode:
1097+
## starting frame + (L-2 empty frames) + ending frame
1098+
img_cat_cond = torch.zeros_like(z)
1099+
img_cat_cond[:,:,0,:,:] = z[:,:,0,:,:]
1100+
img_cat_cond[:,:,-1,:,:] = z[:,:,-1,:,:]
1101+
else:
1102+
## simply repeat the cond_frame to match the seq_len of z
1103+
img_cat_cond = z[:,:,cond_frame_index,:,:]
1104+
img_cat_cond = img_cat_cond.unsqueeze(2)
1105+
img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=z.shape[2])
11001106

11011107
cond["c_concat"] = [img_cat_cond] # b c t h w
11021108
cond["c_crossattn"] = [torch.cat([prompt_imb, img_emb], dim=1)] ## concat in the seq_len dim

0 commit comments

Comments
 (0)