Skip to content

Commit 83b272b

Browse files
committed
custom training
1 parent a804fac commit 83b272b

File tree

2 files changed

+79
-1
lines changed

2 files changed

+79
-1
lines changed

README.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,27 @@ python sunrgbd/eval.py | tee sunrgbd/map.txt
192192
</details>
193193

194194
# Train on Your Own Object Collections
195-
Coming soon.
195+
196+
<details>
197+
<summary><b>Configuration Explained</b></summary>
198+
199+
To train on custom objects, it is necessary to understand some parameters in configuration files.
200+
- **up_sym**: Whether the objects look like a cylinder from up to bottom (e.g., bottles). This is to ensure the voting target is unambiguous.
201+
- **right_sym**: Whether the objects look like a cylinder from left to right (e.g., rolls). This is to ensure the voting target is unambiguous.
202+
- **regress_right**: Whether to predict the right axis. Some symmetric objects only have a up axis well defined (e.g., bowls, bottles), while some do not (e.g., laptops, mugs).
203+
- **z_right**: Whether the objects are placed such that the right axis is [0, 0, 1] (default: [1, 0, 0]).
204+
</details>
205+
<details>
206+
<summary><b>Voting Statistics Generation</b></summary>
207+
208+
Next, we need to know the ``scale_range`` (used for data augmentation, control possible object scales along the diagonal), ``vote_range`` (the range for center voting targets $\mu,\nu$), and ``scale_mean`` (the average 3D scale, used for scale voting). To generate them, you may refer to ``gen_stats.py``.
209+
</details>
210+
211+
<details>
212+
<summary><b>Write Configuration Files and Train</b></summary>
213+
214+
After you prepare the necessary configurations and voting statistics, you can write your own configuration file similar to that in ``config/category``, and then run ``train.py``.
215+
</details>
196216

197217
# Citation
198218
```

gen_stats.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import argparse
2+
from utils.dataset import generate_target
3+
from utils.util import typename2shapenetid
4+
import os
5+
import open3d as o3d
6+
import numpy as np
7+
from tqdm import tqdm
8+
9+
from utils.util import estimate_normals
10+
11+
12+
if __name__ == '__main__':
13+
parser = argparse.ArgumentParser()
14+
parser.add_argument('--category', default='bowl', help='Category name')
15+
parser.add_argument('--shapenet_root', default='/home/neil/disk/ShapeNetCore.v2', help='ShapeNet root directory')
16+
parser.add_argument('--up_sym', action='store_true', help='If the objects look similar to a cylinder from up ([0, 1, 0]) to bottom')
17+
parser.add_argument('--right_sym', action='store_true', help='If the objects look similar to a cylinder from left to right')
18+
parser.add_argument('--z_right', action='store_true', help='If the objects use [0, 0, 1] as the right axis (default, [1, 0, 0])')
19+
args = parser.parse_args()
20+
21+
name_path = 'data/shapenet_names/{}.txt'.format(args.category)
22+
if os.path.exists(name_path):
23+
shapenames = open(name_path).read().splitlines()
24+
else:
25+
shapenet_id = typename2shapenetid[args.category]
26+
shapenames = os.listdir(os.path.join(args.shapenet_root, '{}'.format(shapenet_id)))
27+
shapenames = [shapenet_id + '/' + name for name in shapenames]
28+
29+
scale_range = [np.inf, -np.inf]
30+
vote_range = [0, 0]
31+
scale_mean = []
32+
for model_name in tqdm(shapenames):
33+
shapenet_cls, mesh_name = model_name.split('/')
34+
path = os.path.join(args.shapenet_root, f'{shapenet_cls}/{mesh_name}/models/model_normalized.obj')
35+
mesh = o3d.io.read_triangle_mesh(path)
36+
pc = np.array(mesh.sample_points_uniformly(2048).points)
37+
38+
# normalize to center
39+
pc -= (np.min(pc, 0) + np.max(pc, 0)) / 2
40+
41+
normals = estimate_normals(pc, 60)
42+
targets_tr = generate_target(pc, normals, args.up_sym, args.right_sym, args.z_right, 100000)[0]
43+
44+
diag_length = np.linalg.norm(np.max(pc, 0) - np.min(pc, 0))
45+
46+
scale_range[0] = min(scale_range[0], diag_length)
47+
scale_range[1] = max(scale_range[1], diag_length)
48+
49+
vote_range[0] = max(vote_range[0], np.max(np.abs(targets_tr[:, 0])))
50+
vote_range[1] = max(vote_range[1], np.max(targets_tr[:, 1]))
51+
52+
scale_mean.append(np.max(pc, 0))
53+
scale_mean = np.mean(scale_mean, 0)
54+
55+
print(f'scale_range: {scale_range}')
56+
print(f'vote_range: {vote_range}')
57+
print(f'scale_mean: {scale_mean}')
58+

0 commit comments

Comments
 (0)