Skip to content

Commit 21a8bee

Browse files
committed
fast mode
1 parent 2d6f5cd commit 21a8bee

File tree

4 files changed

+31
-25
lines changed

4 files changed

+31
-25
lines changed

configs/wildgs_slam.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ device: "cuda:0"
88
mapping:
99
online_plotting: False # render and save images online
1010
full_resolution: False # if using the full resolution for mapping, but we always keep downsampled size for tracking
11-
final_refine_iters: 20000 # iterations of final refinement
11+
final_refine_iters: 3000 # iterations of final refinement
1212
eval_before_final_ba: True
1313
deform_gaussians: True # apply transformation on Gaussians to account for loop closure and BA
1414
pcd_downsample: 32 # downsamples the unprojected depth map --> point cloud

src/frontend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def __update(self, force_to_add_keyframe):
6161
for itr in range(self.iters1):
6262
self.graph.update(None, None, use_inactive=True)
6363

64-
if itr == 1 and self.video.metric_depth_reg and self.cfg['tracking']["uncertainty_params"]['activate']:
65-
self.video.filter_high_err_mono_depth(self.t1-1,self.graph.ii,self.graph.jj)
64+
# if itr == 1 and self.video.metric_depth_reg and self.cfg['tracking']["uncertainty_params"]['activate']:
65+
# self.video.filter_high_err_mono_depth(self.t1-1,self.graph.ii,self.graph.jj)
6666

6767
d = self.video.distance([self.t1-2], [self.t1-1], beta=self.beta, bidirectional=True)
6868
# Ssee self.max_consecutive_drop_of_keyframes in initi for explanation of the following process

src/mapper.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,12 @@ def run(self):
241241
self.keyframe_optimizers = torch.optim.Adam(opt_params)
242242

243243
with Lock():
244-
gaussian_split = self.map_opt_online(
245-
self.current_window, iters=self.mapping_itr_num
246-
)
244+
if video_idx % 4 == 0:
245+
gaussian_split = self.map_opt_online(
246+
self.current_window, iters=self.mapping_itr_num
247+
)
248+
else:
249+
self._update_occ_aware_visibility(self.current_window)
247250
if gaussian_split:
248251
# do one more iteration after densify and prune
249252
self.map_opt_online(self.current_window, iters=1)
@@ -546,6 +549,20 @@ def _update_mapping_points(
546549
scales, "scaling"
547550
)["scaling"]
548551

552+
def _update_occ_aware_visibility(self, current_window):
553+
self.occ_aware_visibility = {}
554+
for kf_idx in current_window:
555+
viewpoint = self.cameras[kf_idx]
556+
render_pkg = render(
557+
viewpoint,
558+
self.gaussians,
559+
self.pipeline_params,
560+
self.background,
561+
)
562+
self.occ_aware_visibility[kf_idx] = (
563+
render_pkg["n_touched"] > 0
564+
).long()
565+
549566
def get_w2c_and_depth(self, video_idx, idx, mono_depth, print_info=False):
550567
est_frontend_depth, valid_depth_mask, c2w = self.video.get_depth_and_pose(
551568
video_idx, self.device
@@ -1146,18 +1163,7 @@ def map_opt_online(self, current_window, iters=1):
11461163
# Deinsifying / Pruning Gaussians
11471164
with torch.no_grad():
11481165
if cur_iter == iters - 1:
1149-
self.occ_aware_visibility = {}
1150-
for kf_idx in current_window:
1151-
viewpoint = self.cameras[kf_idx]
1152-
render_pkg = render(
1153-
viewpoint,
1154-
self.gaussians,
1155-
self.pipeline_params,
1156-
self.background,
1157-
)
1158-
self.occ_aware_visibility[kf_idx] = (
1159-
render_pkg["n_touched"] > 0
1160-
).long()
1166+
self._update_occ_aware_visibility(current_window)
11611167

11621168
self.gaussians.max_radii2D[visibility_filter] = torch.max(
11631169
self.gaussians.max_radii2D[visibility_filter],

src/utils/eval_traj.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,13 @@ def full_traj_eval(traj_filler, mapper, plot_parent_dir, plot_name, stream,logge
145145
traj_est_lietorch = traj_est_inv.inv()
146146
traj_est = traj_est_lietorch.matrix().data.cpu().numpy()
147147

148-
## refine non-keyframe-traj from the mapping
149-
for i in tqdm(range(traj_est.shape[0])):
150-
img_feat = dino_feats[i]
151-
w2c_refined = mapper.refine_pose_non_key_frame(i,
152-
torch.tensor(np.linalg.inv(traj_est[i])),
153-
features=img_feat)
154-
traj_est[i] = np.linalg.inv(w2c_refined.cpu().numpy())
148+
# ## refine non-keyframe-traj from the mapping
149+
# for i in tqdm(range(traj_est.shape[0])):
150+
# img_feat = dino_feats[i]
151+
# w2c_refined = mapper.refine_pose_non_key_frame(i,
152+
# torch.tensor(np.linalg.inv(traj_est[i])),
153+
# features=img_feat)
154+
# traj_est[i] = np.linalg.inv(w2c_refined.cpu().numpy())
155155

156156
kf_num = traj_filler.video.counter.value
157157
kf_timestamps = traj_filler.video.timestamp[:kf_num].cpu().int().numpy()

0 commit comments

Comments
 (0)