Skip to content

Commit f45ba24

Browse files
committed
Make fast mode as an option in cfg
1 parent 21a8bee commit f45ba24

File tree

9 files changed

+36
-20
lines changed

9 files changed

+36
-20
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ __pycache__/
99

1010
pretrained/
1111

12-
output/
12+
output*/
1313

1414
.vscode/
1515

configs/Dynamic/Wild_SLAM_Mocap/basketball.yaml renamed to configs/Dynamic/Wild_SLAM_Mocap/ball.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ inherit_from: ./configs/Dynamic/Wild_SLAM_Mocap/wild_slam_mocap.yaml
22
scene: our_basketball
33

44
data:
5-
input_folder: ROOT_FOLDER_PLACEHOLDER/scene1/basketball
5+
input_folder: ROOT_FOLDER_PLACEHOLDER/scene1/ball

configs/Dynamic/Wild_SLAM_Mocap/person_tracking2.yaml renamed to configs/Dynamic/Wild_SLAM_Mocap/person_tracking.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
inherit_from: ./configs/Dynamic/Wild_SLAM_Mocap/wild_slam_mocap.yaml
2-
scene: person_tracking2
2+
scene: person_tracking
33

44
data:
5-
input_folder: ROOT_FOLDER_PLACEHOLDER/scene1/person_tracking2
5+
input_folder: ROOT_FOLDER_PLACEHOLDER/scene1/person_tracking
66

77
cam:
88
fx: 647.5684814453125

configs/wildgs_slam.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@ gui: False
33
stride: 1 # use every X image from the dataset
44
max_frames: -1 # use the first X images from the dataset, -1 means using all
55
setup_seed: 43
6+
fast_mode: False
67
device: "cuda:0"
78

89
mapping:
910
online_plotting: False # render and save images online
1011
full_resolution: False # if using the full resolution for mapping, but we always keep downsampled size for tracking
11-
final_refine_iters: 3000 # iterations of final refinement
12-
eval_before_final_ba: True
12+
final_refine_iters: 20000 # iterations of final refinement, it will be forced to be 3000 if fast_mode is on
13+
eval_before_final_ba: False
1314
deform_gaussians: True # apply transformation on Gaussians to account for loop closure and BA
1415
pcd_downsample: 32 # downsamples the unprojected depth map --> point cloud
1516
pcd_downsample_init: 16 # first frame downsampling factor is smaller
@@ -140,7 +141,7 @@ cam:
140141
mono_prior:
141142
# Metric depth model, only support:
142143
# metric3d_vit_small, metric3d_vit_large and metric3d_vit_giant2
143-
# dpt2_{vits,vitb,vitl}_{hypersim,vkitti}_{20,80} (see src/mono_estimator.py for detail)
144+
# dpt2_{vits,vitb,vitl}_{hypersim,vkitti}_{20,80} (see src/utils/mono_priors/metric_depth_estimators.py for detail)
144145
# e.g. dpt2_vitl_hypersim_20, dpt2_vitl_vkitti_80
145146
depth: 'metric3d_vit_large'
146147

run.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def setup_seed(seed):
2626

2727
cfg = config.load_config(args.config)
2828
setup_seed(cfg['setup_seed'])
29+
if cfg['fast_mode']:
30+
# Force the final refine iterations to be 3000 if in fast mode
31+
cfg['mapping']['final_refine_iters'] = 3000
2932

3033
output_dir = cfg['data']['output']
3134
output_dir = output_dir+f"/{cfg['scene']}"

src/frontend.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,9 @@ def __update(self, force_to_add_keyframe):
6161
for itr in range(self.iters1):
6262
self.graph.update(None, None, use_inactive=True)
6363

64-
# if itr == 1 and self.video.metric_depth_reg and self.cfg['tracking']["uncertainty_params"]['activate']:
65-
# self.video.filter_high_err_mono_depth(self.t1-1,self.graph.ii,self.graph.jj)
64+
if not self.cfg['fast_mode']:
65+
if itr == 1 and self.video.metric_depth_reg and self.cfg['tracking']["uncertainty_params"]['activate']:
66+
self.video.filter_high_err_mono_depth(self.t1-1,self.graph.ii,self.graph.jj)
6667

6768
d = self.video.distance([self.t1-2], [self.t1-1], beta=self.beta, bidirectional=True)
6869
# Ssee self.max_consecutive_drop_of_keyframes in initi for explanation of the following process

src/mapper.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,12 +241,20 @@ def run(self):
241241
self.keyframe_optimizers = torch.optim.Adam(opt_params)
242242

243243
with Lock():
244-
if video_idx % 4 == 0:
244+
if self.config['fast_mode']:
245+
# We are in fast mode,
246+
# update map and uncertainty MLP every 4 key frames
247+
if video_idx % 4 == 0:
248+
gaussian_split = self.map_opt_online(
249+
self.current_window, iters=self.mapping_itr_num
250+
)
251+
else:
252+
self._update_occ_aware_visibility(self.current_window)
253+
else:
245254
gaussian_split = self.map_opt_online(
246255
self.current_window, iters=self.mapping_itr_num
247256
)
248-
else:
249-
self._update_occ_aware_visibility(self.current_window)
257+
250258
if gaussian_split:
251259
# do one more iteration after densify and prune
252260
self.map_opt_online(self.current_window, iters=1)

src/slam.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def terminate(self):
209209
self.stream,
210210
self.logger,
211211
self.printer,
212+
self.cfg['fast_mode'],
212213
)
213214

214215
self.mapper.gaussians.save_ply(f"{self.save_dir}/final_gs.ply")

src/utils/eval_traj.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,18 +140,20 @@ def kf_traj_eval(npz_path, plot_parent_dir,plot_name, stream, logger,printer):
140140
return ape_statistics, s, r_a, t_a
141141

142142

143-
def full_traj_eval(traj_filler, mapper, plot_parent_dir, plot_name, stream,logger,printer):
143+
def full_traj_eval(traj_filler, mapper, plot_parent_dir, plot_name, stream, logger, printer, fast_mode=False):
144144
traj_est_inv, dino_feats = traj_filler(stream)
145145
traj_est_lietorch = traj_est_inv.inv()
146146
traj_est = traj_est_lietorch.matrix().data.cpu().numpy()
147147

148-
# ## refine non-keyframe-traj from the mapping
149-
# for i in tqdm(range(traj_est.shape[0])):
150-
# img_feat = dino_feats[i]
151-
# w2c_refined = mapper.refine_pose_non_key_frame(i,
152-
# torch.tensor(np.linalg.inv(traj_est[i])),
153-
# features=img_feat)
154-
# traj_est[i] = np.linalg.inv(w2c_refined.cpu().numpy())
148+
if not fast_mode:
149+
# refine non-keyframe-traj from the mapping
150+
# this is time-consuming with minimal tracking improvement
151+
for i in tqdm(range(traj_est.shape[0])):
152+
img_feat = dino_feats[i]
153+
w2c_refined = mapper.refine_pose_non_key_frame(i,
154+
torch.tensor(np.linalg.inv(traj_est[i])),
155+
features=img_feat)
156+
traj_est[i] = np.linalg.inv(w2c_refined.cpu().numpy())
155157

156158
kf_num = traj_filler.video.counter.value
157159
kf_timestamps = traj_filler.video.timestamp[:kf_num].cpu().int().numpy()

0 commit comments

Comments
 (0)