@@ -87,94 +87,6 @@ TrtCandidateSelector::TrtCandidateSelector(
8787 : graph_properties_(graph_properties), precision_mode_(precision_mode) {}
8888
8989Status TrtCandidateSelector::IsTensorRTCandidate (const tensorflow::Node* node) {
90- // TODO(laigd): move this set to TrtNodeValidator where it should belong.
91- // LINT.IfChange
92- static const auto * candidate_ops = new std::set<string>{
93- " Abs" ,
94- " Acos" ,
95- " Acosh" ,
96- " Add" ,
97- " Asin" ,
98- " Asinh" ,
99- " Atan" ,
100- " Atanh" ,
101- " AvgPool" ,
102- " BatchMatMul" ,
103- " BiasAdd" ,
104- " Ceil" ,
105- " ConcatV2" ,
106- " Const" ,
107- " Conv2D" ,
108- " Conv2DBackpropInput" ,
109- " Cos" ,
110- " Cosh" ,
111- " DepthwiseConv2dNative" ,
112- " Div" ,
113- " Exp" ,
114- " ExpandDims" ,
115- " Floor" ,
116- " FusedBatchNorm" ,
117- " FusedBatchNormV2" ,
118- " GatherV2" ,
119- " Identity" ,
120- " LeakyRelu" ,
121- " Log" ,
122- " MatMul" ,
123- " Max" ,
124- " Maximum" ,
125- " MaxPool" ,
126- " Mean" ,
127- " Min" ,
128- " Minimum" ,
129- " Mul" ,
130- " Neg" ,
131- " Pad" ,
132- " Prod" ,
133- " RealDiv" ,
134- " Reciprocal" ,
135- " Relu" ,
136- " Relu6" ,
137- " Reshape" ,
138- " Rsqrt" ,
139- " Sigmoid" ,
140- " Sin" ,
141- " Sinh" ,
142- " Slice" ,
143- " Snapshot" ,
144- " Softmax" ,
145- " Sqrt" ,
146- " Square" ,
147- " Squeeze" ,
148- " StridedSlice" ,
149- " Sub" ,
150- " Sum" ,
151- " Tan" ,
152- " Tanh" ,
153- " TopKV2" ,
154- " Transpose" ,
155- };
156- bool is_supported_op_type =
157- (candidate_ops->count (node->type_string ()) ||
158- PluginFactoryTensorRT::GetInstance ()->IsPlugin (node->type_string ()));
159- static const auto * quantize_ops = new std::set<string>{
160- " QuantizeAndDequantizeV2" ,
161- " QuantizeAndDequantizeV3" ,
162- " FakeQuantWithMinMaxVars" ,
163- " FakeQuantWithMinMaxArgs" ,
164- };
165- // In INT8 mode, we will always apply the quantization ranges provided by
166- // these ops to the relevant tensors. This happens regardless of the value of
167- // use_calibration.
168- if (precision_mode_ == TrtPrecisionMode::INT8 &&
169- quantize_ops->count (node->type_string ())) {
170- is_supported_op_type = true ;
171- }
172- // LINT.ThenChange(//tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc)
173- if (!is_supported_op_type) {
174- return errors::Unimplemented (" Op type " , node->type_string (),
175- " is not supported" );
176- }
177-
17890 std::vector<const Edge*> input_edges;
17991 TF_RETURN_IF_ERROR (node->input_edges (&input_edges));
18092 std::vector<std::pair<const NodeDef*, int >> input_node_and_ports;
@@ -184,7 +96,7 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) {
18496 input_edge->src_output ());
18597 }
18698 return validator_.ValidateNode (node->def (), input_node_and_ports,
187- graph_properties_);
99+ precision_mode_, graph_properties_);
188100}
189101
190102namespace {
0 commit comments