1+ import torch
2+ import argparse
3+ import onnx
4+ import onnxruntime
5+ import json
6+ import numpy as np
7+ import cv2
8+
9+ from dpt .models import DPTDepthModel
10+ from dpt .midas_net import MidasNet_large
11+ import util .io
12+
13+
14+ def main (model_path , model_type , output_path , batch_size , test_image_path ):
15+ # load network
16+ if model_type == "dpt_large" : # DPT-Large
17+ net_w = net_h = 384
18+ model = DPTDepthModel (
19+ path = model_path ,
20+ backbone = "vitl16_384" ,
21+ non_negative = True ,
22+ enable_attention_hooks = False ,
23+ )
24+ normalization = dict (mean = [0.5 , 0.5 , 0.5 ], std = [0.5 , 0.5 , 0.5 ])
25+ prediction_factor = 1
26+ elif model_type == "dpt_hybrid" : # DPT-Hybrid
27+ net_w = net_h = 384
28+ model = DPTDepthModel (
29+ path = model_path ,
30+ backbone = "vitb_rn50_384" ,
31+ non_negative = True ,
32+ enable_attention_hooks = False ,
33+ )
34+ normalization = dict (mean = [0.5 , 0.5 , 0.5 ], std = [0.5 , 0.5 , 0.5 ])
35+ prediction_factor = 1
36+ elif model_type == "dpt_hybrid_kitti" :
37+ net_w = 1216
38+ net_h = 352
39+
40+ model = DPTDepthModel (
41+ path = model_path ,
42+ scale = 0.00006016 ,
43+ shift = 0.00579 ,
44+ invert = True ,
45+ backbone = "vitb_rn50_384" ,
46+ non_negative = True ,
47+ enable_attention_hooks = False ,
48+ )
49+
50+ normalization = dict (mean = [0.5 , 0.5 , 0.5 ], std = [0.5 , 0.5 , 0.5 ])
51+ prediction_factor = 256
52+ elif model_type == "dpt_hybrid_nyu" :
53+ net_w = 640
54+ net_h = 480
55+
56+ model = DPTDepthModel (
57+ path = model_path ,
58+ scale = 0.000305 ,
59+ shift = 0.1378 ,
60+ invert = True ,
61+ backbone = "vitb_rn50_384" ,
62+ non_negative = True ,
63+ enable_attention_hooks = False ,
64+ )
65+
66+ normalization = dict (mean = [0.5 , 0.5 , 0.5 ], std = [0.5 , 0.5 , 0.5 ])
67+ prediction_factor = 1000.0
68+ elif model_type == "midas_v21" : # Convolutional model
69+ net_w = net_h = 384
70+
71+ model = MidasNet_large (model_path , non_negative = True )
72+ normalization = dict (
73+ mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]
74+ )
75+ prediction_factor = 1
76+ else :
77+ assert (
78+ False
79+ ), f"model_type '{ model_type } ' not implemented, use: --model_type [dpt_large|dpt_hybrid|dpt_hybrid_kitti|dpt_hybrid_nyu|midas_v21]"
80+
81+ model .eval ()
82+
83+ dummy_input = torch .zeros ((batch_size , 3 , net_h , net_w ))
84+ # TODO: right now, the batch size is not dynamic due to the PyTorch tracer
85+ # treating the batch size as constant (see get_attention() in vit.py).
86+ # Therefore you have to use a batch size of one to use this together with
87+ # run_monodepth_onnx.py.
88+ torch .onnx .export (
89+ model ,
90+ dummy_input ,
91+ output_path ,
92+ input_names = ["input" ],
93+ output_names = ["output" ],
94+ opset_version = 11 ,
95+ dynamic_axes = {"input" : {0 : "batch_size" }, "output" : {0 : "batch_size" }},
96+ )
97+
98+ # store normalization configuration
99+ model_onnx = onnx .load (output_path )
100+ meta_imagesize = model_onnx .metadata_props .add ()
101+ meta_imagesize .key = "ImageSize"
102+ meta_imagesize .value = json .dumps ([net_w , net_h ])
103+ meta_normalization = model_onnx .metadata_props .add ()
104+ meta_normalization .key = "Normalization"
105+ meta_normalization .value = json .dumps (normalization )
106+ meta_prediction_factor = model_onnx .metadata_props .add ()
107+ meta_prediction_factor .key = "PredictionFactor"
108+ meta_prediction_factor .value = str (prediction_factor )
109+ onnx .save (model_onnx , output_path )
110+ del model_onnx
111+
112+ if test_image_path is not None :
113+ # load test image
114+ img = util .io .read_image (test_image_path )
115+
116+ # resize
117+ img_input = cv2 .resize (img , (net_h , net_w ), cv2 .INTER_AREA )
118+
119+ # normalize
120+ img_input = (img_input - np .array (normalization ["mean" ])) / np .array (normalization ["std" ])
121+
122+ # transpose from HWC to CHW
123+ img_input = img_input .transpose (2 , 0 , 1 )
124+
125+ # add batch dimension
126+ img_input = np .stack ([img_input ] * batch_size )
127+
128+ # validate accuracy of exported model
129+ torch_out = model (torch .from_numpy (img_input .astype (np .float32 ))).detach ().cpu ().numpy ()
130+ session = onnxruntime .InferenceSession (
131+ output_path ,
132+ providers = [
133+ "TensorrtExecutionProvider" ,
134+ "CUDAExecutionProvider" ,
135+ "CPUExecutionProvider" ,
136+ ],
137+ )
138+ onnx_out = session .run (["output" ], {"input" : img_input .astype (np .float32 )})[0 ]
139+
140+ # compare ONNX Runtime and PyTorch results
141+ np .testing .assert_allclose (torch_out , onnx_out , rtol = 1e-02 , atol = 1e-04 )
142+ print ("Exported model predictions match original" )
143+
144+
145+ if __name__ == "__main__" :
146+ parser = argparse .ArgumentParser ()
147+ parser .add_argument ("model_weights" , help = "path to input model weights" )
148+ parser .add_argument ("output_path" , help = "path to output model weights" )
149+ parser .add_argument (
150+ "-t" ,
151+ "--model_type" ,
152+ default = "dpt_hybrid" ,
153+ help = "model type [dpt_large|dpt_hybrid|midas_v21]" ,
154+ )
155+ parser .add_argument ("--batch_size" , default = 1 , help = "batch size used for tracing" )
156+ parser .add_argument (
157+ "--test_image_path" ,
158+ type = str ,
159+ help = "path to some image to test the accuracy of the exported model against the original"
160+ )
161+
162+ args = parser .parse_args ()
163+ main (args .model_weights , args .model_type , args .output_path , args .batch_size , args .test_image_path )
0 commit comments