|
31 | 31 | if __name__ == "__main__": |
32 | 32 | parser = argparse.ArgumentParser() |
33 | 33 | parser.add_argument('--seg_dir', default='data/nocs_seg', help='Segmentation PKL files for NOCS') |
| 34 | + parser.add_argument('--nocs_dir', default='data/nocs', help='NOCS real test image path') |
34 | 35 | parser.add_argument('--out_dir', default='data/nocs_prediction', help='Output directory for predictions') |
35 | 36 | parser.add_argument('--cp_device', type=int, default=0, help='GPU device number for custom voting algorithms') |
36 | 37 | parser.add_argument('--ckpt_path', default='checkpoints', help='Model checkpoint path') |
37 | 38 | parser.add_argument('--angle_prec', type=float, default=1.5, help='Angle precision in orientation voting') |
38 | 39 | parser.add_argument('--num_rots', type=int, default=72, help='Number of candidate center votes generated for a given point pair') |
| 40 | + parser.add_argument('--bbox_mask', action='store_true', help='Whether to use bbox mask instead of instance segmentations.') |
39 | 41 | args = parser.parse_args() |
40 | 42 |
|
41 | 43 | cp_device = args.cp_device |
|
102 | 104 | bcelogits = torch.nn.BCEWithLogitsLoss() |
103 | 105 |
|
104 | 106 | for res in tqdm(final_results): |
105 | | - img = cv2.imread(res['image_path'] + '_color.png')[:, :, ::-1] |
106 | | - depth = cv2.imread(res['image_path'] + '_depth.png', -1) |
| 107 | + img = cv2.imread(os.path.join(args.nocs_dir, res['image_path'][5:] + '_color.png'))[:, :, ::-1] |
| 108 | + depth = cv2.imread(os.path.join(args.nocs_dir, res['image_path'][5:] + '_depth.png'), -1) |
107 | 109 |
|
108 | 110 | bboxs = res['pred_bboxes'] |
109 | | - masks = res['pred_masks'] |
| 111 | + masks = res['pred_masks'].copy() |
110 | 112 | RTs = np.zeros((len(bboxs), 4, 4), dtype=np.float32) |
111 | 113 | scales = np.zeros((len(bboxs), 3), dtype=np.float32) |
112 | 114 | cls_ids = res['pred_class_ids'] |
113 | 115 |
|
114 | 116 | for i, bbox in enumerate(bboxs): |
| 117 | + if args.bbox_mask: |
| 118 | + masks[:, :, i][bbox[0]:bbox[2], bbox[1]:bbox[3]] = True |
| 119 | + |
115 | 120 | cls_id = cls_ids[i] |
116 | 121 | cls_name = synset_names[cls_id] |
117 | 122 |
|
|
297 | 302 | if cfg.regress_right: |
298 | 303 | right = final_directions[1] |
299 | 304 | right -= np.dot(up, right) * up |
300 | | - right /= np.linalg.norm(right) |
| 305 | + right /= (np.linalg.norm(right) + 1e-9) |
301 | 306 | else: |
302 | 307 | right = np.array([0, -up[2], up[1]]) |
303 | | - right /= np.linalg.norm(right) |
| 308 | + right /= (np.linalg.norm(right) + 1e-9) |
304 | 309 |
|
305 | 310 | if (cls_name == 'laptop') and (laptop_up is not None): |
306 | 311 | if np.dot(up, laptop_up) + np.dot(right, laptop_up) < np.dot(up, -laptop_up) + np.dot(right, -laptop_up): |
|
311 | 316 | right = up |
312 | 317 | up = laptop_up |
313 | 318 | right -= np.dot(up, right) * up |
314 | | - right /= np.linalg.norm(right) |
315 | | - |
| 319 | + right /= (np.linalg.norm(right) + 1e-9) |
| 320 | + |
| 321 | + if np.linalg.norm(right) < 1e-7: # right is zero |
| 322 | + right = np.random.randn(3) |
| 323 | + right -= right.dot(up) * up |
| 324 | + right /= np.linalg.norm(right) |
| 325 | + |
316 | 326 | if cfg.z_right: |
317 | 327 | R_est = np.stack([np.cross(up, right), up, right], -1) |
318 | 328 | else: |
319 | 329 | R_est = np.stack([right, up, np.cross(right, up)], -1) |
320 | 330 |
|
321 | 331 | pred_scale = np.exp(preds_scale[0].mean(0).cpu().numpy()) * cfg.scale_mean * 2 |
322 | 332 | scale_norm = np.linalg.norm(pred_scale) |
| 333 | + assert scale_norm > 0 |
323 | 334 | RTs[i][:3, :3] = R_est * scale_norm |
| 335 | + RTs[i][3, 3] = 1. |
324 | 336 | scales[i, :] = pred_scale / scale_norm |
325 | 337 |
|
326 | 338 | res['pred_RTs'] = RTs |
|
0 commit comments