Skip to content

Commit 8016d2b

Browse files
committed
option to save reconstruction w/ full res depths
1 parent 92027b3 commit 8016d2b

File tree

6 files changed

+43
-6
lines changed

6 files changed

+43
-6
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ python setup.py install
5858
./tools/download_sample_data.sh
5959
```
6060

61-
Run the demo on any of the samples (all demos can be run on a GPU with 11G of memory). While running, press the "s" key to increase the filtering threshold (= more points) and "a" to decrease the filtering threshold (= fewer points).
61+
Run the demo on any of the samples (all demos can be run on a GPU with 11G of memory). While running, press the "s" key to increase the filtering threshold (= more points) and "a" to decrease the filtering threshold (= fewer points). To save the reconstruction with full resolution depth maps use the `--reconstruction_path` flag.
6262

6363

6464
```Python

demo.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,27 @@ def image_stream(imagedir, calib, stride):
5656
yield t, image[None], intrinsics
5757

5858

59+
def save_reconstruction(droid, reconstruction_path):
60+
61+
from pathlib import Path
62+
import random
63+
import string
64+
65+
t = droid.video.counter.value
66+
tstamps = droid.video.tstamp[:t].cpu().numpy()
67+
images = droid.video.images[:t].cpu().numpy()
68+
disps = droid.video.disps_up[:t].cpu().numpy()
69+
poses = droid.video.poses[:t].cpu().numpy()
70+
intrinsics = droid.video.intrinsics[:t].cpu().numpy()
71+
72+
Path("reconstructions/{}".format(reconstruction_path)).mkdir(parents=True, exist_ok=True)
73+
np.save("reconstructions/{}/tstamps.npy".format(reconstruction_path), tstamps)
74+
np.save("reconstructions/{}/images.npy".format(reconstruction_path), images)
75+
np.save("reconstructions/{}/disps.npy".format(reconstruction_path), disps)
76+
np.save("reconstructions/{}/poses.npy".format(reconstruction_path), poses)
77+
np.save("reconstructions/{}/intrinsics.npy".format(reconstruction_path), intrinsics)
78+
79+
5980
if __name__ == '__main__':
6081
parser = argparse.ArgumentParser()
6182
parser.add_argument("--imagedir", type=str, help="path to image directory")
@@ -80,13 +101,19 @@ def image_stream(imagedir, calib, stride):
80101
parser.add_argument("--backend_thresh", type=float, default=22.0)
81102
parser.add_argument("--backend_radius", type=int, default=2)
82103
parser.add_argument("--backend_nms", type=int, default=3)
104+
parser.add_argument("--upsample", action="store_true")
105+
parser.add_argument("--reconstruction_path", help="path to saved reconstruction")
83106
args = parser.parse_args()
84107

85108
args.stereo = False
86109
torch.multiprocessing.set_start_method('spawn')
87110

88111
droid = None
89112

113+
# need high resolution depths
114+
if args.reconstruction_path is not None:
115+
args.upsample = True
116+
90117
tstamps = []
91118
for (t, image, intrinsics) in tqdm(image_stream(args.imagedir, args.calib, args.stride)):
92119
if t < args.t0:
@@ -101,4 +128,7 @@ def image_stream(imagedir, calib, stride):
101128

102129
droid.track(t, image, intrinsics=intrinsics)
103130

131+
if args.reconstruction_path is not None:
132+
save_reconstruction(droid, args.reconstruction_path)
133+
104134
traj_est = droid.terminate(image_stream(args.imagedir, args.calib, args.stride))

droid_slam/droid_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def __init__(self, net, video, args):
1515
self.t0 = 0
1616
self.t1 = 0
1717

18+
self.upsample = args.upsample
1819
self.beta = args.beta
1920
self.backend_thresh = args.backend_thresh
2021
self.backend_radius = args.backend_radius
@@ -28,7 +29,7 @@ def __call__(self, steps=12):
2829
if not self.video.stereo and not torch.any(self.video.disps_sens):
2930
self.video.normalize()
3031

31-
graph = FactorGraph(self.video, self.update_op, corr_impl="alt", max_factors=16*t)
32+
graph = FactorGraph(self.video, self.update_op, corr_impl="alt", max_factors=16*t, upsample=self.upsample)
3233

3334
graph.add_proximity_factors(rad=self.backend_radius,
3435
nms=self.backend_nms,

droid_slam/droid_frontend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class DroidFrontend:
1010
def __init__(self, net, video, args):
1111
self.video = video
1212
self.update_op = net.update
13-
self.graph = FactorGraph(video, net.update, max_factors=48)
13+
self.graph = FactorGraph(video, net.update, max_factors=48, upsample=args.upsample)
1414

1515
# local optimization window
1616
self.t0 = 0

droid_slam/factor_graph.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99

1010

1111
class FactorGraph:
12-
def __init__(self, video, update_op, device="cuda:0", corr_impl="volume", max_factors=-1):
12+
def __init__(self, video, update_op, device="cuda:0", corr_impl="volume", max_factors=-1, upsample=False):
1313
self.video = video
1414
self.update_op = update_op
1515
self.device = device
1616
self.max_factors = max_factors
1717
self.corr_impl = corr_impl
18+
self.upsample = upsample
1819

1920
# operator at 1/8 resolution
2021
self.ht = ht = video.ht // 8
@@ -239,6 +240,9 @@ def update(self, t0=None, t1=None, itrs=2, use_inactive=False, EP=1e-7, motion_o
239240
self.video.ba(target, weight, damping, ii, jj, t0, t1,
240241
itrs=itrs, lm=1e-4, ep=0.1, motion_only=motion_only)
241242

243+
if self.upsample:
244+
self.video.upsample(torch.unique(self.ii), upmask)
245+
242246
self.age += 1
243247

244248

@@ -270,9 +274,11 @@ def update_lowmem(self, t0=None, t1=None, itrs=2, use_inactive=False, EP=1e-7, s
270274

271275
with torch.cuda.amp.autocast(enabled=True):
272276

273-
net, delta, weight, damping, _ = \
277+
net, delta, weight, damping, upmask = \
274278
self.update_op(self.net[:,v], self.video.inps[None,iis], corr1, motn[:,v], iis, jjs)
275279

280+
if self.upsample:
281+
self.video.upsample(torch.unique(iis), upmask)
276282

277283
self.net[:,v] = net
278284
self.target[:,v] = coords1[:,v] + delta.float()

environment.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: droidenv5
1+
name: droidenv
22
channels:
33
- rusty1s
44
- pytorch

0 commit comments

Comments
 (0)