Skip to content

Commit 7234a93

Browse files
committed
stereo + rgbd eval code
1 parent 8207c9f commit 7234a93

19 files changed

+499
-120
lines changed

README.md

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Zachary Teed and Jia Deng
1515
}
1616
```
1717

18-
**Initial Code Release:** This repo currently provides a single GPU implementation of our monocular SLAM system. It also contains demos, training, and evaluation scripts. Stereo, RGB-D, and multi-GPU code will be added on **September 7**.
18+
**Initial Code Release:** This repo currently provides a single GPU implementation of our monocular SLAM system. It currently contains demos, training, and evaluation scripts.
1919

2020

2121
## Requirements
@@ -83,25 +83,33 @@ fx fy cx cy [k1 k2 p1 p2 [ k3 [ k4 k5 k6 ]]]
8383
```
8484
with parameters in brackets optional.
8585

86-
## Evaluation (Monocular)
87-
We provide evaluation scripts for TartanAir, EuRoC, and TUM. EuRoC and TUM can be run on a 1080Ti. The TartanAir validation script will require 24G of memory.
86+
## Evaluation
87+
We provide evaluation scripts for TartanAir, EuRoC, and TUM. EuRoC and TUM can be run on a 1080Ti. The TartanAir and ETH will require 24G of memory.
8888

