Skip to content

Commit 16d954d

Browse files
aaroeytensorflower-gardener
authored andcommitted
Use TrtNodeValidator to determine whether an op type is supported, and register
all node converters as validators as well. This also fix a bug where it says Floor is supported (since it's in the op list in IsTensorRTCandidate(), and validator_.ValidateNode() returns OK since no validator was registered for it), but actually it's not supported until TRT 5.1. PiperOrigin-RevId: 235295283
1 parent 8af58fd commit 16d954d

File tree

4 files changed

+119
-159
lines changed

4 files changed

+119
-159
lines changed

tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc

Lines changed: 1 addition & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -87,94 +87,6 @@ TrtCandidateSelector::TrtCandidateSelector(
8787
: graph_properties_(graph_properties), precision_mode_(precision_mode) {}
8888

8989
Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) {
90-
// TODO(laigd): move this set to TrtNodeValidator where it should belong.
91-
// LINT.IfChange
92-
static const auto* candidate_ops = new std::set<string>{
93-
"Abs",
94-
"Acos",
95-
"Acosh",
96-
"Add",
97-
"Asin",
98-
"Asinh",
99-
"Atan",
100-
"Atanh",
101-
"AvgPool",
102-
"BatchMatMul",
103-
"BiasAdd",
104-
"Ceil",
105-
"ConcatV2",
106-
"Const",
107-
"Conv2D",
108-
"Conv2DBackpropInput",
109-
"Cos",
110-
"Cosh",
111-
"DepthwiseConv2dNative",
112-
"Div",
113-
"Exp",
114-
"ExpandDims",
115-
"Floor",
116-
"FusedBatchNorm",
117-
"FusedBatchNormV2",
118-
"GatherV2",
119-
"Identity",
120-
"LeakyRelu",
121-
"Log",
122-
"MatMul",
123-
"Max",
124-
"Maximum",
125-
"MaxPool",
126-
"Mean",
127-
"Min",
128-
"Minimum",
129-
"Mul",
130-
"Neg",
131-
"Pad",
132-
"Prod",
133-
"RealDiv",
134-
"Reciprocal",
135-
"Relu",
136-
"Relu6",
137-
"Reshape",
138-
"Rsqrt",
139-
"Sigmoid",
140-
"Sin",
141-
"Sinh",
142-
"Slice",
143-
"Snapshot",
144-
"Softmax",
145-
"Sqrt",
146-
"Square",
147-
"Squeeze",
148-
"StridedSlice",
149-
"Sub",
150-
"Sum",
151-
"Tan",
152-
"Tanh",
153-
"TopKV2",
154-
"Transpose",
155-
};
156-
bool is_supported_op_type =
157-
(candidate_ops->count(node->type_string()) ||
158-
PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string()));
159-
static const auto* quantize_ops = new std::set<string>{
160-
"QuantizeAndDequantizeV2",
161-
"QuantizeAndDequantizeV3",
162-
"FakeQuantWithMinMaxVars",
163-
"FakeQuantWithMinMaxArgs",
164-
};
165-
// In INT8 mode, we will always apply the quantization ranges provided by
166-
// these ops to the relevant tensors. This happens regardless of the value of
167-
// use_calibration.
168-
if (precision_mode_ == TrtPrecisionMode::INT8 &&
169-
quantize_ops->count(node->type_string())) {
170-
is_supported_op_type = true;
171-
}
172-
// LINT.ThenChange(//tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc)
173-
if (!is_supported_op_type) {
174-
return errors::Unimplemented("Op type ", node->type_string(),
175-
" is not supported");
176-
}
177-
17890
std::vector<const Edge*> input_edges;
17991
TF_RETURN_IF_ERROR(node->input_edges(&input_edges));
18092
std::vector<std::pair<const NodeDef*, int>> input_node_and_ports;
@@ -184,7 +96,7 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) {
18496
input_edge->src_output());
18597
}
18698
return validator_.ValidateNode(node->def(), input_node_and_ports,
187-
graph_properties_);
99+
precision_mode_, graph_properties_);
188100
}
189101

190102
namespace {

0 commit comments

Comments
 (0)