-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathbuild_engine_tool.cpp
More file actions
152 lines (128 loc) · 5.54 KB
/
build_engine_tool.cpp
File metadata and controls
152 lines (128 loc) · 5.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#include <iostream>
#include <fstream>
#include <memory>
#include <chrono>
#include <NvInfer.h>
#include <NvOnnxParser.h>
class Logger : public nvinfer1::ILogger {
public:
void log(Severity severity, const char* msg) noexcept override {
if (severity <= Severity::kWARNING) {
std::cout << "[TensorRT] " << msg << std::endl;
}
}
};
bool buildEngine(const std::string& onnxPath, const std::string& enginePath, bool enableFP16 = false) {
Logger logger;
// Create builder
auto builder = std::unique_ptr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(logger));
if (!builder) {
std::cerr << "Failed to create TensorRT builder" << std::endl;
return false;
}
// Create network
auto network = std::unique_ptr<nvinfer1::INetworkDefinition>(
builder->createNetworkV2(1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)));
if (!network) {
std::cerr << "Failed to create network definition" << std::endl;
return false;
}
// Create ONNX parser
auto parser = std::unique_ptr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, logger));
if (!parser) {
std::cerr << "Failed to create ONNX parser" << std::endl;
return false;
}
// Parse ONNX file
std::cout << "Parsing ONNX model: " << onnxPath << std::endl;
if (!parser->parseFromFile(onnxPath.c_str(), static_cast<int>(nvinfer1::ILogger::Severity::kWARNING))) {
std::cerr << "Failed to parse ONNX file" << std::endl;
return false;
}
// Create builder config
auto config = std::unique_ptr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
if (!config) {
std::cerr << "Failed to create builder config" << std::endl;
return false;
}
// Set memory pool limit
config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, 1U << 30); // 1GB
// Enable FP16 if requested
if (enableFP16) {
config->setFlag(nvinfer1::BuilderFlag::kFP16);
std::cout << "FP16 precision enabled" << std::endl;
} else {
std::cout << "Using FP32 precision" << std::endl;
}
// Build engine
std::cout << "Building TensorRT engine (this may take several minutes)..." << std::endl;
auto start = std::chrono::high_resolution_clock::now();
auto engine = std::unique_ptr<nvinfer1::ICudaEngine>(builder->buildEngineWithConfig(*network, *config));
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::seconds>(end - start);
if (!engine) {
std::cerr << "Failed to build TensorRT engine" << std::endl;
return false;
}
std::cout << "Engine built successfully in " << duration.count() << " seconds" << std::endl;
// Print engine info
std::cout << "Engine info:" << std::endl;
std::cout << " Number of I/O tensors: " << engine->getNbIOTensors() << std::endl;
for (int i = 0; i < engine->getNbIOTensors(); ++i) {
const char* tensorName = engine->getIOTensorName(i);
auto dims = engine->getTensorShape(tensorName);
std::cout << " Tensor " << i << ": " << tensorName << " [";
for (int j = 0; j < dims.nbDims; ++j) {
std::cout << dims.d[j];
if (j < dims.nbDims - 1) std::cout << ", ";
}
std::cout << "]" << std::endl;
}
// Serialize and save engine
std::cout << "Saving engine to: " << enginePath << std::endl;
auto serializedEngine = std::unique_ptr<nvinfer1::IHostMemory>(engine->serialize());
if (!serializedEngine) {
std::cerr << "Failed to serialize engine" << std::endl;
return false;
}
std::ofstream file(enginePath, std::ios::binary);
if (!file.good()) {
std::cerr << "Failed to create engine file: " << enginePath << std::endl;
return false;
}
file.write(static_cast<const char*>(serializedEngine->data()), serializedEngine->size());
file.close();
std::cout << "Engine saved successfully (" << serializedEngine->size() / (1024*1024) << " MB)" << std::endl;
return true;
}
int main(int argc, char* argv[]) {
if (argc < 3) {
std::cout << "Usage: " << argv[0] << " <onnx_path> <engine_path> [fp16]" << std::endl;
std::cout << "Example: " << argv[0] << " models/yolov5s.onnx models/yolov5s.engine" << std::endl;
std::cout << "Example: " << argv[0] << " models/yolov5s.onnx models/yolov5s.engine fp16" << std::endl;
return 1;
}
std::string onnxPath = argv[1];
std::string enginePath = argv[2];
bool enableFP16 = (argc > 3 && std::string(argv[3]) == "fp16");
std::cout << "=== TensorRT Engine Builder ===" << std::endl;
std::cout << "ONNX model: " << onnxPath << std::endl;
std::cout << "Engine output: " << enginePath << std::endl;
std::cout << "Precision: " << (enableFP16 ? "FP16" : "FP32") << std::endl;
std::cout << "================================" << std::endl;
// Check if ONNX file exists
std::ifstream onnxFile(onnxPath);
if (!onnxFile.good()) {
std::cerr << "ONNX file not found: " << onnxPath << std::endl;
return 1;
}
onnxFile.close();
if (buildEngine(onnxPath, enginePath, enableFP16)) {
std::cout << "\n✓ Engine build completed successfully!" << std::endl;
std::cout << "Engine location: " << enginePath << std::endl;
return 0;
} else {
std::cerr << "\n✗ Engine build failed!" << std::endl;
return 1;
}
}