89-
### EuRoC
89+
### TartanAir (Mono + Stereo)
90+
Download the [TartanAir](https://theairlab.org/tartanair-dataset/) dataset using the script `thirdparty/tartanair_tools/download_training.py` and put them in `datasets/TartanAir`
91+
```Bash
92+
./tools/validate_tartanair.sh --plot_curve # monocular eval
93+
./tools/validate_tartanair.sh --plot_curve --stereo # stereo eval
94+
```
95+
96+
### EuRoC (Mono + Stereo)
9097
Download the [EuRoC](https://projects.asl.ethz.ch/datasets/doku.php?id=kmavvisualinertialdatasets) sequences (ASL format) and put them in `datasets/EuRoC`
9198
```Bash
92-
./tools/evaluate_euroc.sh
99+
./tools/evaluate_euroc.sh # monocular eval
100+
./tools/evaluate_euroc.sh --stereo # stereo eval
93101
```
94102

95-
### TUM-RGBD
103+
### TUM-RGBD (Mono)
96104
Download the fr1 sequences from [TUM-RGBD](https://vision.in.tum.de/data/datasets/rgbd-dataset/download) and put them in `datasets/TUM-RGBD`
97105
```Bash
98-
./tools/evaluate_tum.sh
106+
./tools/evaluate_tum.sh # monocular eval
99107
```
100108

101-
### TartanAir
102-
Download the [TartanAir](https://theairlab.org/tartanair-dataset/) dataset using the script `thirdparty/tartanair_tools/download_training.py` and put them in `datasets/TartanAir`
109+
### ETH3D (RGB-D)
110+
Download the [ETH3D](https://www.eth3d.net/slam_datasets) dataset
103111
```Bash
104-
./tools/validate_tartanair.sh
112+
./tools/evaluate_eth3d.sh # RGB-D eval
105113
```
106114

107115
## Training

droid_slam/depth_video.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import geom.projective_ops as pops
1111

1212
class DepthVideo:
13-
def __init__(self, image_size=[480, 640], buffer=1024, device="cuda:0"):
13+
def __init__(self, image_size=[480, 640], buffer=1024, stereo=False, device="cuda:0"):
1414

1515
# current keyframe count
1616
self.counter = Value('i', 0)
@@ -25,11 +25,15 @@ def __init__(self, image_size=[480, 640], buffer=1024, device="cuda:0"):
2525
self.red = torch.zeros(buffer, device="cuda", dtype=torch.bool).share_memory_()
2626
self.poses = torch.zeros(buffer, 7, device="cuda", dtype=torch.float).share_memory_()
2727
self.disps = torch.ones(buffer, ht//8, wd//8, device="cuda", dtype=torch.float).share_memory_()
28+
self.disps_sens = torch.zeros(buffer, ht//8, wd//8, device="cuda", dtype=torch.float).share_memory_()
2829
self.disps_up = torch.zeros(buffer, ht, wd, device="cuda", dtype=torch.float).share_memory_()
2930
self.intrinsics = torch.zeros(buffer, 4, device="cuda", dtype=torch.float).share_memory_()
3031

32+
self.stereo = stereo
33+
c = 1 if not self.stereo else 2
34+
3135
### feature attributes ###
32-
self.fmaps = torch.zeros(buffer, 128, ht//8, wd//8, dtype=torch.half, device="cuda").share_memory_()
36+
self.fmaps = torch.zeros(buffer, c, 128, ht//8, wd//8, dtype=torch.half, device="cuda").share_memory_()
3337
self.nets = torch.zeros(buffer, 128, ht//8, wd//8, dtype=torch.half, device="cuda").share_memory_()
3438
self.inps = torch.zeros(buffer, 128, ht//8, wd//8, dtype=torch.half, device="cuda").share_memory_()
3539

@@ -57,16 +61,20 @@ def __item_setter(self, index, item):
5761
self.disps[index] = item[3]
5862

5963
if item[4] is not None:
60-
self.intrinsics[index] = item[4]
64+
depth = item[4][3::8,3::8]
65+
self.disps_sens[index] = torch.where(depth>0, 1.0/depth, depth)
6166

62-
if len(item) > 5:
63-
self.fmaps[index] = item[5]
67+
if item[5] is not None:
68+
self.intrinsics[index] = item[5]
6469

6570
if len(item) > 6:
66-
self.nets[index] = item[6]
71+
self.fmaps[index] = item[6]
6772

6873
if len(item) > 7:
69-
self.inps[index] = item[7]
74+
self.nets[index] = item[7]
75+
76+
if len(item) > 8:
77+
self.inps[index] = item[8]
7078

7179
def __setitem__(self, index, item):
7280
with self.get_lock():
@@ -179,11 +187,7 @@ def ba(self, target, weight, eta, ii, jj, t0=1, t1=None, itrs=2, lm=1e-4, ep=0.1
179187
if t1 is None:
180188
t1 = max(ii.max().item(), jj.max().item()) + 1
181189

182-
if eta is None:
183-
k = torch.unique(torch.cat([ii, jj], 0)).shape[0]
184-
eta = 1e-7 * torch.ones([k, self.ht//8, self.wd//8], device="cuda")
185-
186-
droid_backends.ba(self.poses, self.disps, self.intrinsics[0],
190+
droid_backends.ba(self.poses, self.disps, self.intrinsics[0], self.disps_sens,
187191
target, weight, eta, ii, jj, t0, t1, itrs, lm, ep, motion_only)
188192

189193
self.disps.clamp_(min=0.001)

droid_slam/droid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self, args):
2121
self.disable_vis = args.disable_vis
2222

2323
# store images, depth, poses, intrinsics (shared between processes)
24-
self.video = DepthVideo(args.image_size, args.buffer)
24+
self.video = DepthVideo(args.image_size, args.buffer, stereo=args.stereo)
2525

2626
# filter incoming frames so that there is enough motion
2727
self.filterx = MotionFilter(self.net, self.video, thresh=args.filter_thresh)

droid_slam/droid_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ def __call__(self, steps=12):
2525
""" main update """
2626

2727
t = self.video.counter.value
28-
self.video.normalize()
28+
if not self.video.stereo and not torch.any(self.video.disps_sens):
29+
self.video.normalize()
2930

30-
graph = FactorGraph(self.video, self.update_op, corr_impl="alt", max_factors=100000)
31+
graph = FactorGraph(self.video, self.update_op, corr_impl="alt", max_factors=16*t)
3132

3233
graph.add_proximity_factors(rad=self.backend_radius,
3334
nms=self.backend_nms,
@@ -37,4 +38,3 @@ def __call__(self, steps=12):
3738
graph.update_lowmem(steps=steps)
3839
graph.clear_edges()
3940
self.video.dirty[:t] = True
40-

droid_slam/droid_frontend.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ def __update(self):
4444
self.graph.add_proximity_factors(self.t1-5, max(self.t1-self.frontend_window, 0),
4545
rad=self.frontend_radius, nms=self.frontend_nms, thresh=self.frontend_thresh, beta=self.beta, remove=True)
4646

47+
self.video.disps[self.t1-1] = torch.where(self.video.disps_sens[self.t1-1] > 0,
48+
self.video.disps_sens[self.t1-1], self.video.disps[self.t1-1])
49+
4750
for itr in range(self.iters1):
4851
self.graph.update(None, None, use_inactive=True)
4952

@@ -80,11 +83,12 @@ def __initialize(self):
8083
for itr in range(8):
8184
self.graph.update(1, use_inactive=True)
8285

83-
self.graph.add_proximity_factors(0, 0, rad=2, nms=2, thresh=self.frontend_thresh)
86+
self.graph.add_proximity_factors(0, 0, rad=2, nms=2, thresh=self.frontend_thresh, remove=False)
8487

85-
for itr in range(12):
88+
for itr in range(8):
8689
self.graph.update(1, use_inactive=True)
8790

91+
8892
# self.video.normalize()
8993
self.video.poses[self.t1] = self.video.poses[self.t1-1].clone()
9094
self.video.disps[self.t1] = self.video.disps[self.t1-4:self.t1].mean()
@@ -99,6 +103,8 @@ def __initialize(self):
99103
self.video.ready.value = 1
100104
self.video.dirty[:self.t1] = True
101105

106+
self.graph.rm_factors(self.graph.ii < self.warmup-4, store=True)
107+
102108
def __call__(self):
103109
""" main update """
104110

droid_slam/factor_graph.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,9 @@ def add_factors(self, ii, jj, remove=False):
109109

110110
# correlation volume for new edges
111111
if self.corr_impl == "volume":
112-
fmap1 = self.video.fmaps[ii].to(self.device).unsqueeze(0)
113-
fmap2 = self.video.fmaps[jj].to(self.device).unsqueeze(0)
112+
c = (ii == jj).long()
113+
fmap1 = self.video.fmaps[ii,0].to(self.device).unsqueeze(0)
114+
fmap2 = self.video.fmaps[jj,c].to(self.device).unsqueeze(0)
114115
corr = CorrBlock(fmap1, fmap2)
115116
self.corr = corr if self.corr is None else self.corr.cat(corr)
116117

@@ -167,20 +168,27 @@ def rm_keyframe(self, ix):
167168
with self.video.get_lock():
168169
self.video.poses[ix] = self.video.poses[ix+1]
169170
self.video.disps[ix] = self.video.disps[ix+1]
171+
self.video.disps_sens[ix] = self.video.disps_sens[ix+1]
170172
self.video.intrinsics[ix] = self.video.intrinsics[ix+1]
171173

172174
self.video.nets[ix] = self.video.nets[ix+1]
173175
self.video.inps[ix] = self.video.inps[ix+1]
174176
self.video.fmaps[ix] = self.video.fmaps[ix+1]
175177

178+
m = (self.ii_inac == ix) | (self.jj_inac == ix)
179+
self.ii_inac[self.ii_inac >= ix] -= 1
180+
self.jj_inac[self.jj_inac >= ix] -= 1
181+
182+
if torch.any(m):
183+
self.ii_inac = self.ii_inac[~m]
184+
self.jj_inac = self.jj_inac[~m]
185+
self.target_inac = self.target_inac[:,~m]
186+
self.weight_inac = self.weight_inac[:,~m]
187+
176188
m = (self.ii == ix) | (self.jj == ix)
177189

178190
self.ii[self.ii >= ix] -= 1
179191
self.jj[self.jj >= ix] -= 1
180-
181-
self.ii_inac[self.ii_inac >= ix] -= 1
182-
self.jj_inac[self.jj_inac >= ix] -= 1
183-
184192
self.rm_factors(m, store=False)
185193

186194

@@ -239,7 +247,9 @@ def update_lowmem(self, t0=None, t1=None, itrs=2, use_inactive=False, EP=1e-7, s
239247

240248
# alternate corr implementation
241249
t = self.video.counter.value
242-
corr_op = AltCorrBlock(self.video.fmaps[None,:t])
250+
251+
num, rig, ch, ht, wd = self.video.fmaps.shape
252+
corr_op = AltCorrBlock(self.video.fmaps.view(1, num*rig, ch, ht, wd))
243253

244254
for step in range(steps):
245255
print("Global BA Iteration #{}".format(step+1))
@@ -253,11 +263,12 @@ def update_lowmem(self, t0=None, t1=None, itrs=2, use_inactive=False, EP=1e-7, s
253263
v = (self.ii >= i) & (self.ii < i + s)
254264
iis = self.ii[v]
255265
jjs = self.jj[v]
256-
266+
257267
ht, wd = self.coords0.shape[0:2]
258-
corr1 = corr_op(coords1[:,v], iis, jjs)
268+
corr1 = corr_op(coords1[:,v], rig * iis, rig * jjs + (iis == jjs).long())
259269

260-
with torch.cuda.amp.autocast(enabled=True):
270+
with torch.cuda.amp.autocast(enabled=True):
271+
261272
net, delta, weight, damping, _ = \
262273
self.update_op(self.net[:,v], self.video.inps[None,iis], corr1, motn[:,v], iis, jjs)
263274

@@ -267,7 +278,7 @@ def update_lowmem(self, t0=None, t1=None, itrs=2, use_inactive=False, EP=1e-7, s
267278
self.weight[:,v] = weight.float()
268279
self.damping[torch.unique(iis)] = damping
269280

270-
damping = self.damping[torch.unique(self.ii)].contiguous() + EP
281+
damping = .2 * self.damping[torch.unique(self.ii)].contiguous() + EP
271282
target = self.target.view(-1, ht, wd, 2).permute(0,3,1,2).contiguous()
272283
weight = self.weight.view(-1, ht, wd, 2).permute(0,3,1,2).contiguous()
273284

@@ -277,15 +288,16 @@ def update_lowmem(self, t0=None, t1=None, itrs=2, use_inactive=False, EP=1e-7, s
277288

278289
self.video.dirty[:t] = True
279290

280-
281291
def add_neighborhood_factors(self, t0, t1, r=3):
282292
""" add edges between neighboring frames within radius r """
283293

284294
ii, jj = torch.meshgrid(torch.arange(t0,t1), torch.arange(t0,t1))
285295
ii = ii.reshape(-1).to(dtype=torch.long, device=self.device)
286296
jj = jj.reshape(-1).to(dtype=torch.long, device=self.device)
287297

288-
keep = ((ii - jj).abs() > 0) & ((ii - jj).abs() <= r)
298+
c = 1 if self.video.stereo else 0
299+
300+
keep = ((ii - jj).abs() > c) & ((ii - jj).abs() <= r)
289301
self.add_factors(ii[keep], jj[keep])
290302

291303

@@ -307,8 +319,6 @@ def add_proximity_factors(self, t0=0, t1=0, rad=2, nms=2, beta=0.25, thresh=16.0
307319
ii1 = torch.cat([self.ii, self.ii_bad, self.ii_inac], 0)
308320
jj1 = torch.cat([self.jj, self.jj_bad, self.jj_inac], 0)
309321
for i, j in zip(ii1.cpu().numpy(), jj1.cpu().numpy()):
310-
if abs(i - j) <= 2:
311-
continue
312322
for di in range(-nms, nms+1):
313323
for dj in range(-nms, nms+1):
314324
if abs(di) + abs(dj) <= max(min(abs(i-j)-2, nms), 0):
@@ -318,17 +328,26 @@ def add_proximity_factors(self, t0=0, t1=0, rad=2, nms=2, beta=0.25, thresh=16.0
318328
if (t0 <= i1 < t) and (t1 <= j1 < t):
319329
d[(i1-t0)*(t-t1) + (j1-t1)] = np.inf
320330

331+
321332
es = []
322333
for i in range(t0, t):
323-
for j in range(i+1, min(i+rad+1, t)):
334+
if self.video.stereo:
335+
es.append((i, i))
336+
d[(i-t0)*(t-t1) + (i-t1)] = np.inf
337+
338+
for j in range(max(i-rad-1,0), i):
324339
es.append((i,j))
325340
es.append((j,i))
341+
d[(i-t0)*(t-t1) + (j-t1)] = np.inf
326342

327343
ix = torch.argsort(d)
328344
for k in ix:
329345
if d[k].item() > thresh:
330346
continue
331347

348+
if len(es) > self.max_factors:
349+
break
350+
332351
i = ii[k]
333352
j = jj[k]
334353

droid_slam/geom/projective_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ def projective_transform(poses, depths, intrinsics, ii, jj, jacobian=False, retu
101101

102102
# transform
103103
Gij = poses[:,jj] * poses[:,ii].inv()
104+
105+
Gij.data[:,ii==jj] = torch.as_tensor([-0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], device="cuda")
104106
X1, Ja = actp(Gij, X0, jacobian=jacobian)
105107

106108
# project (pinhole)

0 commit comments

Comments
 (0)