Skip to content

Commit 2407bfa

Browse files
committed
Add data preprocessing scripts
1 parent 166bf74 commit 2407bfa

File tree

16 files changed

+1719
-4
lines changed

16 files changed

+1719
-4
lines changed

README.md

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,23 @@ We have provided some videos [here](https://drive.google.com/file/d/1cKZF6ILeokC
4343

4444
### Customize your own data
4545

46-
*Stay tuned for data preparation scripts.*
46+
We segement video sequences using [SAM-Track](https://github.com/z-x-yang/Segment-and-Track-Anything). Once you obtain the mask files, place them in the folder `all_sequences/{YOUR_SEQUENCE_NAME}/{YOUR_SEQUENCE_NAME}_masks`. Next, execute the following command:
4747

48-
Please organize your own data as follows:
48+
```shell
49+
cd data_preprocessing
50+
python preproc_mask.py
51+
```
52+
53+
We extract optical flows of video sequences using [RAFT](https://github.com/princeton-vl/RAFT). To get started, please follow the instructions provided [here](https://github.com/princeton-vl/RAFT#demos) to download their pretrained model. Once downloaded, place the model in the `data_preprocessing/RAFT/models` folder. After that, you can execute the following command:
54+
55+
```shell
56+
cd data_preprocessing/RAFT
57+
./run_raft.sh
58+
```
59+
60+
Remember to update the sequence name and root directory in both `data_preprocessing/preproc_mask.py` and `data_preprocessing/RAFT/run_raft.sh` accordingly.
61+
62+
After obtaining the files, please organize your own data as follows:
4963

5064
```
5165
CoDeF
@@ -131,7 +145,7 @@ After running the script, the reconstructed videos can be found in `results/all_
131145
## Test video translation
132146

133147
After obtaining the canonical image through [this step](#anchor), use your preferred text prompts to transfer it using [ControlNet](https://github.com/lllyasviel/ControlNet).
134-
Once you have the transferred canonical image, place it in `all_sequences/${NAME}/${EXP_NAME}_control` (i.e. `CANONICAL_DIR` in `scripts/test_canonical.sh`).
148+
Once you have the transferred canonical image, place it in `all_sequences/${NAME}/${EXP_NAME}_control` (i.e. `CANONICAL_DIR` in `scripts/test_canonical.sh`).
135149

136150
Then run
137151

@@ -147,7 +161,7 @@ The transferred results can be seen in `results/all_sequences/{NAME}/{EXP_NAME}_
147161

148162
```bibtex
149163
@article{ouyang2023codef,
150-
title={CoDeF: Content Deformation Fields for Temporally Consistent Video Processing},
164+
title={CoDeF: Content Deformation Fields for Temporally Consistent Video Processing},
151165
author={Hao Ouyang and Qiuyu Wang and Yuxi Xiao and Qingyan Bai and Juntao Zhang and Kecheng Zheng and Xiaowei Zhou and Qifeng Chen and Yujun Shen},
152166
journal={arXiv preprint arXiv:2308.07926},
153167
year={2023}

data_preprocessing/RAFT/core/__init__.py

Whitespace-only changes.
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from utils.utils import bilinear_sampler, coords_grid
4+
5+
try:
6+
import alt_cuda_corr
7+
except:
8+
# alt_cuda_corr is not compiled
9+
pass
10+
11+
12+
class CorrBlock:
13+
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14+
self.num_levels = num_levels
15+
self.radius = radius
16+
self.corr_pyramid = []
17+
18+
# all pairs correlation
19+
corr = CorrBlock.corr(fmap1, fmap2)
20+
21+
batch, h1, w1, dim, h2, w2 = corr.shape
22+
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
23+
24+
self.corr_pyramid.append(corr)
25+
for i in range(self.num_levels-1):
26+
corr = F.avg_pool2d(corr, 2, stride=2)
27+
self.corr_pyramid.append(corr)
28+
29+
def __call__(self, coords):
30+
r = self.radius
31+
coords = coords.permute(0, 2, 3, 1)
32+
batch, h1, w1, _ = coords.shape
33+
34+
out_pyramid = []
35+
for i in range(self.num_levels):
36+
corr = self.corr_pyramid[i]
37+
dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
38+
dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
39+
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
40+
41+
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
42+
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
43+
coords_lvl = centroid_lvl + delta_lvl
44+
45+
corr = bilinear_sampler(corr, coords_lvl)
46+
corr = corr.view(batch, h1, w1, -1)
47+
out_pyramid.append(corr)
48+
49+
out = torch.cat(out_pyramid, dim=-1)
50+
return out.permute(0, 3, 1, 2).contiguous().float()
51+
52+
@staticmethod
53+
def corr(fmap1, fmap2):
54+
batch, dim, ht, wd = fmap1.shape
55+
fmap1 = fmap1.view(batch, dim, ht*wd)
56+
fmap2 = fmap2.view(batch, dim, ht*wd)
57+
58+
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
59+
corr = corr.view(batch, ht, wd, 1, ht, wd)
60+
return corr / torch.sqrt(torch.tensor(dim).float())
61+
62+
63+
class AlternateCorrBlock:
64+
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
65+
self.num_levels = num_levels
66+
self.radius = radius
67+
68+
self.pyramid = [(fmap1, fmap2)]
69+
for i in range(self.num_levels):
70+
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
71+
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
72+
self.pyramid.append((fmap1, fmap2))
73+
74+
def __call__(self, coords):
75+
coords = coords.permute(0, 2, 3, 1)
76+
B, H, W, _ = coords.shape
77+
dim = self.pyramid[0][0].shape[1]
78+
79+
corr_list = []
80+
for i in range(self.num_levels):
81+
r = self.radius
82+
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
83+
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
84+
85+
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
86+
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
87+
corr_list.append(corr.squeeze(1))
88+
89+
corr = torch.stack(corr_list, dim=1)
90+
corr = corr.reshape(B, -1, H, W)
91+
return corr / torch.sqrt(torch.tensor(dim).float())
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# Data loading based on https://github.com/NVIDIA/flownet2-pytorch
2+
3+
import numpy as np
4+
import torch
5+
import torch.utils.data as data
6+
import torch.nn.functional as F
7+
8+
import os
9+
import math
10+
import random
11+
from glob import glob
12+
import os.path as osp
13+
14+
from utils import frame_utils
15+
from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
16+
17+
18+
class FlowDataset(data.Dataset):
19+
def __init__(self, aug_params=None, sparse=False):
20+
self.augmentor = None
21+
self.sparse = sparse
22+
if aug_params is not None:
23+
if sparse:
24+
self.augmentor = SparseFlowAugmentor(**aug_params)
25+
else:
26+
self.augmentor = FlowAugmentor(**aug_params)
27+
28+
self.is_test = False
29+
self.init_seed = False
30+
self.flow_list = []
31+
self.image_list = []
32+
self.extra_info = []
33+
34+
def __getitem__(self, index):
35+
36+
if self.is_test:
37+
img1 = frame_utils.read_gen(self.image_list[index][0])
38+
img2 = frame_utils.read_gen(self.image_list[index][1])
39+
img1 = np.array(img1).astype(np.uint8)[..., :3]
40+
img2 = np.array(img2).astype(np.uint8)[..., :3]
41+
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
42+
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
43+
return img1, img2, self.extra_info[index]
44+
45+
if not self.init_seed:
46+
worker_info = torch.utils.data.get_worker_info()
47+
if worker_info is not None:
48+
torch.manual_seed(worker_info.id)
49+
np.random.seed(worker_info.id)
50+
random.seed(worker_info.id)
51+
self.init_seed = True
52+
53+
index = index % len(self.image_list)
54+
valid = None
55+
if self.sparse:
56+
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
57+
else:
58+
flow = frame_utils.read_gen(self.flow_list[index])
59+
60+
img1 = frame_utils.read_gen(self.image_list[index][0])
61+
img2 = frame_utils.read_gen(self.image_list[index][1])
62+
63+
flow = np.array(flow).astype(np.float32)
64+
img1 = np.array(img1).astype(np.uint8)
65+
img2 = np.array(img2).astype(np.uint8)
66+
67+
# grayscale images
68+
if len(img1.shape) == 2:
69+
img1 = np.tile(img1[...,None], (1, 1, 3))
70+
img2 = np.tile(img2[...,None], (1, 1, 3))
71+
else:
72+
img1 = img1[..., :3]
73+
img2 = img2[..., :3]
74+
75+
if self.augmentor is not None:
76+
if self.sparse:
77+
img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
78+
else:
79+
img1, img2, flow = self.augmentor(img1, img2, flow)
80+
81+
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
82+
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
83+
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
84+
85+
if valid is not None:
86+
valid = torch.from_numpy(valid)
87+
else:
88+
valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
89+
90+
return img1, img2, flow, valid.float()
91+
92+
93+
def __rmul__(self, v):
94+
self.flow_list = v * self.flow_list
95+
self.image_list = v * self.image_list
96+
return self
97+
98+
def __len__(self):
99+
return len(self.image_list)
100+
101+
102+
class MpiSintel(FlowDataset):
103+
def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
104+
super(MpiSintel, self).__init__(aug_params)
105+
flow_root = osp.join(root, split, 'flow')
106+
image_root = osp.join(root, split, dstype)
107+
108+
if split == 'test':
109+
self.is_test = True
110+
111+
for scene in os.listdir(image_root):
112+
image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
113+
for i in range(len(image_list)-1):
114+
self.image_list += [ [image_list[i], image_list[i+1]] ]
115+
self.extra_info += [ (scene, i) ] # scene and frame_id
116+
117+
if split != 'test':
118+
self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
119+
120+
121+
class FlyingChairs(FlowDataset):
122+
def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
123+
super(FlyingChairs, self).__init__(aug_params)
124+
125+
images = sorted(glob(osp.join(root, '*.ppm')))
126+
flows = sorted(glob(osp.join(root, '*.flo')))
127+
assert (len(images)//2 == len(flows))
128+
129+
split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
130+
for i in range(len(flows)):
131+
xid = split_list[i]
132+
if (split=='training' and xid==1) or (split=='validation' and xid==2):
133+
self.flow_list += [ flows[i] ]
134+
self.image_list += [ [images[2*i], images[2*i+1]] ]
135+
136+
137+
class FlyingThings3D(FlowDataset):
138+
def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
139+
super(FlyingThings3D, self).__init__(aug_params)
140+
141+
for cam in ['left']:
142+
for direction in ['into_future', 'into_past']:
143+
image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
144+
image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
145+
146+
flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
147+
flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
148+
149+
for idir, fdir in zip(image_dirs, flow_dirs):
150+
images = sorted(glob(osp.join(idir, '*.png')) )
151+
flows = sorted(glob(osp.join(fdir, '*.pfm')) )
152+
for i in range(len(flows)-1):
153+
if direction == 'into_future':
154+
self.image_list += [ [images[i], images[i+1]] ]
155+
self.flow_list += [ flows[i] ]
156+
elif direction == 'into_past':
157+
self.image_list += [ [images[i+1], images[i]] ]
158+
self.flow_list += [ flows[i+1] ]
159+
160+
161+
class KITTI(FlowDataset):
162+
def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
163+
super(KITTI, self).__init__(aug_params, sparse=True)
164+
if split == 'testing':
165+
self.is_test = True
166+
167+
root = osp.join(root, split)
168+
images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
169+
images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
170+
171+
for img1, img2 in zip(images1, images2):
172+
frame_id = img1.split('/')[-1]
173+
self.extra_info += [ [frame_id] ]
174+
self.image_list += [ [img1, img2] ]
175+
176+
if split == 'training':
177+
self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
178+
179+
180+
class HD1K(FlowDataset):
181+
def __init__(self, aug_params=None, root='datasets/HD1k'):
182+
super(HD1K, self).__init__(aug_params, sparse=True)
183+
184+
seq_ix = 0
185+
while 1:
186+
flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
187+
images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
188+
189+
if len(flows) == 0:
190+
break
191+
192+
for i in range(len(flows)-1):
193+
self.flow_list += [flows[i]]
194+
self.image_list += [ [images[i], images[i+1]] ]
195+
196+
seq_ix += 1
197+
198+
199+
def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
200+
""" Create the data loader for the corresponding trainign set """
201+
202+
if args.stage == 'chairs':
203+
aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
204+
train_dataset = FlyingChairs(aug_params, split='training')
205+
206+
elif args.stage == 'things':
207+
aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
208+
clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
209+
final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
210+
train_dataset = clean_dataset + final_dataset
211+
212+
elif args.stage == 'sintel':
213+
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
214+
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
215+
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
216+
sintel_final = MpiSintel(aug_params, split='training', dstype='final')
217+
218+
if TRAIN_DS == 'C+T+K+S+H':
219+
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
220+
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
221+
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
222+
223+
elif TRAIN_DS == 'C+T+K/S':
224+
train_dataset = 100*sintel_clean + 100*sintel_final + things
225+
226+
elif args.stage == 'kitti':
227+
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
228+
train_dataset = KITTI(aug_params, split='training')
229+
230+
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
231+
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
232+
233+
print('Training with %d image pairs' % len(train_dataset))
234+
return train_loader
235+

0 commit comments

Comments
 (0)