Skip to content
This repository was archived by the owner on Dec 18, 2024. It is now read-only.

Commit 092dc8f

Browse files
committed
add ONNX conversion and runner scripts
1 parent 9fcf7ce commit 092dc8f

File tree

5 files changed

+499
-5
lines changed

5 files changed

+499
-5
lines changed

dpt/vit.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -318,11 +318,17 @@ def forward(self, x):
318318
out_size = torch.Size((h // self.patch_size[1], w // self.patch_size[0]))
319319

320320
if not self.hybrid_backbone:
321-
layer_1 = self.act_postprocess1(layer_1.unflatten(2, out_size))
322-
layer_2 = self.act_postprocess2(layer_2.unflatten(2, out_size))
323-
324-
layer_3 = self.act_postprocess3(layer_3.unflatten(2, out_size))
325-
layer_4 = self.act_postprocess4(layer_4.unflatten(2, out_size))
321+
# according to https://github.com/isl-org/DPT/issues/42#issuecomment-944657114
322+
# layer_1 = self.act_postprocess1(layer_1.unflatten(2, out_size))
323+
# layer_2 = self.act_postprocess2(layer_2.unflatten(2, out_size))
324+
layer_1 = self.act_postprocess1(layer_1.view(layer_1.shape[0], layer_1.shape[1], *out_size))
325+
layer_2 = self.act_postprocess2(layer_2.view(layer_2.shape[0], layer_2.shape[1], *out_size))
326+
327+
# according to https://github.com/isl-org/DPT/issues/42#issuecomment-944657114
328+
# layer_3 = self.act_postprocess3(layer_3.unflatten(2, out_size))
329+
# layer_4 = self.act_postprocess4(layer_4.unflatten(2, out_size))
330+
layer_3 = self.act_postprocess3(layer_3.view(layer_3.shape[0], layer_3.shape[1], *out_size))
331+
layer_4 = self.act_postprocess4(layer_4.view(layer_4.shape[0], layer_4.shape[1], *out_size))
326332

327333
return layer_1, layer_2, layer_3, layer_4
328334

export_monodepth_onnx.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import torch
2+
import argparse
3+
import onnx
4+
import onnxruntime
5+
import json
6+
import numpy as np
7+
import cv2
8+
9+
from dpt.models import DPTDepthModel
10+
from dpt.midas_net import MidasNet_large
11+
import util.io
12+
13+
14+
def main(model_path, model_type, output_path, batch_size, test_image_path):
15+
# load network
16+
if model_type == "dpt_large": # DPT-Large
17+
net_w = net_h = 384
18+
model = DPTDepthModel(
19+
path=model_path,
20+
backbone="vitl16_384",
21+
non_negative=True,
22+
enable_attention_hooks=False,
23+
)
24+
normalization = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
25+
prediction_factor = 1
26+
elif model_type == "dpt_hybrid": # DPT-Hybrid
27+
net_w = net_h = 384
28+
model = DPTDepthModel(
29+
path=model_path,
30+
backbone="vitb_rn50_384",
31+
non_negative=True,
32+
enable_attention_hooks=False,
33+
)
34+
normalization = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
35+
prediction_factor = 1
36+
elif model_type == "dpt_hybrid_kitti":
37+
net_w = 1216
38+
net_h = 352
39+
40+
model = DPTDepthModel(
41+
path=model_path,
42+
scale=0.00006016,
43+
shift=0.00579,
44+
invert=True,
45+
backbone="vitb_rn50_384",
46+
non_negative=True,
47+
enable_attention_hooks=False,
48+
)
49+
50+
normalization = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
51+
prediction_factor = 256
52+
elif model_type == "dpt_hybrid_nyu":
53+
net_w = 640
54+
net_h = 480
55+
56+
model = DPTDepthModel(
57+
path=model_path,
58+
scale=0.000305,
59+
shift=0.1378,
60+
invert=True,
61+
backbone="vitb_rn50_384",
62+
non_negative=True,
63+
enable_attention_hooks=False,
64+
)
65+
66+
normalization = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
67+
prediction_factor = 1000.0
68+
elif model_type == "midas_v21": # Convolutional model
69+
net_w = net_h = 384
70+
71+
model = MidasNet_large(model_path, non_negative=True)
72+
normalization = dict(
73+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
74+
)
75+
prediction_factor = 1
76+
else:
77+
assert (
78+
False
79+
), f"model_type '{model_type}' not implemented, use: --model_type [dpt_large|dpt_hybrid|dpt_hybrid_kitti|dpt_hybrid_nyu|midas_v21]"
80+
81+
model.eval()
82+
83+
dummy_input = torch.zeros((batch_size, 3, net_h, net_w))
84+
# TODO: right now, the batch size is not dynamic due to the PyTorch tracer
85+
# treating the batch size as constant (see get_attention() in vit.py).
86+
# Therefore you have to use a batch size of one to use this together with
87+
# run_monodepth_onnx.py.
88+
torch.onnx.export(
89+
model,
90+
dummy_input,
91+
output_path,
92+
input_names=["input"],
93+
output_names=["output"],
94+
opset_version=11,
95+
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
96+
)
97+
98+
# store normalization configuration
99+
model_onnx = onnx.load(output_path)
100+
meta_imagesize = model_onnx.metadata_props.add()
101+
meta_imagesize.key = "ImageSize"
102+
meta_imagesize.value = json.dumps([net_w, net_h])
103+
meta_normalization = model_onnx.metadata_props.add()
104+
meta_normalization.key = "Normalization"
105+
meta_normalization.value = json.dumps(normalization)
106+
meta_prediction_factor = model_onnx.metadata_props.add()
107+
meta_prediction_factor.key = "PredictionFactor"
108+
meta_prediction_factor.value = str(prediction_factor)
109+
onnx.save(model_onnx, output_path)
110+
del model_onnx
111+
112+
if test_image_path is not None:
113+
# load test image
114+
img = util.io.read_image(test_image_path)
115+
116+
# resize
117+
img_input = cv2.resize(img, (net_h, net_w), cv2.INTER_AREA)
118+
119+
# normalize
120+
img_input = (img_input - np.array(normalization["mean"])) / np.array(normalization["std"])
121+
122+
# transpose from HWC to CHW
123+
img_input = img_input.transpose(2, 0, 1)
124+
125+
# add batch dimension
126+
img_input = np.stack([img_input] * batch_size)
127+
128+
# validate accuracy of exported model
129+
torch_out = model(torch.from_numpy(img_input.astype(np.float32))).detach().cpu().numpy()
130+
session = onnxruntime.InferenceSession(
131+
output_path,
132+
providers=[
133+
"TensorrtExecutionProvider",
134+
"CUDAExecutionProvider",
135+
"CPUExecutionProvider",
136+
],
137+
)
138+
onnx_out = session.run(["output"], {"input": img_input.astype(np.float32)})[0]
139+
140+
# compare ONNX Runtime and PyTorch results
141+
np.testing.assert_allclose(torch_out, onnx_out, rtol=1e-02, atol=1e-04)
142+
print("Exported model predictions match original")
143+
144+
145+
if __name__ == "__main__":
146+
parser = argparse.ArgumentParser()
147+
parser.add_argument("model_weights", help="path to input model weights")
148+
parser.add_argument("output_path", help="path to output model weights")
149+
parser.add_argument(
150+
"-t",
151+
"--model_type",
152+
default="dpt_hybrid",
153+
help="model type [dpt_large|dpt_hybrid|midas_v21]",
154+
)
155+
parser.add_argument("--batch_size", default=1, help="batch size used for tracing")
156+
parser.add_argument(
157+
"--test_image_path",
158+
type=str,
159+
help="path to some image to test the accuracy of the exported model against the original"
160+
)
161+
162+
args = parser.parse_args()
163+
main(args.model_weights, args.model_type, args.output_path, args.batch_size, args.test_image_path)

export_segmentation_onnx.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import torch
2+
import argparse
3+
import onnx
4+
import onnxruntime
5+
import json
6+
import numpy as np
7+
import cv2
8+
9+
from dpt.models import DPTSegmentationModel
10+
import util.io
11+
12+
13+
def main(model_path, model_type, output_path, batch_size, test_image_path):
14+
net_w = net_h = 480
15+
normalization = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
16+
17+
# load network
18+
if model_type == "dpt_large":
19+
model = DPTSegmentationModel(
20+
150,
21+
path=model_path,
22+
backbone="vitl16_384",
23+
)
24+
elif model_type == "dpt_hybrid":
25+
model = DPTSegmentationModel(
26+
150,
27+
path=model_path,
28+
backbone="vitb_rn50_384",
29+
)
30+
else:
31+
assert (
32+
False
33+
), f"model_type '{model_type}' not implemented, use: --model_type [dpt_large|dpt_hybrid]"
34+
35+
model.eval()
36+
37+
dummy_input = torch.zeros((batch_size, 3, net_h, net_w))
38+
# TODO: right now, the batch size is not dynamic due to the PyTorch tracer
39+
# treating the batch size as constant (see get_attention() in vit.py).
40+
# Therefore you have to use a batch size of one to use this together with
41+
# run_monodepth_onnx.py.
42+
torch.onnx.export(
43+
model,
44+
dummy_input,
45+
output_path,
46+
input_names=["input"],
47+
output_names=["output"],
48+
opset_version=11,
49+
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
50+
)
51+
52+
# store normalization configuration
53+
model_onnx = onnx.load(output_path)
54+
meta_imagesize = model_onnx.metadata_props.add()
55+
meta_imagesize.key = "ImageSize"
56+
meta_imagesize.value = json.dumps([net_w, net_h])
57+
meta_normalization = model_onnx.metadata_props.add()
58+
meta_normalization.key = "Normalization"
59+
meta_normalization.value = json.dumps(normalization)
60+
onnx.save(model_onnx, output_path)
61+
del model_onnx
62+
63+
if test_image_path is not None:
64+
# load test image
65+
img = util.io.read_image(test_image_path)
66+
67+
# resize
68+
img_input = cv2.resize(img, (net_h, net_w), cv2.INTER_AREA)
69+
70+
# normalize
71+
img_input = (img_input - np.array(normalization["mean"])) / np.array(normalization["std"])
72+
73+
# transpose from HWC to CHW
74+
img_input = img_input.transpose(2, 0, 1)
75+
76+
# add batch dimension
77+
img_input = np.stack([img_input] * batch_size)
78+
79+
# validate accuracy of exported model
80+
torch_out = model(torch.from_numpy(img_input.astype(np.float32))).detach().cpu().numpy()
81+
session = onnxruntime.InferenceSession(
82+
output_path,
83+
providers=[
84+
"TensorrtExecutionProvider",
85+
"CUDAExecutionProvider",
86+
"CPUExecutionProvider",
87+
],
88+
)
89+
onnx_out = session.run(["output"], {"input": img_input.astype(np.float32)})[0]
90+
91+
# compare ONNX Runtime and PyTorch results
92+
np.testing.assert_allclose(torch_out, onnx_out, rtol=1e-02, atol=1e-04)
93+
print("Exported model predictions match original")
94+
95+
96+
if __name__ == "__main__":
97+
parser = argparse.ArgumentParser()
98+
parser.add_argument("model_weights", help="path to input model weights")
99+
parser.add_argument("output_path", help="path to output model weights")
100+
parser.add_argument(
101+
"-t",
102+
"--model_type",
103+
default="dpt_hybrid",
104+
help="model type [dpt_large|dpt_hybrid]",
105+
)
106+
parser.add_argument("--batch_size", default=1, help="batch size used for tracing")
107+
parser.add_argument(
108+
"--test_image_path",
109+
type=str,
110+
help="path to some image to test the accuracy of the exported model against the original"
111+
)
112+
113+
args = parser.parse_args()
114+
main(args.model_weights, args.model_type, args.output_path, args.batch_size, args.test_image_path)

0 commit comments

Comments
 (0)