xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
17 
18 #include <algorithm>
19 #include <bitset>
20 #include <cmath>
21 #include <cstring>
22 #include <map>
23 #include <memory>
24 #include <set>
25 #include <unordered_map>
26 #include <utility>
27 #include <vector>
28 
29 #include "absl/algorithm/container.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/memory/memory.h"
32 #include "absl/strings/match.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/strings/str_format.h"
35 #include "absl/strings/string_view.h"
36 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
37 #include "tensorflow/compiler/tf2tensorrt/convert/algorithm_selector.h"
38 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h"
39 #include "tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h"
40 #include "tensorflow/compiler/tf2tensorrt/convert/ops/quantization_ops.h"
41 #include "tensorflow/compiler/tf2tensorrt/convert/ops/slice_ops.h"
42 #include "tensorflow/compiler/tf2tensorrt/convert/timing_cache.h"
43 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
44 #include "tensorflow/compiler/tf2tensorrt/utils/trt_experimental_features.h"
45 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
46 #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
47 #include "tensorflow/core/common_runtime/graph_constructor.h"
48 #include "tensorflow/core/framework/node_def.pb.h"  // NOLINT
49 #include "tensorflow/core/framework/node_def_builder.h"
50 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
51 #include "tensorflow/core/framework/tensor_shape.h"
52 #include "tensorflow/core/framework/tensor_shape.pb.h"  // NOLINT
53 #include "tensorflow/core/framework/tensor_util.h"
54 #include "tensorflow/core/framework/types.h"
55 #include "tensorflow/core/graph/algorithm.h"
56 #include "tensorflow/core/graph/graph.h"
57 #include "tensorflow/core/grappler/grappler_item.h"
58 #include "tensorflow/core/grappler/op_types.h"
59 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
60 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer.h"
61 #include "tensorflow/core/lib/core/errors.h"
62 #include "tensorflow/core/lib/core/status.h"
63 #include "tensorflow/core/lib/strings/numbers.h"
64 #include "tensorflow/core/lib/strings/str_util.h"
65 #include "tensorflow/core/lib/strings/strcat.h"
66 #include "tensorflow/core/platform/logging.h"
67 #include "tensorflow/core/platform/mutex.h"
68 #include "tensorflow/core/platform/protobuf.h"
69 #include "tensorflow/core/platform/tensor_coding.h"
70 #include "tensorflow/core/platform/tensor_float_32_utils.h"
71 #include "tensorflow/core/platform/types.h"
72 #include "tensorflow/core/profiler/lib/annotated_traceme.h"
73 #include "tensorflow/core/public/version.h"
74 #include "tensorflow/core/util/env_var.h"
75 #include "tensorflow/core/util/strided_slice_op.h"
76 
77 #if GOOGLE_CUDA && GOOGLE_TENSORRT
78 #include "third_party/tensorrt/NvInfer.h"
79 #include "third_party/tensorrt/NvInferPlugin.h"
80 
81 // Check if the types are equal. Cast to int first so that failure log message
82 // would work!
83 #define TFTRT_CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2)
84 
85 #define TFTRT_CHECK_INPUT_SIZE(size, exp_size, node_def)                 \
86   if ((size) != (exp_size)) {                                            \
87     TFTRT_ERROR(errors::InvalidArgument, node_def.op(), " got ", (size), \
88                 " inputs but expected ", (exp_size));                    \
89   }
90 
91 namespace tensorflow {
92 namespace tensorrt {
93 namespace convert {
94 
95 using absl::StrAppend;
96 using absl::StrCat;
97 
98 namespace {
99 
100 #define ADD_LAYER(layer_name)              \
101   case nvinfer1::LayerType::k##layer_name: \
102     return #layer_name;
103 
LayerTypeToString(nvinfer1::LayerType layer_type)104 const char* LayerTypeToString(nvinfer1::LayerType layer_type) {
105   switch (layer_type) {
106     ADD_LAYER(CONVOLUTION)
107     ADD_LAYER(FULLY_CONNECTED)
108     ADD_LAYER(ACTIVATION)
109     ADD_LAYER(POOLING)
110     ADD_LAYER(LRN)
111     ADD_LAYER(SCALE)
112     ADD_LAYER(SOFTMAX)
113     ADD_LAYER(DECONVOLUTION)
114     ADD_LAYER(CONCATENATION)
115     ADD_LAYER(ELEMENTWISE)
116     ADD_LAYER(PLUGIN)
117     ADD_LAYER(UNARY)
118     ADD_LAYER(PADDING)
119     ADD_LAYER(SHUFFLE)
120     ADD_LAYER(REDUCE)
121     ADD_LAYER(TOPK)
122     ADD_LAYER(GATHER)
123     ADD_LAYER(MATRIX_MULTIPLY)
124     ADD_LAYER(RAGGED_SOFTMAX)
125     ADD_LAYER(CONSTANT)
126     ADD_LAYER(RNN_V2)
127     ADD_LAYER(IDENTITY)
128     ADD_LAYER(PLUGIN_V2)
129     ADD_LAYER(SLICE)
130     ADD_LAYER(SHAPE)
131     ADD_LAYER(PARAMETRIC_RELU)
132     ADD_LAYER(RESIZE)
133     ADD_LAYER(TRIP_LIMIT)
134     ADD_LAYER(RECURRENCE)
135     ADD_LAYER(ITERATOR)
136     ADD_LAYER(LOOP_OUTPUT)
137     ADD_LAYER(SELECT)
138     ADD_LAYER(FILL)
139 #if IS_TRT_VERSION_GE(8, 0, 0, 0)
140     ADD_LAYER(QUANTIZE)
141     ADD_LAYER(DEQUANTIZE)
142 #else
143     // The TRT IRNNv2Layer has been deprecated in favor of the loop API.
144     ADD_LAYER(RNN)
145 #endif
146   }
147   return "UNKNOWN_LAYER";
148 }
149 
150 #undef ADD_LAYER
151 
152 // Sets the ILayer name in the form of
153 // <engine_name>/<tf_related_part>:<trt_operation_name>.
SetLayerNameHelper(nvinfer1::ILayer * layer,absl::string_view engine_name,absl::string_view tf_name)154 void SetLayerNameHelper(nvinfer1::ILayer* layer, absl::string_view engine_name,
155                         absl::string_view tf_name) {
156   const char* trt_name = LayerTypeToString(layer->getType());
157   layer->setName(
158       absl::StrCat(engine_name, "/", tf_name, ":", trt_name).c_str());
159 }
160 
161 // Returns a string in the form of <sub_op_name><sub_op_instance>.
GetLayerNameSuffix(absl::string_view sub_op_name,std::optional<int> sub_op_instance)162 std::string GetLayerNameSuffix(absl::string_view sub_op_name,
163                                std::optional<int> sub_op_instance) {
164   std::string op_suffix(sub_op_name);
165   if (sub_op_instance.has_value()) {
166     op_suffix =
167         absl::StrCat(op_suffix, "_", std::to_string(sub_op_instance.value()));
168   }
169   return op_suffix;
170 }
171 
172 }  // namespace
173 
IsEngineInput(absl::string_view name)174 bool IsEngineInput(absl::string_view name) {
175   return absl::StartsWith(name, IONamePrefixes::kInputPHName);
176 }
IsEngineOutput(absl::string_view name)177 bool IsEngineOutput(absl::string_view name) {
178   return absl::StartsWith(name, IONamePrefixes::kOutputPHName);
179 }
180 
GetOutputProperties(const grappler::GraphProperties & graph_properties,const Node * node,const int out_port,PartialTensorShape * shape,DataType * dtype)181 void GetOutputProperties(const grappler::GraphProperties& graph_properties,
182                          const Node* node, const int out_port,
183                          PartialTensorShape* shape, DataType* dtype) {
184   if (graph_properties.HasOutputProperties(node->name())) {
185     auto output_params = graph_properties.GetOutputProperties(node->name());
186     auto out_shape = output_params.at(out_port);
187     *dtype = out_shape.dtype();
188     *shape = out_shape.shape();
189   } else {
190     LOG(INFO) << "Unknown output shape" << node->name();
191     *dtype = node->output_type(out_port);
192   }
193 }
194 
GetInputProperties(const grappler::GraphProperties & graph_properties,const Node * node,const int in_port,PartialTensorShape * shape,DataType * dtype)195 void GetInputProperties(const grappler::GraphProperties& graph_properties,
196                         const Node* node, const int in_port,
197                         PartialTensorShape* shape, DataType* dtype) {
198   if (graph_properties.HasInputProperties(node->name())) {
199     auto input_params = graph_properties.GetInputProperties(node->name());
200     auto in_shape = input_params.at(in_port);
201     *dtype = in_shape.dtype();
202     *shape = in_shape.shape();
203   } else {
204     *dtype = node->input_type(in_port);
205   }
206 }
207 
208 // This function checks if a tensor is compatible with TRT.
209 //
210 // We check that the shape and datatype are compatible with TensorRT. We also
211 // return the corresponding trt_dtype, the trt_dims and the batch_size (latter
212 // is only needed in implicit batch mode).
213 //
214 // The return status indicates wether the tensor is compatible.
215 //
216 // For implicit batch mode, when validation_only == false, we also check that
217 // all input dimensions (besides the batch dimension) are known dimensions.
ValidateTensorProperties(const string & producer_node_type,const DataType dtype,const PartialTensorShape & shape,const bool use_implicit_batch,bool validation_only,nvinfer1::DataType * trt_dtype,nvinfer1::Dims * trt_dims,int * batch_size)218 Status ValidateTensorProperties(const string& producer_node_type,
219                                 const DataType dtype,
220                                 const PartialTensorShape& shape,
221                                 const bool use_implicit_batch,
222                                 bool validation_only,
223                                 nvinfer1::DataType* trt_dtype,
224                                 nvinfer1::Dims* trt_dims, int* batch_size) {
225   // Convert data type.
226   TF_RETURN_IF_ERROR(TfTypeToTrtType(dtype, trt_dtype));
227 
228   // Convert shape.
229   if (shape.dims() < 0) {
230     return errors::InvalidArgument("Input tensor rank is unknown.");
231   }
232   // Add 1 to maximum rank for implicit batch dim.
233   const int max_rank = nvinfer1::Dims::MAX_DIMS + (use_implicit_batch ? 1 : 0);
234   if (shape.dims() > max_rank) {
235     return errors::OutOfRange("Input tensor rank is greater than ", max_rank);
236   }
237   if (use_implicit_batch && (producer_node_type != "Const") &&
238       (shape.dims() < 1)) {
239     return errors::InvalidArgument(
240         "Scalar input tensor is not supported since the first dimension "
241         "is treated as batch dimension by TRT");
242   }
243   StatusOr<DimsAdapter> dims = DimsAdapter::Create(shape, use_implicit_batch);
244   TRT_ENSURE_OK(dims);
245   *trt_dims = dims->AsTrtDims();
246   // Get batch size for tensor if it will not be included the shape.
247   if (use_implicit_batch) {
248     *batch_size = shape.dim_size(0);
249   }
250 
251   // Don't convert empty tensors (dim value of 0).
252   const int first_trt_dim = use_implicit_batch ? 1 : 0;
253   for (int d = first_trt_dim; d < shape.dims(); ++d) {
254     if (shape.dim_size(d) == 0) {
255       return errors::Unimplemented(
256           "Input tensor with shape ", shape.DebugString(),
257           " is an empty tensor, which is not supported by TRT");
258     }
259   }
260 
261   if (validation_only) return Status::OK();
262 
263   // Following checks are only used during TRT engine creation time.
264   if (use_implicit_batch) {
265     for (int d = first_trt_dim; d < shape.dims(); ++d) {
266       if (shape.dim_size(d) < 0) {
267         return errors::InvalidArgument(
268             "Input tensor with shape ", shape.DebugString(),
269             " has an unknown non-batch dimension at dim ", d);
270       }
271     }
272   }
273   return Status::OK();
274 }
275 
GetTrtBroadcastShape(const TRT_TensorOrWeights & operand_l,const TRT_TensorOrWeights & operand_r,const bool check_feasibility,const bool use_implicit_batch,nvinfer1::Dims * operand_l_new_dims,nvinfer1::Dims * operand_r_new_dims)276 Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
277                             const TRT_TensorOrWeights& operand_r,
278                             const bool check_feasibility,
279                             const bool use_implicit_batch,
280                             nvinfer1::Dims* operand_l_new_dims,
281                             nvinfer1::Dims* operand_r_new_dims) {
282   // TensorRT Elementwise op supports broadcast but requires both tensor to be
283   // of Identical rank
284   //
285   // We consider case of:
286   //   1. operand_l to be a Tensor & operand_r to be a Const;
287   //   2. operand_l to be a Tensor & operand_r to be a Tensor;
288   // note: const op const (constant folding) should fallback to TensorFlow
289   //
290   // broadcast scheme:
291   //       T:  1 3 5    (tensor would not have batch dimension)
292   //       W:  1 1 3 1  (weight would have all explicit dimensions)
293   // i. fill in explicit dimensions
294   //    -> T: -1 1 3 5  (we put a -1 for batch dimension)
295   //    -> W:  1 1 3 1
296   // ii. compare broadcast feasibility
297   //
298   // We cannot support the following since TensorRT does not allow manipulation
299   // on batch dimension, we cannot generate output with proper shape
300   //    T: 3 5 1
301   //    W: 1 1 1  1 3 5 1
302   // -> T: 1 1 1 -1 3 5 1
303   // -> W: 1 1 1  1 3 5 1
304   // ***************************************************************************
305   if (!operand_l.is_tensor() && !operand_r.is_tensor()) {
306     // TODO(lsugy): remove this check in dynamic shapes mode. This should work
307     // if both inputs are weights.
308     return errors::InvalidArgument(
309         "Broadcasting requires at least one of the operands be tensors");
310   }
311 
312   constexpr int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1;
313   auto compute_output_dims =
314       [use_implicit_batch](const TRT_TensorOrWeights& input,
315                            int broadcast_num_dims,
316                            std::array<int32_t, max_nb_dims>* output_dims_array,
317                            nvinfer1::Dims* output_dims) -> Status {
318     const nvinfer1::Dims input_dims = input.GetTrtDims();
319     absl::c_fill(*output_dims_array, 1);
320     absl::c_copy(
321         DimsAdapter(input_dims),
322         output_dims_array->begin() + broadcast_num_dims - input_dims.nbDims);
323     if (use_implicit_batch && input.is_tensor()) {
324       const int true_input_dims = input_dims.nbDims + 1;
325       if (true_input_dims < broadcast_num_dims) {
326         return errors::InvalidArgument(
327             "Broadcasting beyond batch dimension is not supported ",
328             "(tensor #dims ", true_input_dims, " vs broadcast #dims ",
329             broadcast_num_dims, ")");
330       }
331       // Set the batch dimension to -1, since batch size is not supposed to
332       // be broadcasted.
333       (*output_dims_array)[0] = -1;
334     }
335     // Copy to output dimensions
336     auto offt = use_implicit_batch ? 1 : 0;
337     output_dims->nbDims = broadcast_num_dims - offt;
338     absl::c_copy(
339         absl::MakeSpan(*output_dims_array).subspan(offt, broadcast_num_dims),
340         output_dims->d);
341     return Status::OK();
342   };
343 
344   // Compute the output dimensions.
345   const int broadcast_num_dims =
346       std::max(operand_l.GetTrtDims().nbDims +
347                    (use_implicit_batch && operand_l.is_tensor()),
348                operand_r.GetTrtDims().nbDims +
349                    (use_implicit_batch && operand_r.is_tensor()));
350   std::array<int32_t, max_nb_dims> output_l, output_r;
351   TF_RETURN_IF_ERROR(compute_output_dims(operand_l, broadcast_num_dims,
352                                          &output_l, operand_l_new_dims));
353   TF_RETURN_IF_ERROR(compute_output_dims(operand_r, broadcast_num_dims,
354                                          &output_r, operand_r_new_dims));
355 
356   // Compare broadcast feasibility
357   if (check_feasibility) {
358     for (int i = 0; i < broadcast_num_dims; ++i) {
359       if (!use_implicit_batch && (output_l[i] == -1 || output_r[i] == -1)) {
360         // If the condition is true then we are in explicit batch mode and (at
361         // least) one of the input dimensions are unknown. In other words we
362         // are in dynamic shape mode. During conversion time we only see -1 for
363         // the unknown shapes, therefore we cannot decide on the feasibility of
364         // broadcast over the unknown dimensions. Therefore we just continue for
365         // the next dimension. In dynamic shape mode TRT can only check the
366         // feasibility of the broadcast when the actual input dimensions are
367         // specified by SetTrtEngineInputs and the inference job is launched by
368         // TrtEnque.
369         continue;
370       }
371       if ((output_l[i] != output_r[i]) && (output_l[i] != 1) &&
372           (output_r[i] != 1)) {
373         return errors::InvalidArgument("Infeasible broadcast scheme (",
374                                        "batch_dim: ", output_l[0], ", ",
375                                        DebugString(*operand_l_new_dims), " vs ",
376                                        "batch_dim: ", output_r[0], ", ",
377                                        DebugString(*operand_r_new_dims), ")");
378       }
379     }
380   }
381   return Status::OK();
382 }
383 
384 // Prepares a dynamic shape tensor for broadcast by adding leading 1 dimensions.
DynamicBroadcast(ITensorProxyPtr operand,OpConverterParams * params,ITensorProxyPtr * output,int broadcasted_nbDims,std::optional<int> op_instance)385 Status DynamicBroadcast(ITensorProxyPtr operand, OpConverterParams* params,
386                         ITensorProxyPtr* output, int broadcasted_nbDims,
387                         std::optional<int> op_instance) {
388   int operand_nbDims = operand->getDimensions().nbDims;
389   if (broadcasted_nbDims > operand_nbDims) {
390     if (params->validation_only) return Status::OK();
391     int n_extra_dims = broadcasted_nbDims - operand_nbDims;
392     VLOG(2) << "Dynamic broadcast adding " << n_extra_dims << " leading 1s";
393     TF_RETURN_IF_ERROR(params->converter->DynamicReshape(
394         /*input=*/operand,
395         /*slices=*/{std::make_pair(0, operand_nbDims)},
396         /*params=*/params,
397         /*output=*/output,
398         /*size_for_added_dims*/ {n_extra_dims},
399         /*op_instance=*/op_instance));
400   } else {
401     *output = operand;
402   }
403   return Status::OK();
404 }
405 
BroadcastWeights(std::unique_ptr<TRT_TensorOrWeights> & p,const DimsAdapter & broadcasted_dims)406 Status BroadcastWeights(std::unique_ptr<TRT_TensorOrWeights>& p,
407                         const DimsAdapter& broadcasted_dims) {
408   if (!p->is_weights()) return errors::Internal("Weight input expected");
409   if (p->GetTrtDims().nbDims != broadcasted_dims.NumDims()) {
410     TRT_ShapedWeights weights(p->weights());
411     TF_RETURN_IF_ERROR(weights.SetShape(broadcasted_dims));
412     p = std::make_unique<TRT_TensorOrWeights>(weights);
413   }
414   return Status::OK();
415 }
416 
ApplyBroadcast(std::unique_ptr<TRT_TensorOrWeights> & operand,const DimsAdapter & broadcasted_dims,OpConverterParams * params,std::optional<int> op_instance)417 Status ApplyBroadcast(std::unique_ptr<TRT_TensorOrWeights>& operand,
418                       const DimsAdapter& broadcasted_dims,
419                       OpConverterParams* params,
420                       std::optional<int> op_instance) {
421   if (operand->is_weights()) {
422     TF_RETURN_IF_ERROR(BroadcastWeights(operand, broadcasted_dims));
423   } else {
424     ITensorProxyPtr tensor = nullptr;
425     auto is_static_shuffle_compatible = [](const auto& dims) {
426       return absl::c_count(dims, -1) <= 1;
427     };
428     if (is_static_shuffle_compatible(broadcasted_dims)) {
429       TF_RETURN_IF_ERROR(PrepareTensorForShape(
430           params->converter, *operand, broadcasted_dims,
431           params->validation_only, &tensor, params->node_def));
432     } else {
433       TF_RETURN_IF_ERROR(DynamicBroadcast(
434           /*operand=*/operand->tensor(),
435           /*params=*/params,
436           /*output=*/&tensor,
437           /*broadcasted_nbDims*/ broadcasted_dims.NumDims(),
438           /*op_instance=*/op_instance));
439     }
440     operand = std::make_unique<TRT_TensorOrWeights>(tensor);
441   }
442   return Status::OK();
443 }
444 
445 // Inserts leading 1 dimensions so that both operands have the same rank.
446 // Note: In implicit batch mode, weights' shape can include an explicit 1 batch
447 // dimension. The broadcasted shape might loose this leading batch dim, because
448 // the broadcasted shape does not include the implicit batch dim.
449 // TODO(tfeher): Other code blocks that use GetTrtBroadcastShape need to be
450 // fixed to use this routine to handle dynamic inputs. Eventually,
451 // GetTrtBroadcastShape should only be used by this routine.
BroadcastTensors(std::unique_ptr<TRT_TensorOrWeights> & operand_l,std::unique_ptr<TRT_TensorOrWeights> & operand_r,bool check_feasibility,OpConverterParams * params)452 Status BroadcastTensors(std::unique_ptr<TRT_TensorOrWeights>& operand_l,
453                         std::unique_ptr<TRT_TensorOrWeights>& operand_r,
454                         bool check_feasibility, OpConverterParams* params) {
455   nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
456   TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
457       *operand_l, *operand_r, check_feasibility, params->use_implicit_batch,
458       &broadcasted_dims_l, &broadcasted_dims_r));
459 
460   if (params->validation_only) return Status::OK();
461 
462   TF_RETURN_IF_ERROR(ApplyBroadcast(
463       /*operand=*/operand_l,
464       /*broadcasted_dims=*/broadcasted_dims_l,
465       /*params=*/params,
466       /*op_instance=*/0));
467 
468   TF_RETURN_IF_ERROR(ApplyBroadcast(
469       /*operand=*/operand_r,
470       /*broadcasted_dims=*/broadcasted_dims_r,
471       /*params=*/params,
472       /*op_instance=*/1));
473 
474   return Status::OK();
475 }
476 
CreateConstantLayer(const TRT_ShapedWeights & weights,const nvinfer1::Dims & dims)477 ITensorProxyPtr Converter::CreateConstantLayer(const TRT_ShapedWeights& weights,
478                                                const nvinfer1::Dims& dims) {
479   nvinfer1::Weights trt_weights = weights.GetTrtWeights();
480   nvinfer1::IConstantLayer* layer = network()->addConstant(dims, trt_weights);
481   if (!layer) return nullptr;
482   SetLayerName(layer, "_tftrt_constant_",
483                std::to_string(next_constant_layer_id_));
484   next_constant_layer_id_++;
485   ITensorProxyPtr trt_tensor = layer->getOutput(0);
486   return trt_tensor;
487 }
488 
489 // Creates a scalar constant and fills with value.
490 template <typename T>
CreateScalarConstant(OpConverterParams * params,T value,ITensorProxyPtr * tensor,nvinfer1::DataType trt_type=nvinfer1::DataType::kINT32,const nvinfer1::Dims & dims={1, {1}})491 Status CreateScalarConstant(
492     OpConverterParams* params, T value, ITensorProxyPtr* tensor,
493     nvinfer1::DataType trt_type = nvinfer1::DataType::kINT32,
494     const nvinfer1::Dims& dims = {1, {1}}) {
495   StatusOr<TRT_ShapedWeights> weights =
496       params->weight_store->GetTempWeights(trt_type, dims);
497   TRT_ENSURE_OK(weights);
498   TF_RETURN_IF_ERROR(weights->SetValues(value));
499   *tensor = params->converter->CreateConstantLayer(*weights, dims);
500   TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, params->node_def.name());
501   return Status::OK();
502 }
503 
504 // Creates a constant with the same rank as dims, where each dimension has
505 // size = 1.
CreateBroadcastableScalarConstant(OpConverterParams * params,float value,const nvinfer1::Dims & dims,ITensorProxyPtr * tensor,const char * dtype_attr_name="T")506 Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value,
507                                          const nvinfer1::Dims& dims,
508                                          ITensorProxyPtr* tensor,
509                                          const char* dtype_attr_name = "T") {
510   nvinfer1::DataType trt_type = nvinfer1::DataType::kFLOAT;  // Default to FP32.
511   AttrSlice attrs(params->node_def);
512   if (attrs.FindByString(dtype_attr_name) != nullptr) {
513     DataType dtype;
514     TF_RETURN_IF_ERROR(GetNodeAttr(attrs, dtype_attr_name, &dtype));
515     TF_RETURN_IF_ERROR(TfTypeToTrtType(dtype, &trt_type));
516   }
517 
518   // In order to be broadcastable, the number of dims has to match.
519   nvinfer1::Dims broadcastable_dims(dims);
520   for (int i = 0; i < broadcastable_dims.nbDims; i++) {
521     broadcastable_dims.d[i] = 1;
522   }
523   return CreateScalarConstant(params, value, tensor, trt_type,
524                               broadcastable_dims);
525 }
526 
527 // The function concatenates tensors on the first axis. This can be used to
528 // create a shape tensor from individual dimension sizes.
ConcatenateTensors(OpConverterParams * params,const std::vector<ITensorProxyPtr> input_tensors,std::optional<int> op_instance=std::nullopt)529 StatusOr<ITensorProxyPtr> ConcatenateTensors(
530     OpConverterParams* params, const std::vector<ITensorProxyPtr> input_tensors,
531     std::optional<int> op_instance = std::nullopt) {
532   std::vector<nvinfer1::ITensor*> trt_input_tensors;
533   for (const auto& t : input_tensors) {
534     trt_input_tensors.push_back(t->trt_tensor());
535   }
536   nvinfer1::IConcatenationLayer* layer =
537       params->converter->network()->addConcatenation(
538           static_cast<nvinfer1::ITensor* const*>(trt_input_tensors.data()),
539           input_tensors.size());
540   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, params->node_def.op());
541   params->converter->SetLayerName(layer, params->node_def.name(),
542                                   "concat_shapes", op_instance);
543   layer->setAxis(0);
544   return ITensorProxyPtr(layer->getOutput(0));
545 }
546 
547 // Convert an axis from TF format to TRT format while validating. TF format
548 // includes the batch dimension, while TRT does not if implicit batching is used
549 // (i.e. for tensors). TF can also use negative indices.
ConvertAxis(int tf_axis,int trt_nb_dims,absl::string_view node_name,bool use_implicit_batch,int * trt_axis)550 Status ConvertAxis(int tf_axis, int trt_nb_dims, absl::string_view node_name,
551                    bool use_implicit_batch, int* trt_axis) {
552   const int tf_nb_dims = trt_nb_dims + (use_implicit_batch ? 1 : 0);
553   // Check bounds.
554   if (tf_axis < -tf_nb_dims || tf_axis >= tf_nb_dims) {
555     return errors::InvalidArgument(
556         "Axis value of ", tf_axis, " is out of bounds, must be in range [",
557         -tf_nb_dims, ", ", tf_nb_dims, "), at ", node_name);
558   }
559   // Make negative axis positive.
560   if (tf_axis < 0) tf_axis += tf_nb_dims;
561   // Don't allow axis to be the batch dimension.
562   if (use_implicit_batch && tf_axis == 0) {
563     return errors::Unimplemented(
564         "TensorRT does not allow manipulation of the batch dimension");
565   }
566   // Remove batch dimension if it is implicit.
567   *trt_axis = use_implicit_batch ? tf_axis - 1 : tf_axis;
568   return Status::OK();
569 }
570 
AllLengthsEqual(const std::vector<std::vector<int>> & inputs)571 bool AllLengthsEqual(const std::vector<std::vector<int>>& inputs) {
572   if (inputs.size() == 0) return true;
573   int length = inputs.at(0).size();
574   for (int i = 1; i < inputs.size(); i++) {
575     if (inputs.at(i).size() != length) return false;
576   }
577   return true;
578 }
579 
DimsHaveSameSize(const DimsAdapter & lhs,const DimsAdapter & rhs)580 bool DimsHaveSameSize(const DimsAdapter& lhs, const DimsAdapter& rhs) {
581   return lhs.Volume() == rhs.Volume();
582 }
583 
584 // Returns whether both dimensions are fully specified and the total number of
585 // elements equals.
AreDimsStaticWithSameSize(const DimsAdapter & lhs,const DimsAdapter & rhs)586 bool AreDimsStaticWithSameSize(const DimsAdapter& lhs, const DimsAdapter& rhs) {
587   if (!lhs.IsStatic() || !rhs.IsStatic()) return false;
588   return DimsHaveSameSize(lhs, rhs);
589 }
590 
AreDimsStaticWithDifferentSize(const DimsAdapter & lhs,const DimsAdapter & rhs)591 bool AreDimsStaticWithDifferentSize(const DimsAdapter& lhs,
592                                     const DimsAdapter& rhs) {
593   if (!lhs.IsStatic() || !rhs.IsStatic()) return false;
594   return !DimsHaveSameSize(lhs, rhs);
595 }
596 
CreateSamePadding(const nvinfer1::Dims & stride,const nvinfer1::Dims & kernel,const std::vector<int64_t> & input_dims)597 static std::vector<std::pair<int, int>> CreateSamePadding(
598     const nvinfer1::Dims& stride, const nvinfer1::Dims& kernel,
599     const std::vector<int64_t>& input_dims) {
600   std::vector<std::pair<int, int>> padding(input_dims.size());
601   CHECK_EQ(stride.nbDims, input_dims.size());  // TODO(jie): N+C? NC+?
602 
603   for (size_t i = 0; i < input_dims.size(); ++i) {
604     // Formula to calculate the padding
605     int p = ((input_dims[i] - 1) / stride.d[i]) * stride.d[i] + kernel.d[i] -
606             input_dims[i];
607     p = (p > 0) ? p : 0;
608 
609     // Right precedence padding, like in TensorFlow
610     int left = p / 2;
611     int right = p - left;
612 
613     VLOG(2) << "PADDING_" << i << " pre: " << left << ", post: " << right
614             << "paras: " << input_dims[i] << ", " << stride.d[i] << ", "
615             << "kernel: " << kernel.d[i];
616     padding[i] = {left, right};
617   }
618   return padding;
619 }
620 
GetCommonNameScope(const string & op_name_a,const string & op_name_b)621 string GetCommonNameScope(const string& op_name_a, const string& op_name_b) {
622   size_t last_scope_separator = 0;
623   const size_t min_size = std::min(op_name_a.size(), op_name_b.size());
624   for (size_t i = 0; i < min_size; ++i) {
625     if (op_name_a[i] != op_name_b[i]) break;
626     if (op_name_a[i] == '/') last_scope_separator = i + 1;
627   }
628   return op_name_a.substr(0, last_scope_separator);
629 }
630 
631 // Verifies that shapes of the given inputs match after masking the specified
632 // dimension.
VerifyShapesMatch(absl::Span<const TRT_TensorOrWeights> inputs,int masked_dim,absl::string_view node_name)633 Status VerifyShapesMatch(absl::Span<const TRT_TensorOrWeights> inputs,
634                          int masked_dim, absl::string_view node_name) {
635   size_t num_inputs = inputs.size();
636   if (num_inputs <= 1) return Status::OK();
637 
638   const nvinfer1::Dims dims_0 = inputs.at(0).GetTrtDims();
639   for (size_t i = 1; i < num_inputs; ++i) {
640     const nvinfer1::Dims dim_i = inputs.at(i).GetTrtDims();
641     if (dim_i.nbDims != dims_0.nbDims) {
642       return errors::InvalidArgument(
643           "Received inputs with inconsistent rank, at ", node_name);
644     }
645     for (size_t j = 0; j < dims_0.nbDims; ++j) {
646       // Dynamic dimensions will be verified at runtime.
647       if (dim_i.d[j] == -1 || dims_0.d[j] == -1) continue;
648       if (dim_i.d[j] != dims_0.d[j] && j != masked_dim) {
649         return errors::InvalidArgument(
650             "Received inputs with inconsistent shape, at ", node_name);
651       }
652     }
653   }
654   return Status::OK();
655 }
656 
657 // Perform 5 dimensional reorder of data on CPU
658 // This is done once at convert time and does not affect GPU inference perf
659 // Example: reorder NDHWC (Tensorflow) -> NCDHW (TensorRT)
660 template <typename T>
Reorder5(const nvinfer1::Dims & shape,const T * idata,const nvinfer1::Dims & istrides,T * odata,const nvinfer1::Dims & ostrides)661 void Reorder5(const nvinfer1::Dims& shape, const T* idata,
662               const nvinfer1::Dims& istrides, T* odata,
663               const nvinfer1::Dims& ostrides) {
664   for (int k = 0; k < shape.d[0]; ++k) {
665     for (int c = 0; c < shape.d[1]; ++c) {
666       for (int d = 0; d < shape.d[2]; ++d) {
667         for (int r = 0; r < shape.d[3]; ++r) {
668           for (int s = 0; s < shape.d[4]; ++s) {
669             odata[k * ostrides.d[0] + c * ostrides.d[1] + d * ostrides.d[2] +
670                   r * ostrides.d[3] + s * ostrides.d[4]] =
671                 idata[k * istrides.d[0] + c * istrides.d[1] +
672                       d * istrides.d[2] + r * istrides.d[3] +
673                       s * istrides.d[4]];
674           }
675         }
676       }
677     }
678   }
679 }
680 
681 // TODO(jie): reorder4 & reorder2 should be merged?
682 // TODO(aaroey): fix the order of parameters.
683 template <typename T>
Reorder4(const nvinfer1::Dims4 & shape,const T * idata,const nvinfer1::Dims4 & istrides,T * odata,const nvinfer1::Dims4 & ostrides)684 void Reorder4(const nvinfer1::Dims4& shape, const T* idata,
685               const nvinfer1::Dims4& istrides, T* odata,
686               const nvinfer1::Dims4& ostrides) {
687   for (int n = 0; n < shape.d[0]; ++n) {
688     for (int c = 0; c < shape.d[1]; ++c) {
689       for (int h = 0; h < shape.d[2]; ++h) {
690         for (int w = 0; w < shape.d[3]; ++w) {
691           odata[n * ostrides.d[0] + c * ostrides.d[1] + h * ostrides.d[2] +
692                 w * ostrides.d[3]] =
693               idata[n * istrides.d[0] + c * istrides.d[1] + h * istrides.d[2] +
694                     w * istrides.d[3]];
695         }
696       }
697     }
698   }
699 }
700 
701 template <typename T>
Reorder2(const nvinfer1::DimsHW & shape,const T * idata,const nvinfer1::DimsHW & istrides,T * odata,const nvinfer1::DimsHW & ostrides)702 void Reorder2(const nvinfer1::DimsHW& shape, const T* idata,
703               const nvinfer1::DimsHW& istrides, T* odata,
704               const nvinfer1::DimsHW& ostrides) {
705   for (int h = 0; h < shape.h(); ++h) {
706     for (int w = 0; w < shape.w(); ++w) {
707       odata[h * ostrides.h() + w * ostrides.w()] =
708           idata[h * istrides.h() + w * istrides.w()];
709     }
710   }
711 }
712 
713 // TODO(jie): fallback to tensorflow!!
ReorderCKtoKC(const TRT_ShapedWeights & iweights,TRT_ShapedWeights * oweights)714 void ReorderCKtoKC(const TRT_ShapedWeights& iweights,
715                    TRT_ShapedWeights* oweights) {
716   const int c = iweights.Shape().dim(0);
717   const int k = iweights.Shape().dim(1);
718   oweights->Shape().dim(0) = k;
719   oweights->Shape().dim(1) = c;
720   const nvinfer1::DimsHW istrides = {1, k};
721   const nvinfer1::DimsHW ostrides = {c, 1};
722   switch (iweights.TrtDType()) {
723     case nvinfer1::DataType::kFLOAT: {
724       Reorder2({k, c}, iweights.GetPointer<float>(), istrides,
725                oweights->GetPointer<float>(), ostrides);
726       break;
727     }
728     case nvinfer1::DataType::kHALF: {
729       Reorder2({k, c}, iweights.GetPointer<Eigen::half>(), istrides,
730                oweights->GetPointer<Eigen::half>(), ostrides);
731       break;
732     }
733     default:
734       LOG(FATAL) << "Unsupported type in reorder expected fp32 or fp16 but got "
735                  << DebugString(iweights.TrtDType());
736   }
737 }
738 
ReorderRSCKToKCRS(const TRT_ShapedWeights & iweights,TRT_ShapedWeights * oweights,const int num_groups)739 void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
740                        TRT_ShapedWeights* oweights, const int num_groups) {
741   CHECK(iweights.TrtDType() == oweights->TrtDType());
742   CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
743   // K indexes over output channels, C over input channels, and R and S over the
744   // height and width of the convolution
745   const int r = iweights.Shape().dim(0);
746   const int s = iweights.Shape().dim(1);
747   // TRT requires GKcRS, while TF depthwise has RSCK where c=1, C=G
748   const int c = iweights.Shape().dim(2) / num_groups;
749   const int k = iweights.Shape().dim(3) * num_groups;
750   VLOG(2) << "num_groups: " << num_groups << "c" << iweights.Shape().dim(2)
751           << " then " << c << "k" << iweights.Shape().dim(3) << " then " << k
752           << "r" << iweights.Shape().dim(0) << " then " << r << "s"
753           << iweights.Shape().dim(1) << " then " << s;
754   oweights->Shape().dim(0) = k / num_groups;
755   oweights->Shape().dim(1) = c * num_groups;
756   oweights->Shape().dim(2) = r;
757   oweights->Shape().dim(3) = s;
758   const nvinfer1::Dims4 istrides = {1, k, s * k * c, c * k};
759   const nvinfer1::Dims4 ostrides = {c * r * s, r * s, s, 1};
760   switch (iweights.TrtDType()) {
761     case nvinfer1::DataType::kFLOAT: {
762       Reorder4({k, c, r, s}, iweights.GetPointer<float>(), istrides,
763                oweights->GetPointer<float>(), ostrides);
764       break;
765     }
766     case nvinfer1::DataType::kHALF: {
767       Reorder4({k, c, r, s}, iweights.GetPointer<Eigen::half>(), istrides,
768                oweights->GetPointer<Eigen::half>(), ostrides);
769       break;
770     }
771 
772     default:
773       LOG(FATAL) << "Unsupported type, expected fp32 or fp16 but got "
774                  << DebugString(iweights.TrtDType());
775   }
776 }
777 
778 // Initialize a Dims object with arbitrary dimension
InitDimsN(std::initializer_list<int> list)779 nvinfer1::Dims InitDimsN(std::initializer_list<int> list) {
780   nvinfer1::Dims dim;
781   dim.nbDims = list.size();
782   std::copy(list.begin(), list.end(), dim.d);
783   return dim;
784 }
785 
786 // Reorder 3D convolution weights from TF to TRT
ReorderDRSCKToKCDRS(const TRT_ShapedWeights & iweights,TRT_ShapedWeights * oweights,const int num_groups)787 void ReorderDRSCKToKCDRS(const TRT_ShapedWeights& iweights,
788                          TRT_ShapedWeights* oweights, const int num_groups) {
789   DCHECK(iweights.TrtDType() == oweights->TrtDType());
790   CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
791   // K indexes over output channels, C over input channels, and R, S, D over the
792   // height, width, depth
793   const int d = iweights.Shape().dim(0);
794   const int r = iweights.Shape().dim(1);
795   const int s = iweights.Shape().dim(2);
796   // TRT requires GKcRS, while TF depthwise has RSCK where c=1, C=G
797   const int c = iweights.Shape().dim(3) / num_groups;
798   const int k = iweights.Shape().dim(4) * num_groups;
799 
800   VLOG(2) << "num_groups: " << num_groups << ", c: " << iweights.Shape().dim(3)
801           << " becomes " << c << ", k: " << iweights.Shape().dim(4)
802           << " becomes " << k << ", d: " << d << ", r: " << r << ", s: " << s;
803 
804   oweights->Shape().dim(0) = iweights.Shape().dim(4);  // k / num_groups;
805   oweights->Shape().dim(1) = iweights.Shape().dim(3);  // c * num_groups;
806   oweights->Shape().dim(2) = d;
807   oweights->Shape().dim(3) = r;
808   oweights->Shape().dim(4) = s;
809 
810   nvinfer1::Dims shape =
811       InitDimsN({k, c, d, r, s});  // KCDRS shape (same as output)
812 
813   nvinfer1::Dims ostrides =
814       InitDimsN({c * d * r * s, d * r * s, r * s, s,
815                  1});  // Output = KCDRS = k*CDRS + c*DRS + d*RS + r*S + s
816 
817   nvinfer1::Dims istrides =
818       InitDimsN({1, k, r * s * c * k, s * c * k,
819                  c * k});  // Input = DRSCK = k*1 + c*K + d*RSCK + r*SCK + s*CK
820 
821   switch (iweights.TrtDType()) {
822     case nvinfer1::DataType::kFLOAT: {
823       Reorder5(shape, iweights.GetPointer<float>(), istrides,
824                oweights->GetPointer<float>(), ostrides);
825       break;
826     }
827     case nvinfer1::DataType::kHALF: {
828       Reorder5(shape, iweights.GetPointer<Eigen::half>(), istrides,
829                oweights->GetPointer<Eigen::half>(), ostrides);
830       break;
831     }
832     default:
833       LOG(FATAL) << "Unsupported type, expected fp32 or fp16 but got "
834                  << DebugString(iweights.TrtDType());
835   }
836 }
837 
OpConverterParams(const NodeDef & node_def,const std::vector<TRT_TensorOrWeights> & inputs,std::vector<TRT_TensorOrWeights> * outputs,TrtWeightStore * weight_store,TrtPrecisionMode precision_mode,bool use_calibration,bool use_implicit_batch,bool use_explicit_precision)838 OpConverterParams::OpConverterParams(
839     const NodeDef& node_def, const std::vector<TRT_TensorOrWeights>& inputs,
840     std::vector<TRT_TensorOrWeights>* outputs, TrtWeightStore* weight_store,
841     TrtPrecisionMode precision_mode, bool use_calibration,
842     bool use_implicit_batch, bool use_explicit_precision)
843     : node_def(node_def),
844       inputs(inputs),
845       outputs(outputs),
846       validation_only(true),
847       weight_store(weight_store),
848       precision_mode(precision_mode),
849       use_calibration(use_calibration),
850       use_implicit_batch(use_implicit_batch),
851       use_explicit_precision(use_explicit_precision) {}
852 
OpConverterParams(Converter * converter,const NodeDef & node_def,const std::vector<TRT_TensorOrWeights> & inputs,std::vector<TRT_TensorOrWeights> * outputs,TrtWeightStore * weight_store)853 OpConverterParams::OpConverterParams(
854     Converter* converter, const NodeDef& node_def,
855     const std::vector<TRT_TensorOrWeights>& inputs,
856     std::vector<TRT_TensorOrWeights>* outputs, TrtWeightStore* weight_store)
857     : converter(converter),
858       node_def(node_def),
859       inputs(inputs),
860       outputs(outputs),
861       validation_only(false),
862       weight_store(weight_store),
863       precision_mode(converter->precision_mode()),
864       use_calibration(converter->use_calibration()),
865       use_implicit_batch(converter->use_implicit_batch()),
866       use_explicit_precision(converter->UseExplicitPrecision()) {}
867 
TrtNodeValidator(const grappler::GraphProperties & graph_properties,TrtPrecisionMode precision_mode,bool use_calibration,bool use_implicit_batch,bool use_explicit_precision)868 TrtNodeValidator::TrtNodeValidator(
869     const grappler::GraphProperties& graph_properties,
870     TrtPrecisionMode precision_mode, bool use_calibration,
871     bool use_implicit_batch, bool use_explicit_precision)
872     : graph_properties_(graph_properties),
873       precision_mode_(precision_mode),
874       use_calibration_(use_calibration),
875       use_implicit_batch_(use_implicit_batch),
876       use_explicit_precision_(use_explicit_precision) {}
877 
GetValidator(const std::string & op)878 StatusOr<OpConverter> TrtNodeValidator::GetValidator(const std::string& op) {
879   return GetOpConverterRegistry()->LookUp(op);
880 }
881 
ConvertToTensorOrWeights(const NodeDef & node_def,int output_port,TRT_TensorOrWeights * tensor_or_weights)882 Status TrtNodeValidator::ConvertToTensorOrWeights(
883     const NodeDef& node_def, int output_port,
884     TRT_TensorOrWeights* tensor_or_weights) {
885   // Treat handles separately.
886   if (node_def.op() == "VarHandleOp" || node_def.op() == "Placeholder") {
887     AttrSlice attrs(node_def);
888     DataType dtype;
889     TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "dtype", &dtype));
890     if (dtype == DataType::DT_RESOURCE) {
891       // The converter doesn't use the input resource at the validation stage
892       // (it gets the dtype and shape from attributes). A fake resource can be
893       // used.
894       ResourceHandle fake_resource;
895       *tensor_or_weights = TRT_TensorOrWeights(fake_resource);
896       return Status::OK();
897     }
898   }
899 
900   if (node_def.op() == "Const" || node_def.op() == "VariableV2") {
901     // The output of the conversion will be used as input to other nodes to
902     // determine whether TRT supports those nodes. If it cannot convert the
903     // Const, it's very likely we cannot treat it as a tensor and make it an
904     // input to the TRT network, since TRT removes the first dimension and
905     // treats it as batch size. Also, it's not likely that the converter can
906     // support the op, and performance may suffer even if it can, so we just
907     // simply return error if the conversion fails.
908     if (output_port != 0) {
909       return errors::InvalidArgument(node_def.op(),
910                                      " node should only have one output.");
911     }
912     std::vector<TRT_TensorOrWeights> inputs;
913     return ConvertConstToWeights(node_def, inputs, tensor_or_weights);
914   }
915   if (node_def.op() == "ReadVariableOp") {
916     // Similar treatment to Const and VariableV2, but we provide a fake
917     // resource input to the converter.
918     const std::vector<TRT_TensorOrWeights> inputs{
919         TRT_TensorOrWeights(ResourceHandle())};
920 
921     // Convert the variable to weights.
922     return ConvertConstToWeights(node_def, inputs, tensor_or_weights);
923   }
924   if (!graph_properties_.HasOutputProperties(node_def.name())) {
925     return errors::InvalidArgument("Shape and data type are unknown");
926   }
927 
928   // Validate and convert shape and dtype.
929   const auto& output_params =
930       graph_properties_.GetOutputProperties(node_def.name());
931   const auto& tensor_properties = output_params.at(output_port);
932   const DataType dtype = tensor_properties.dtype();
933   const PartialTensorShape shape = tensor_properties.shape();
934   nvinfer1::DataType trt_dtype;
935   nvinfer1::Dims trt_dims;
936   int batch_size = -1;
937   TF_RETURN_IF_ERROR(ValidateTensorProperties(
938       node_def.op(), dtype, shape, use_implicit_batch_,
939       /*validation_only_=*/true, &trt_dtype, &trt_dims, &batch_size));
940 
941   // Adds a fake ITensor. This is fine since op converter operates in
942   // validation-only mode and it won't (and shouldn't) use the tensor to do
943   // any TRT network operations.
944   *tensor_or_weights = TRT_TensorOrWeights(trt_dtype, trt_dims, batch_size);
945   return Status::OK();
946 }
947 
IsTensorRTCandidate(const Node * node)948 Status TrtNodeValidator::IsTensorRTCandidate(const Node* node) {
949   const string& op = node->def().op();
950   // In INT8 mode, we will always apply the quantization ranges provided by
951   // these ops to the relevant tensors. This happens regardless of the value of
952   // use_calibration.
953   bool is_supported_op = false;
954   if (absl::c_find(kQuantizationOpNames, op) != kQuantizationOpNames.end()) {
955     is_supported_op = (precision_mode_ == TrtPrecisionMode::INT8);
956   } else {
957     is_supported_op = GetValidator(op).ok();
958   }
959 
960   if (!is_supported_op) {
961     return errors::Unimplemented("Op type ", op, " is not supported.");
962   }
963 
964   // Convert input NodeDef and corresponding output ports to
965   // TRT_TensorOrWeights.
966   std::vector<TRT_TensorOrWeights> inputs;
967   std::vector<const Edge*> input_edges;
968   TF_RETURN_IF_ERROR(node->input_edges(&input_edges));
969   for (const Edge* edge : input_edges) {
970     // Go up the chain of Identity nodes.
971     Node* src_node = edge->src();
972     while (src_node->def().op() == "Identity") {
973       std::vector<const Edge*> input_edges_temp;
974       TF_RETURN_IF_ERROR(src_node->input_edges(&input_edges_temp));
975       src_node = input_edges_temp[0]->src();
976     }
977     const NodeDef& src_def = src_node->def();
978 
979     TRT_TensorOrWeights tensor_or_weights;
980     Status status = ConvertToTensorOrWeights(src_def, edge->src_output(),
981                                              &tensor_or_weights);
982     if (!status.ok()) {
983       VLOG(2) << "Failed to convert input `" << src_def.name() << "` to a "
984               << "TRT_TensorOrWeights: " << status.error_message();
985 
986       return errors::Internal(
987           "Failed to convert at least one input to a TRT_TensorOrWeights: ",
988           status.error_message());
989     }
990     inputs.push_back(tensor_or_weights);
991   }
992 
993   auto validator = GetValidator(op);
994   TF_RETURN_IF_ERROR(validator.status());
995   OpConverterParams params(node->def(), inputs, /*arg_outputs=*/nullptr,
996                            &weight_store_, precision_mode_, use_calibration_,
997                            use_implicit_batch_, use_explicit_precision_);
998   return (*validator)(&params);
999 }
1000 
ConvertConstToWeights(const NodeDef & const_node_def,const std::vector<TRT_TensorOrWeights> & inputs,TRT_TensorOrWeights * output)1001 Status TrtNodeValidator::ConvertConstToWeights(
1002     const NodeDef& const_node_def,
1003     const std::vector<TRT_TensorOrWeights>& inputs,
1004     TRT_TensorOrWeights* output) {
1005   std::vector<TRT_TensorOrWeights> outputs;
1006   OpConverterParams params(const_node_def, inputs, &outputs, &weight_store_,
1007                            precision_mode_, use_calibration_,
1008                            use_implicit_batch_, use_explicit_precision_);
1009   auto const_val = GetValidator(const_node_def.op());
1010   TF_RETURN_IF_ERROR(const_val.status());
1011   Status status = (*const_val)(&params);
1012   if (status.ok() && (output != nullptr)) {
1013     *output = outputs[0];
1014   }
1015   return status;
1016 }
1017 
1018 // static
Create(TrtPrecisionMode precision_mode,bool use_calibration,nvinfer1::ILogger * trt_logger,const bool use_implicit_batch,absl::string_view engine_name,bool use_explicit_precision,OpKernelContext * ctx)1019 StatusOr<std::unique_ptr<Converter>> Converter::Create(
1020     TrtPrecisionMode precision_mode, bool use_calibration,
1021     nvinfer1::ILogger* trt_logger, const bool use_implicit_batch,
1022     absl::string_view engine_name, bool use_explicit_precision,
1023     OpKernelContext* ctx) {
1024   std::unique_ptr<Converter> converter = absl::WrapUnique(new Converter(
1025       precision_mode, use_calibration, trt_logger, use_implicit_batch,
1026       engine_name, use_explicit_precision, ctx));
1027   TF_RETURN_IF_ERROR(converter->Init(trt_logger));
1028   return converter;
1029 }
1030 
Converter(TrtPrecisionMode precision_mode,bool use_calibration,nvinfer1::ILogger * trt_logger,const bool use_implicit_batch,absl::string_view engine_name,bool use_explicit_precision,OpKernelContext * ctx)1031 Converter::Converter(TrtPrecisionMode precision_mode, bool use_calibration,
1032                      nvinfer1::ILogger* trt_logger,
1033                      const bool use_implicit_batch,
1034                      absl::string_view engine_name, bool use_explicit_precision,
1035                      OpKernelContext* ctx)
1036     : ctx_(ctx),
1037       precision_mode_(precision_mode),
1038       use_calibration_(use_calibration),
1039       use_implicit_batch_(use_implicit_batch),
1040       engine_name_(engine_name),
1041       use_explicit_precision_(use_explicit_precision) {
1042   MaybeInitializeTrtPlugins(trt_logger);
1043 }
1044 
Init(nvinfer1::ILogger * trt_logger)1045 Status Converter::Init(nvinfer1::ILogger* trt_logger) {
1046   VLOG(1) << "Creating TensorRT builder";
1047   trt_builder_.reset(nvinfer1::createInferBuilder(*trt_logger));
1048 
1049   VLOG(1) << "Creating TensorRT network";
1050   uint32_t flags =
1051       use_implicit_batch_
1052           ? 0U
1053           : (1U << static_cast<int>(
1054                  nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH));
1055   if (use_explicit_precision_) {
1056     flags |=
1057         (1U << static_cast<int>(
1058              nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_PRECISION));
1059   }
1060   trt_network_.reset(trt_builder_->createNetworkV2(flags));
1061   if (!trt_network_) {
1062     return errors::Internal("Failed to create TensorRT network object");
1063   }
1064   return Status::OK();
1065 }
1066 
ConvertNode(const NodeDef & node_def)1067 Status Converter::ConvertNode(const NodeDef& node_def) {
1068   std::vector<TRT_TensorOrWeights> inputs;
1069   std::vector<TRT_TensorOrWeights> outputs;
1070   TF_RETURN_IF_ERROR(this->GetInputs(node_def, &inputs));
1071 
1072   OpConverterParams params(this, node_def, inputs, &outputs, &weight_store_);
1073   const string& op = node_def.op();
1074   auto op_converter = GetOpConverterRegistry()->LookUp(op);
1075   TF_RETURN_IF_ERROR(op_converter.status());
1076   TF_RETURN_IF_ERROR((*op_converter)(&params));
1077 
1078   for (size_t i = 0; i < outputs.size(); ++i) {
1079     TRT_TensorOrWeights& output = outputs[i];
1080     string output_name = node_def.name();
1081     if (i != 0) {
1082       StrAppend(&output_name, ":", i);
1083     }
1084     // We need to check the name before setting it. If the input is one of the
1085     // engine input, setting the name here will overwrite engine input
1086     // bindings which will cause runtime error.
1087     // TODO(tmorris): Remove this work-around once we use TRT's IIdentityLayer
1088     // in ConvertIdentity.
1089     if (output.is_tensor()) {
1090       const char* tensor_name = output.tensor()->getName();
1091       if (!IsEngineInput(tensor_name)) {
1092         // TRT initializes tensor names as "(Unnamed ITensor* N)". We rename
1093         // them to match their corresponding TensorFlow name.
1094         // Note: ITensors that we create internally within TF-TRT which are
1095         // not inputs or outputs of a node will not be renamed. This is a
1096         // potential cause of confusion if an error message or warning
1097         // mentions the unnamed tensor.
1098         output.tensor()->setName(output_name.c_str());
1099       }
1100     }
1101     VLOG(2) << "Adding out tensor " << output_name << ": "
1102             << output.DebugString();
1103     Status status = AddTensorOrWeights(output_name, output);
1104     if (!status.ok()) {
1105       return errors::Create(
1106           status.code(),
1107           StrCat("Failed to add output for node: ", node_def.name(), ": ",
1108                  status.error_message()),
1109           errors::GetPayloads(status));
1110     }
1111   }
1112   return Status::OK();
1113 }
1114 
AddInputTensor(const string & name,nvinfer1::DataType dtype,const nvinfer1::Dims & dims,int batch_size)1115 Status Converter::AddInputTensor(const string& name, nvinfer1::DataType dtype,
1116                                  const nvinfer1::Dims& dims, int batch_size) {
1117   // We verify the batch size only for the input nodes, and rely on individual
1118   // op converter to ensure the batch size of the outputs is not changed.
1119   // TODO(laigd): we need to test this properties.
1120   Status status;
1121   if (use_implicit_batch_) {
1122     status = MaybeUpdateBatchSize(batch_size);
1123     if (!status.ok()) {
1124       return errors::CreateWithUpdatedMessage(
1125           status, StrCat("Batch size doesn't match for tensor ", name, ": ",
1126                          status.error_message()));
1127     }
1128   }
1129   ITensorProxyPtr tensor = network()->addInput(name.c_str(), dtype, dims);
1130   if (*tensor == nullptr) {
1131     return errors::InvalidArgument("Failed to create Input layer tensor ", name,
1132                                    " rank=", dims.nbDims);
1133   }
1134   status = AddTensorOrWeights(name, TRT_TensorOrWeights(tensor));
1135   if (!status.ok()) {
1136     return errors::CreateWithUpdatedMessage(
1137         status, StrCat("Failed to add input tensor ", name, ": ",
1138                        status.error_message()));
1139   }
1140   return Status::OK();
1141 }
1142 
AddInputResource(const string & name,const ResourceHandle & resource)1143 Status Converter::AddInputResource(const string& name,
1144                                    const ResourceHandle& resource) {
1145   Status status = AddTensorOrWeights(name, TRT_TensorOrWeights(resource));
1146   if (!status.ok()) {
1147     return errors::CreateWithUpdatedMessage(
1148         status, StrCat("Failed to add input resource ", name, ": ",
1149                        status.error_message()));
1150   }
1151   return Status::OK();
1152 }
1153 
RenameAndMarkOutputTensors(const std::vector<Converter::EngineOutputInfo> & output_tensors)1154 Status Converter::RenameAndMarkOutputTensors(
1155     const std::vector<Converter::EngineOutputInfo>& output_tensors) {
1156   int output_index = 0;
1157   for (const auto& output : output_tensors) {
1158     TRT_TensorOrWeights tensor_or_weights;
1159     TF_RETURN_IF_ERROR(
1160         GetTensorOrWeights(output.source_tensor_name, &tensor_or_weights));
1161     if (!tensor_or_weights.is_tensor()) {
1162       return errors::InvalidArgument("Output ", output.source_tensor_name,
1163                                      " is weights not tensor");
1164     }
1165     ITensorProxyPtr tensor = tensor_or_weights.tensor();
1166     if (*tensor == nullptr) {
1167       return errors::NotFound("Output tensor not found: ",
1168                               output.source_tensor_name);
1169     }
1170     // Check if this tensor has already been marked as an input or output.
1171     //
1172     // ConvertIdentity can cause the same tensor to be repeated in
1173     // output_tensors, which can cause us to overwrite the name of the output
1174     // tensor binding. For example, if we rename OutputPH_0 to OutputPH_1 then
1175     // we won't be able to locate OutputPH_0 during runtime. To fix this,
1176     // duplicate the tensor using no-op shuffle.
1177     //
1178     // TODO(tmorris): Remove this work-around once we use TRT's IIdentityLayer
1179     // in ConvertIdentity.
1180     if (IsEngineInput(tensor->getName()) || IsEngineOutput(tensor->getName())) {
1181       // Using shuffle layer for identity by not setting reshape or transpose.
1182       nvinfer1::IShuffleLayer* layer =
1183           network()->addShuffle(*tensor->trt_tensor());
1184       TFTRT_RETURN_ERROR_IF_NULLPTR(
1185           layer, StrCat("Output Copy for ", tensor->getName()));
1186       SetLayerName(layer, tensor->getName(), "shuffle", output_index);
1187       tensor = layer->getOutput(0);
1188     }
1189     tensor->setName(output.dest_node_name.c_str());
1190     network()->markOutput(*tensor->trt_tensor());
1191     // Set type after marking as output. TRT only supports setType for engine
1192     // outputs and inputs (type is inferred otherwise).
1193     tensor->setType(output.trt_dtype);
1194     output_index++;
1195     VLOG(1) << "Marking output TRT tensor " << output.source_tensor_name
1196             << " with data type " << DebugString(output.trt_dtype)
1197             << ", which feeds TF node " << output.dest_node_name;
1198   }
1199   if (VLOG_IS_ON(2)) {
1200     VLOG(2) << "Created TensorRT network with the following layers:";
1201     for (int i = 0; i < network()->getNbLayers(); i++) {
1202       auto layer = network()->getLayer(i);
1203       VLOG(2) << "    " << layer->getName() << " ("
1204               << "type: " << static_cast<int>(layer->getType())
1205               << ", precision: " << static_cast<int>(layer->getPrecision())
1206               << ")";
1207     }
1208   }
1209   return Status::OK();
1210 }
1211 
1212 // Returns the value of TF_TRT_ABORT_CUDA_ENGINE_BUILD environment variable.
1213 // This variable can be used to abort CUDA engine construction, therefore it
1214 // provides a way to test and debug the native segment fallback of TF-TRT.
AbortCudaEngineBuild()1215 bool AbortCudaEngineBuild() {
1216   bool value;
1217   Status status = ReadBoolFromEnvVar("TF_TRT_ABORT_CUDA_ENGINE_BUILD",
1218                                      /*default_value=*/false, &value);
1219   if (!status.ok()) {
1220     LOG(ERROR) << status;
1221   }
1222   return value;
1223 }
1224 
BuildCudaEngine(TrtUniquePtrType<nvinfer1::ICudaEngine> * engine,int max_batch_size,size_t max_workspace_size_bytes,nvinfer1::IGpuAllocator * allocator,TRTInt8Calibrator * calibrator,TrtShapeOptimizationProfile * profiles)1225 Status Converter::BuildCudaEngine(
1226     TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, int max_batch_size,
1227     size_t max_workspace_size_bytes, nvinfer1::IGpuAllocator* allocator,
1228     TRTInt8Calibrator* calibrator, TrtShapeOptimizationProfile* profiles) {
1229   tensorflow::profiler::AnnotatedTraceMe activity(
1230       [&]() {
1231         return tensorflow::profiler::TraceMeOpOverride("TRTEngineOp",
1232                                                        "BuildEngine");
1233       },
1234       tensorflow::profiler::TraceMeLevel::kInfo);
1235 
1236   if (AbortCudaEngineBuild()) {
1237     return errors::Aborted(
1238         "Engine creation aborted by TF_TRT_ABORT_CUDA_ENGINE_BUILD variable");
1239   }
1240 
1241   VLOG(1) << "Configuring TensorRT builder";
1242   trt_builder_->setMaxBatchSize(max_batch_size);
1243   trt_builder_->setGpuAllocator(allocator);
1244 
1245   // Create a network configuration and use it to build a TRT engine.
1246   TrtUniquePtrType<nvinfer1::IBuilderConfig> builder_config(
1247       trt_builder_->createBuilderConfig());
1248   builder_config->setMaxWorkspaceSize(max_workspace_size_bytes);
1249 
1250   // Create the algorithm selector. For TensorRT 7.x, the algorithm selector
1251   // cannot be used when building with INT8 calibration.
1252   std::unique_ptr<nvinfer1::IAlgorithmSelector> trt_algorithm_selector{nullptr};
1253   if (!IS_TRT_VERSION_GE(8, 0, 0, 0)) {
1254     if (!use_calibration_ || precision_mode_ != TrtPrecisionMode::INT8) {
1255       trt_algorithm_selector = MaybeCreateAlgorithmSelector();
1256     }
1257   } else {
1258     trt_algorithm_selector = MaybeCreateAlgorithmSelector();
1259   }
1260 
1261   if (trt_algorithm_selector != nullptr) {
1262     builder_config->setAlgorithmSelector(trt_algorithm_selector.get());
1263   }
1264 
1265 #if IS_TRT_VERSION_GE(8, 0, 0, 0)
1266   builder_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS);
1267   VLOG(1) << "Setting sparsity for TensorRT8!";
1268 #endif
1269 
1270   if (tensorflow::tensor_float_32_execution_enabled()) {
1271     builder_config->setFlag(nvinfer1::BuilderFlag::kTF32);
1272   } else {
1273     builder_config->clearFlag(nvinfer1::BuilderFlag::kTF32);
1274   }
1275 
1276   if (precision_mode_ == TrtPrecisionMode::FP16) {
1277     builder_config->setFlag(nvinfer1::BuilderFlag::kFP16);
1278   } else if (precision_mode_ == TrtPrecisionMode::INT8) {
1279     // FP16 is not available in Explicit Precision mode with TensorRT 7.
1280     if (IS_TRT_VERSION_GE(8, 0, 0, 0) || !use_explicit_precision_) {
1281       builder_config->setFlag(nvinfer1::BuilderFlag::kFP16);
1282     } else {
1283       LOG_WARNING_WITH_PREFIX << "With explicit precision mode, FP16 is not "
1284                                  "allowed before TensorRT 8. TRT will consider "
1285                                  "INT8 and FP32 tactics.";
1286     }
1287     builder_config->setFlag(nvinfer1::BuilderFlag::kINT8);
1288   }
1289   if (!use_implicit_batch_ && profiles) {
1290     TF_RETURN_IF_ERROR(profiles->ConfigureBuilder(
1291         trt_builder_.get(), builder_config.get(), network()));
1292   }
1293   if (precision_mode_ == TrtPrecisionMode::INT8) {
1294     builder_config->setInt8Calibrator(use_calibration_ ? calibrator : nullptr);
1295   }
1296 
1297   std::unique_ptr<TimingCacheRegistry::TimingCache> timing_cache = nullptr;
1298   // We only use a timing cache if the algorithm selector is not used. If we
1299   // are using TRT version >= 8.0, then we can try to deserialize an existing
1300   // cache.
1301   if (trt_algorithm_selector == nullptr) {
1302 #if IS_TRT_VERSION_GE(8, 0, 0, 0)
1303     TimingCacheRegistry* registry = GetTimingCacheRegistry();
1304 
1305     auto cache = registry->LookUp("default_cache", builder_config.get());
1306     if (!cache.ok()) {
1307       LOG(WARNING) << "failed to create a timing cache: "
1308                    << cache.status().error_message();
1309     } else {
1310       timing_cache = std::move(*cache);
1311       builder_config->setTimingCache(*timing_cache, /*ignoreMismatch*/ false);
1312     }
1313 #endif  // IS_TRT_VERSION_GE(8, 0, 0, 0)
1314   } else {
1315     // Disabling the timing cache is recommended when using the algorithm
1316     // selector.
1317     builder_config->setFlag(nvinfer1::BuilderFlag::kDISABLE_TIMING_CACHE);
1318   }
1319 
1320   string precision_mode_str;
1321   TF_RETURN_IF_ERROR(
1322       TrtPrecisionModeToName(precision_mode_, &precision_mode_str));
1323   string trt_network_name = StrCat(
1324       "TF:", TF_VERSION_STRING, ", ",
1325       "TRT:", absl::StrJoin(GetLoadedTensorRTVersion(), "."), "-",
1326       "Precision:", precision_mode_str, ", ", "Calibration:", use_calibration_,
1327       ", ", "Max-Batch-Size:", max_batch_size, ", ",
1328       "Max-Workspace-Size:", max_workspace_size_bytes);
1329   VLOG(1) << "Setting TensorRT network name to " << trt_network_name;
1330   network()->setName(trt_network_name.c_str());
1331 
1332   VLOG(1) << "Building TensorRT engine";
1333   if (VLOG_IS_ON(2)) {
1334     VLOG(2) << "Network inputs";
1335     int n_inputs = network()->getNbInputs();
1336     for (int i = 0; i < n_inputs; i++) {
1337       const ITensorProxyPtr input = network()->getInput(i);
1338       if (*input) {
1339         VLOG(2) << "  " << i << " " << input->getName();
1340       } else {
1341         VLOG(2) << "Could not find input " << i;
1342       }
1343     }
1344   }
1345   engine->reset(
1346       trt_builder_->buildEngineWithConfig(*network(), *builder_config));
1347   if (engine->get() == nullptr) {
1348     return errors::Internal("Failed to build TensorRT engine");
1349   }
1350   if (VLOG_IS_ON(2)) {
1351     VLOG(2) << "TRT engine created";
1352     int nbBindings = (*engine)->getNbBindings();
1353     VLOG(2) << "Number of engine bindings: " << nbBindings;
1354     for (int i = 0; i < nbBindings; i++) {
1355       auto get_location_string = [&engine](int i) {
1356         if ((*engine)->getLocation(i) == nvinfer1::TensorLocation::kDEVICE)
1357           return " on device";
1358         else
1359           return " on host";
1360       };
1361       VLOG(2) << "Binding " << i << " name: " << (*engine)->getBindingName(i)
1362               << get_location_string(i);
1363     }
1364   }
1365 
1366   // Write back the new timing cache results to the registry.
1367   if (timing_cache) {
1368     GetTimingCacheRegistry()->Upsert("default_cache", timing_cache.get());
1369   }
1370 
1371   return Status::OK();
1372 }
1373 
MaybeUpdateBatchSize(int batch_size)1374 Status Converter::MaybeUpdateBatchSize(int batch_size) {
1375   // OK iff either is unknown or they equal to each other.
1376   if (this->batch_size_ < 0 || batch_size < 0 ||
1377       this->batch_size_ == batch_size) {
1378     if (this->batch_size_ < 0 && batch_size >= 0) {
1379       this->batch_size_ = batch_size;
1380     }
1381     return Status::OK();
1382   }
1383   return errors::InvalidArgument(
1384       "Provided batch size does not match converter batch size: ", batch_size,
1385       " vs ", batch_size_);
1386 }
1387 
AddTensorOrWeights(const string & name,TRT_TensorOrWeights input)1388 Status Converter::AddTensorOrWeights(const string& name,
1389                                      TRT_TensorOrWeights input) {
1390   // Set the batch size of the tensor, using batch size collected from the
1391   // input tensors to the TRT subgraph at the beginning of the conversion.
1392   // We rely on the individual op converter to understand the semantics of the
1393   // TF node, and make sure it doesn't change the batch size nor introduce
1394   // intra-element dependency inside the batch.
1395   if (use_implicit_batch_ && input.is_tensor()) {
1396     input.set_batch_size(batch_size_);
1397   }
1398   if (trt_tensors_.insert({name, std::move(input)}).second) return Status::OK();
1399   return errors::AlreadyExists("tensor/weights ", name, " already exist.");
1400 }
1401 
GetTensorOrWeights(const string & name,TRT_TensorOrWeights * output)1402 Status Converter::GetTensorOrWeights(const string& name,
1403                                      TRT_TensorOrWeights* output) {
1404   if (!trt_tensors_.count(name)) {
1405     return errors::NotFound("Tensor or weights with name ", name,
1406                             " could not be found.");
1407   }
1408   *output = trt_tensors_.at(name);
1409   return Status::OK();
1410 }
1411 
TransposeTensor(ITensorProxyPtr input_tensor,const std::vector<int> & order_with_batch_dim,ITensorProxyPtr * output_tensor,const NodeDef & node_def,absl::string_view sub_op_name)1412 Status Converter::TransposeTensor(ITensorProxyPtr input_tensor,
1413                                   const std::vector<int>& order_with_batch_dim,
1414                                   ITensorProxyPtr* output_tensor,
1415                                   const NodeDef& node_def,
1416                                   absl::string_view sub_op_name) {
1417   const auto dims = input_tensor->getDimensions();
1418   const int order_size = use_implicit_batch_ ? order_with_batch_dim.size() - 1
1419                                              : order_with_batch_dim.size();
1420   if (order_size != size_t(dims.nbDims)) {
1421     return errors::InvalidArgument(
1422         "Rank of perm for transpose does not match with that of the input.");
1423   }
1424   if (use_implicit_batch_ && order_with_batch_dim[0] != 0) {
1425     return errors::Unimplemented(
1426         "Transpose at batch dimension is not supported.");
1427   }
1428 
1429   nvinfer1::IShuffleLayer* layer =
1430       this->network()->addShuffle(*input_tensor->trt_tensor());
1431   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Transpose");
1432   SetLayerName(layer, node_def, sub_op_name);
1433 
1434   nvinfer1::Permutation permutation;
1435   if (use_implicit_batch_) {
1436     for (int32_t i = 0; i < dims.nbDims; ++i) {
1437       permutation.order[i] = order_with_batch_dim[i + 1] - 1;
1438     }
1439   } else {
1440     std::copy(order_with_batch_dim.begin(), order_with_batch_dim.end(),
1441               permutation.order);
1442   }
1443   VLOG(1) << "TransposeTensor permutation: "
1444           << DebugString(permutation, dims.nbDims);
1445   layer->setFirstTranspose(permutation);
1446 
1447   nvinfer1::Dims reshape_dims;
1448   reshape_dims.nbDims = dims.nbDims;
1449   for (int32_t i = 0; i < reshape_dims.nbDims; ++i) {
1450     reshape_dims.d[i] = 0;
1451   }
1452   layer->setReshapeDimensions(reshape_dims);
1453 
1454   *output_tensor = layer->getOutput(0);
1455   return Status::OK();
1456 }
1457 
GetWeightRange(const TRT_ShapedWeights & weights,float * out_min,float * out_max) const1458 Status Converter::GetWeightRange(const TRT_ShapedWeights& weights,
1459                                  float* out_min, float* out_max) const {
1460   switch (weights.TrtDType()) {
1461     case nvinfer1::DataType::kFLOAT: {
1462       auto inp = weights.GetPointer<float>();
1463       auto result = std::minmax_element(inp, inp + weights.count());
1464       *out_min = *result.first;
1465       *out_max = *result.second;
1466       break;
1467     }
1468     case nvinfer1::DataType::kHALF: {
1469       auto inp = weights.GetPointer<Eigen::half>();
1470       auto result = std::minmax_element(inp, inp + weights.count());
1471       *out_min = static_cast<float>(*result.first);
1472       *out_max = static_cast<float>(*result.second);
1473       break;
1474     }
1475     case nvinfer1::DataType::kINT32: {
1476       auto inp = weights.GetPointer<int>();
1477       auto result = std::minmax_element(inp, inp + weights.count());
1478       *out_min = static_cast<float>(*result.first);
1479       *out_max = static_cast<float>(*result.second);
1480       break;
1481     }
1482     default:
1483       return errors::Unimplemented(
1484           "Data type not supported for GetWeightRange: ",
1485           DebugString(weights.TrtDType()));
1486   }
1487   return Status::OK();
1488 }
1489 
1490 // Constructs <tf_related_part> for the ILayer name as
1491 // <tf_node_def_name>_<sub_op_name>_<sub_op_instance> and callSetLayerNameHelper
1492 // to set the name for the ILayer.
1493 //
1494 // If the operation represented by the ILayer is generated by the converter to
1495 // support the conversion of node_def, callers need to specify a non-empty
1496 // sub_op_name to be appended to the name of node_def to avoid layer name
1497 // conflicts. If the operation is generated multiple times, callers also need
1498 // to specify sub_op_instance to be appended to the name of the layers to avoid
1499 // layer name conflicts.
SetLayerName(nvinfer1::ILayer * layer,const NodeDef & node_def,absl::string_view sub_op_name,std::optional<int> sub_op_instance,std::optional<std::string> origin_node_name)1500 void Converter::SetLayerName(nvinfer1::ILayer* layer, const NodeDef& node_def,
1501                              absl::string_view sub_op_name,
1502                              std::optional<int> sub_op_instance,
1503                              std::optional<std::string> origin_node_name) {
1504   std::string sub_op_suffix = GetLayerNameSuffix(sub_op_name, sub_op_instance);
1505   if (sub_op_suffix.empty()) {
1506     SetLayerNameHelper(layer, engine_name_, node_def.name());
1507   } else if (origin_node_name.has_value()) {
1508     auto layer_name = absl::StrCat(node_def.name(), "-",
1509                                    absl::string_view(origin_node_name.value()),
1510                                    "-", sub_op_suffix);
1511     SetLayerNameHelper(layer, engine_name_, layer_name);
1512   } else {
1513     SetLayerNameHelper(layer, engine_name_,
1514                        absl::StrCat(node_def.name(), "-", sub_op_suffix));
1515   }
1516 }
1517 
1518 // Constructs <tf_related_part> for the ILayer name as
1519 // <main_op_name>_<sub_op_name>_<sub_op_instance> and callSetLayerNameHelper to
1520 // set the name for the ILayer.
SetLayerName(nvinfer1::ILayer * layer,absl::string_view main_op_name,absl::string_view sub_op_name,std::optional<int> sub_op_instance)1521 void Converter::SetLayerName(nvinfer1::ILayer* layer,
1522                              absl::string_view main_op_name,
1523                              absl::string_view sub_op_name,
1524                              std::optional<int> sub_op_instance) {
1525   std::string layer_name_suffix =
1526       GetLayerNameSuffix(sub_op_name, sub_op_instance);
1527   SetLayerNameHelper(layer, engine_name_,
1528                      absl::StrCat(main_op_name, "-", layer_name_suffix));
1529 }
1530 
1531 // Converts 'input' of 'node_def' into 'tensor' with shape specified by 'dims'
1532 // (which doesn't contain the batch dimension).
1533 //
1534 // If validation_only is true, it doesn't do the conversion but only do some
1535 // minimum validation for the eligibility of the conversion, and *tensor will
1536 // be set to nullptr.
PrepareTensorForShape(Converter * converter,const TRT_TensorOrWeights & input,const DimsAdapter & dims,const bool validation_only,ITensorProxyPtr * tensor,const NodeDef & node_def,std::optional<int> op_instance,std::optional<std::string> origin_node_name)1537 Status PrepareTensorForShape(Converter* converter,
1538                              const TRT_TensorOrWeights& input,
1539                              const DimsAdapter& dims,
1540                              const bool validation_only,
1541                              ITensorProxyPtr* tensor, const NodeDef& node_def,
1542                              std::optional<int> op_instance,
1543                              std::optional<std::string> origin_node_name) {
1544   DimsAdapter input_dims(input.GetTrtDims());
1545   // The input shape may have -1s for dynamic shape. The target shape may have
1546   // 0s representing copy over the corresponding input dimensions. It may also
1547   // have at most one -1 representing a dimension value that needs to be
1548   // inferred. If none of those special values present, we verify that the total
1549   // sizes of the input and output shape are the same.
1550   // TODO(tfeher): Verify that the total sizes of the input and output shape are
1551   // the same in the present of 0s but no -1 in the target shape.
1552   // If an input is a weight, it is going to become a tensor via
1553   // CreateConstantLayer. So we can treat it as a tensor for
1554   // AreDimsStaticWithDifferentSize(). This really only matters for 0-D tensors.
1555   if (dims.Volume() > 0 && AreDimsStaticWithDifferentSize(input_dims, dims)) {
1556     return errors::InvalidArgument(
1557         "Incompatible shapes: ", input_dims.DebugString(), " vs. ",
1558         dims.DebugString());
1559   }
1560   // ConstantLayer requires static shapes (cannot infer -1).
1561   if (input.is_weights() && !dims.IsStatic()) {
1562     return errors::InvalidArgument("Shape is not fully defined: ",
1563                                    dims.DebugString());
1564   }
1565   if (validation_only) {
1566     *tensor = nullptr;
1567     return Status::OK();
1568   }
1569 
1570   TFTRT_RETURN_ERROR_IF_NULLPTR(converter, "converter is nullptr");
1571   if (input.is_tensor()) {
1572     if (input_dims == dims) {
1573       *tensor = input.tensor();
1574     } else {
1575       nvinfer1::IShuffleLayer* layer =
1576           converter->network()->addShuffle(*input.tensor()->trt_tensor());
1577       TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Reshape");
1578       converter->SetLayerName(layer, node_def, "shuffle", op_instance,
1579                               origin_node_name);
1580       layer->setReshapeDimensions(dims.AsTrtDims());
1581       *tensor = layer->getOutput(0);
1582     }
1583   } else {
1584     *tensor = converter->CreateConstantLayer(input.weights(), dims.AsTrtDims());
1585     TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, "TF-TRT Internal Reshape");
1586   }
1587   return Status::OK();
1588 }
1589 
ProvideQuantizationRange(ITensorProxyPtr * tensor,float min_range,float max_range)1590 void Converter::ProvideQuantizationRange(ITensorProxyPtr* tensor,
1591                                          float min_range, float max_range) {
1592   float symmetric_range = std::max(std::abs(min_range), std::abs(max_range));
1593   if ((*tensor)->is_trt_tensor()) {
1594     quantization_ranges_[(*tensor)->trt_tensor()] = symmetric_range;
1595   } else if ((*tensor)->is_simple_tensor()) {
1596     quantization_ranges_proxy_[tensor] = symmetric_range;
1597   }
1598 }
1599 
MaybeApplyQuantizationRanges()1600 void Converter::MaybeApplyQuantizationRanges() {
1601   if (precision_mode() != TrtPrecisionMode::INT8) return;
1602 
1603   // Apply ranges.
1604   for (auto pair : quantization_ranges_) {
1605     nvinfer1::ITensor* tensor = pair.first;
1606     const float range = pair.second;
1607     VLOG(1) << "Setting range for: " << tensor->getName() << ": " << range;
1608     // TODO(laigd): if 'tensor' already has a range set which doesn't match
1609     // 'range', it should report error.
1610     tensor->setDynamicRange(-range, range);
1611   }
1612   for (auto pair : quantization_ranges_proxy_) {
1613     ITensorProxyPtr tensor = *pair.first;
1614     const float range = pair.second;
1615     VLOG(1) << "Setting range for: " << tensor->getName() << ": " << range;
1616     // TODO(laigd): if 'tensor' already has a range set which doesn't match
1617     // 'range', it should report error.
1618     tensor->setDynamicRange(-range, range);
1619   }
1620 }
1621 
GetInputs(const NodeDef & node_def,std::vector<TRT_TensorOrWeights> * inputs) const1622 Status Converter::GetInputs(const NodeDef& node_def,
1623                             std::vector<TRT_TensorOrWeights>* inputs) const {
1624   for (auto const& input_name : node_def.input()) {
1625     /*************************************************************************
1626      * TODO(jie): handle case 1) here.
1627      * Normalizes the inputs and extracts associated metadata:
1628      * 1) Inputs can contain a colon followed by a suffix of characters.
1629      *    That suffix may be a single number (e.g. inputName:1) or several
1630      *    word characters separated from a number by a colon
1631      *    (e.g. inputName:foo:1). The
1632      *    latter case is used to denote inputs and outputs of functions.
1633      * 2) Control dependency inputs contain caret at the beginning and we
1634      *    remove this and annotate the edge as a control dependency.
1635      ************************************************************************/
1636     // skip control nodes
1637     if (input_name[0] == '^') continue;
1638     string name = input_name;
1639     auto last = name.find_last_of(':');
1640     // TODO(aaroey): use TensorId
1641     if (last != string::npos && last + 2 == name.size() &&
1642         name[last + 1] == '0') {
1643       name.erase(last);
1644     }
1645 
1646     if (trt_tensors_.count(name)) {
1647       TRT_TensorOrWeights input = trt_tensors_.at(name);
1648       inputs->push_back(input);
1649       VLOG(2) << "Retrieved input " << name << ": " << input.DebugString();
1650     } else {
1651       // TODO(aaroey): this should not happen, make it a CHECK.
1652       // TODO(aaroey): use StrCat for pattern like this.
1653       string msg("Node ");
1654       StrAppend(&msg, node_def.name(), " should have an input named '", name,
1655                 "' but it is not available");
1656       LOG(ERROR) << msg;
1657       return errors::InvalidArgument(msg);
1658     }
1659   }
1660   return Status::OK();
1661 }
1662 
1663 // Checks that the number of inputs match, and enforces that the inputs marked
1664 // as weights are constant. Inputs are allowed to be both weight and tensor.
CheckInputsWeights(const OpConverterParams & params,const std::vector<std::pair<string,TrtInputArg>> & expected_inputs)1665 Status CheckInputsWeights(
1666     const OpConverterParams& params,
1667     const std::vector<std::pair<string, TrtInputArg>>& expected_inputs) {
1668   const auto& inputs = params.inputs;
1669   const auto& node_def = params.node_def;
1670   TFTRT_CHECK_INPUT_SIZE(inputs.size(), expected_inputs.size(), node_def);
1671   for (int i = 0; i < inputs.size(); i++) {
1672     if (expected_inputs[i].second == TrtInputArg::kWeight &&
1673         !inputs.at(i).is_weights()) {
1674       return errors::Unimplemented("The input \"", expected_inputs[i].first,
1675                                    "\" for ", node_def.op(),
1676                                    " must be a constant");
1677     }
1678     // TODO(tfeher): Remove this check and provide a method to automatically
1679     // retrieve an input as a tensor, converting via CreateConstantLayer if it
1680     // was originally a weight. We will want a caching mechanism to prevent many
1681     // duplicate constants from being created.
1682     if (expected_inputs[i].second == TrtInputArg::kTensor &&
1683         !inputs.at(i).is_tensor()) {
1684       return errors::Unimplemented("The input \"", expected_inputs[i].first,
1685                                    "\" for ", node_def.op(),
1686                                    " must be a tensor");
1687     }
1688     if (expected_inputs[i].second == TrtInputArg::kResource &&
1689         !inputs.at(i).is_resource()) {
1690       return errors::Unimplemented("The input \"", expected_inputs[i].first,
1691                                    "\" for ", node_def.op(),
1692                                    " must be a resource handle");
1693     }
1694   }
1695   return Status::OK();
1696 }
1697 
1698 // Checks that the number of inputs match, and enforces that the inputs marked
1699 // as true are constant weights. true means that the input must be a weight,
1700 // while false means the input must be a tensor.
CheckInputsWeights(const OpConverterParams & params,const std::vector<std::pair<string,bool>> & inputs_is_weight)1701 Status CheckInputsWeights(
1702     const OpConverterParams& params,
1703     const std::vector<std::pair<string, bool>>& inputs_is_weight) {
1704   std::vector<std::pair<string, TrtInputArg>> expected_inputs;
1705   expected_inputs.reserve(inputs_is_weight.size());
1706   std::transform(
1707       inputs_is_weight.begin(), inputs_is_weight.end(),
1708       std::back_inserter(expected_inputs), [](std::pair<string, bool> x) {
1709         return std::make_pair(
1710             x.first, x.second ? TrtInputArg::kWeight : TrtInputArg::kTensor);
1711       });
1712   return CheckInputsWeights(params, expected_inputs);
1713 }
1714 
GetNodeDefTfType(const NodeDef & node_def,DataType * tf_type,const string type_attr_name_in="")1715 Status GetNodeDefTfType(const NodeDef& node_def, DataType* tf_type,
1716                         const string type_attr_name_in = "") {
1717   string type_attr_name;
1718   if (type_attr_name_in.empty()) {
1719     if (node_def.op() == "ReadVariableOp" ||
1720         node_def.op() == "ResourceGather") {
1721       type_attr_name = "dtype";
1722     } else {
1723       type_attr_name = "T";
1724     }
1725   } else {
1726     type_attr_name = type_attr_name_in;
1727   }
1728 
1729   AttrSlice attrs(node_def);
1730   if (attrs.FindByString(type_attr_name) == nullptr) {
1731     return errors::InvalidArgument("Attribute with name ", type_attr_name,
1732                                    " not found.");
1733   }
1734   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, type_attr_name, tf_type));
1735   return Status::OK();
1736 }
1737 
GetInputTfType(const OpConverterParams & params,DataType * tf_type,int pos)1738 Status GetInputTfType(const OpConverterParams& params, DataType* tf_type,
1739                       int pos) {
1740   const std::vector<TRT_TensorOrWeights>& inputs = params.inputs;
1741   if (inputs.size() <= pos) {
1742     return errors::Internal("Invalid input position");
1743   }
1744 
1745   return inputs[pos].GetTfType(tf_type);
1746 }
1747 
GetOutputTfType(const OpConverterParams & params,DataType * tf_type)1748 Status GetOutputTfType(const OpConverterParams& params, DataType* tf_type) {
1749   return GetNodeDefTfType(params.node_def, tf_type);
1750 }
1751 
AllowDataTypes(const OpConverterParams & params,const std::set<DataType> & allowed_types,const char * type_attr_name="")1752 Status AllowDataTypes(const OpConverterParams& params,
1753                       const std::set<DataType>& allowed_types,
1754                       const char* type_attr_name = "") {
1755   const auto& node_def = params.node_def;
1756   DataType tf_type;
1757   TF_RETURN_IF_ERROR(GetNodeDefTfType(node_def, &tf_type, type_attr_name));
1758   if (!allowed_types.count(tf_type)) {
1759     string allowed_types_string = absl::StrJoin(
1760         allowed_types, ", ", [](string* out, const DataType& type) {
1761           absl::StrAppendFormat(out, "%s", DataTypeString(type));
1762         });
1763     return errors::Unimplemented(
1764         "Data type ", DataTypeString(tf_type), " is not supported for ",
1765         node_def.op(), ", must be one of [", allowed_types_string, "]");
1766   }
1767   return Status::OK();
1768 }
1769 
1770 namespace {
1771 // Extracts the spatial dimensions from `output_sizes` and returns them as a
1772 // vector of size 2.
GetSpatialDimsFromOutputSizes(const TRT_TensorOrWeights & output_sizes,const int h_index,const int w_index)1773 std::vector<int64_t> GetSpatialDimsFromOutputSizes(
1774     const TRT_TensorOrWeights& output_sizes, const int h_index,
1775     const int w_index) {
1776   // We use h_index and w_index instead of 1 and 2 because we haven't
1777   // transposed output_sizes along with the input.
1778   const TRT_ShapedWeights& weights = output_sizes.weights();
1779   const int output_sizes_length = weights.count();
1780   auto output_sizes_values = weights.GetPointer<int>();
1781   // The length of output_sizes can be 2 or 4. When the length is 4,
1782   // output_sizes represents <height,width>.
1783   return {output_sizes_values[output_sizes_length == 4 ? h_index : 0],
1784           output_sizes_values[output_sizes_length == 4 ? w_index : 1]};
1785 }
1786 }  // namespace
1787 
ConvertConv2DHelper(OpConverterParams * params,int group,bool is_conv2d_backprop_input)1788 Status ConvertConv2DHelper(OpConverterParams* params, int group,
1789                            bool is_conv2d_backprop_input) {
1790   const auto& inputs = params->inputs;
1791   const auto& node_def = params->node_def;
1792   TRT_TensorOrWeights backprop_output_size;
1793   ITensorProxyPtr tensor = nullptr;
1794   if (is_conv2d_backprop_input) {
1795     // In the case when Conv2dBackpropInput is used for conv2d_transpose, these
1796     // inputs correspond to: output size, filter, and input.
1797     // TODO(cbate): refine this check when moving to structured op converter.
1798     if (!params->use_explicit_precision) {
1799       TF_RETURN_IF_ERROR(CheckInputsWeights(
1800           *params,
1801           {{"input_sizes", true}, {"filter", true}, {"out_backprop", false}}));
1802     }
1803 
1804     backprop_output_size = inputs.at(0);
1805     tensor = inputs.at(2).tensor();
1806     bool has_dynamic_hw_shape{false};
1807     int start_idx{0};
1808     auto dims = tensor->getDimensions();
1809     if (params->use_implicit_batch) {
1810       if (dims.nbDims != 3) {
1811         return errors::Internal(
1812             "In implicit batch mode, input nbDims should be 3");
1813       }
1814       start_idx = 1;
1815     } else {
1816       if (dims.nbDims != 4) {
1817         return errors::Internal(
1818             "In explicit batch mode, input nbDims should be 4");
1819       }
1820       start_idx = 2;
1821     }
1822     for (int i = start_idx; i < dims.nbDims; ++i) {
1823       if (dims.d[i] < 0) {
1824         has_dynamic_hw_shape = true;
1825       }
1826     }
1827     if (has_dynamic_hw_shape) {
1828       return errors::Unimplemented(
1829           "Conv2dBackpropInput does not support input with unknown spatial "
1830           "shape");
1831     }
1832   } else {
1833     TF_RETURN_IF_ERROR(CheckInputsWeights(
1834         *params,
1835         {{"input", false}, {"filter", !params->use_explicit_precision}}));
1836     tensor = inputs.at(0).tensor();
1837   }
1838   TF_RETURN_IF_ERROR(
1839       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
1840 
1841   if (inputs.at(1).GetTrtDims().nbDims != 4) {
1842     return errors::InvalidArgument("Conv2D expects kernel of dimension 4");
1843   }
1844 
1845   string data_format, padding_type;
1846   std::vector<int64_t> tf_dilations, tf_stride;
1847   AttrSlice attrs(node_def);
1848   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
1849   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding_type));
1850   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "dilations", &tf_dilations));
1851   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &tf_stride));
1852 
1853   int c_index = (data_format == "NHWC") ? 3 : 1;
1854   int h_index = (data_format == "NHWC") ? 1 : 2;
1855   int w_index = (data_format == "NHWC") ? 2 : 3;
1856 
1857   if (tf_dilations.size() != 4) {
1858     return errors::InvalidArgument(
1859         "Convolution dilations field must specify 4 dimensions");
1860   }
1861   if (tf_dilations[0] != 1 || tf_dilations[c_index] != 1) {
1862     return errors::Unimplemented(
1863         "Dilation rate must be 1 for batch and channel dimensions");
1864   }
1865   const nvinfer1::DimsHW dilation(tf_dilations[h_index], tf_dilations[w_index]);
1866   if (is_conv2d_backprop_input && (dilation.d[0] != 1 || dilation.d[1] != 1)) {
1867     return errors::Unimplemented(
1868         "Dilation with Conv2DBackpropInput (conv2d_transpose) is not"
1869         " supported");
1870   }
1871 
1872   if (tf_stride.size() != 4) {
1873     return errors::InvalidArgument(
1874         "Convolution strides field must specify 4 dimensions");
1875   }
1876   if (tf_stride[0] != 1 || tf_stride[c_index] != 1) {
1877     return errors::Unimplemented(
1878         "Stride must be 1 for batch and channel dimensions");
1879   }
1880   // Channel dim must be static for DepthwiseConv2dNative since we use that
1881   // value for num_groups at build time.
1882   if (!params->use_implicit_batch && tensor->getDimensions().d[c_index] == -1) {
1883     return errors::InvalidArgument("Channel dimension must be static");
1884   }
1885 
1886   if (padding_type != "SAME" && padding_type != "VALID") {
1887     return errors::Unimplemented(padding_type +
1888                                  " padding type not implemented, "
1889                                  "only VALID and SAME are supported");
1890   }
1891   const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
1892   if (params->validation_only) return Status::OK();
1893 
1894   // Transpose to NCHW (NCHW is required for IConvLayer).
1895   const bool need_transpose = (data_format == "NHWC");
1896   if (need_transpose) {
1897     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
1898         tensor, {0, 3, 1, 2}, &tensor, node_def, "to_NCHW"));
1899   }
1900   // Dimensions of transposed tensor.
1901   const auto tensor_dim = tensor->getDimensions();
1902   const int c_dim_size = tensor_dim.d[params->use_implicit_batch ? 0 : 1];
1903 
1904   // group == 0 signifies that this is a depthwise convolution, so set
1905   // num_groups to size of input's channel dim. For a non-depthwise conv,
1906   // num_groups will be 1.
1907   const int num_groups = (group == 0) ? c_dim_size : group;
1908 
1909   // For conv, TF weights are RSCK, and TRT expects KCRS.
1910   // For backprop, TF weights are RSKC, and TRT expects CKRS.
1911   // Therefore, this reorder will work for both cases.
1912   const int output_axis = is_conv2d_backprop_input ? 2 : 3;
1913   auto weights_shape = inputs.at(1).GetTrtDims();
1914   const int noutput = weights_shape.d[output_axis] * num_groups;
1915   nvinfer1::DimsHW kernel_size;
1916   kernel_size.h() = weights_shape.d[0];
1917   kernel_size.w() = weights_shape.d[1];
1918 
1919   TRT_ShapedWeights weights_rsck;
1920   if (inputs.at(1).is_weights()) {
1921     weights_rsck = inputs.at(1).weights();
1922   } else {
1923     StatusOr<TRT_ShapedWeights> tmp = params->weight_store->GetTempWeights(
1924         nvinfer1::DataType::kFLOAT, weights_shape);
1925     TRT_ENSURE_OK(tmp);
1926     weights_rsck = std::move(tmp).value();
1927   }
1928 
1929   // In explcit precision mode, trace the input back to the constant while also
1930   // verifying that QDQ scale layers are present.
1931   if (!inputs.at(1).is_weights()) {
1932     TRT_ENSURE(params->use_explicit_precision);
1933     StatusOr<TRTNetworkBuilder> builder = TRTNetworkBuilder::Create(
1934         params->converter->network(), params->weight_store);
1935     TRT_ENSURE_OK(builder);
1936     auto dequant_layer =
1937         builder->FindProducerOf(inputs.at(1).tensor()->trt_tensor());
1938     TRT_ENSURE_PTR_OK(dequant_layer);
1939 
1940     // TODO(cbate): corresponding TRT layer name check
1941     if (!IS_TRT_VERSION_GE(8, 0, 0, 0)) {
1942       TRT_ENSURE((*dequant_layer)->getType() == nvinfer1::LayerType::kSCALE);
1943     }
1944 
1945     auto quant_layer = builder->UniqueParentOf(*dequant_layer, 0);
1946     TRT_ENSURE_PTR_OK(quant_layer);
1947 
1948     // TODO(cbate): corresponding TRT layer name check
1949     if (!IS_TRT_VERSION_GE(8, 0, 0, 0)) {
1950       TRT_ENSURE((*quant_layer)->getType() == nvinfer1::LayerType::kSCALE);
1951     }
1952 
1953     auto weights_layer = builder->UniqueParentOf(*quant_layer, 0);
1954     TRT_ENSURE_PTR_OK(weights_layer);
1955     TRT_ENSURE((*weights_layer)->getType() == nvinfer1::LayerType::kCONSTANT);
1956     auto const_weights_rsck =
1957         reinterpret_cast<nvinfer1::IConstantLayer*>(*weights_layer)
1958             ->getWeights();
1959 
1960     TRT_ENSURE(weights_rsck.count() == weights_rsck.count());
1961     const auto* weights_ptr =
1962         static_cast<const float*>(const_weights_rsck.values);
1963     std::copy_n(weights_ptr, const_weights_rsck.count,
1964                 weights_rsck.GetPointer<float>());
1965   }
1966 
1967   StatusOr<TRT_ShapedWeights> weights =
1968       params->weight_store->GetTempWeights(weights_rsck);
1969   TRT_ENSURE_OK(weights);
1970   StatusOr<TRT_ShapedWeights> biases = params->weight_store->GetTempWeights(
1971       nvinfer1::DataType::kFLOAT, nvinfer1::Dims{1, {noutput}});
1972   TRT_ENSURE_OK(biases);
1973   std::fill_n(biases->GetPointer<float>(), noutput, 0.0f);
1974   ReorderRSCKToKCRS(weights_rsck, &*weights, num_groups);
1975 
1976   // Add convolution.
1977   nvinfer1::ILayer* conv_layer = nullptr;
1978   if (is_conv2d_backprop_input) {
1979     nvinfer1::IDeconvolutionLayer* layer =
1980         params->converter->network()->addDeconvolution(
1981             *tensor->trt_tensor(), noutput, kernel_size,
1982             weights->GetTrtWeights(), biases->GetTrtWeights());
1983     TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
1984     layer->setStride(stride);
1985     // VALID padding is the default TRT behavior.
1986     if (padding_type == "SAME") {
1987       // SAME_UPPER means that post padding is preferred.
1988       layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
1989     }
1990     layer->setNbGroups(num_groups);
1991     conv_layer = layer;
1992   } else {
1993     const nvinfer1::Weights empty_weights{nvinfer1::DataType::kFLOAT, nullptr,
1994                                           0};
1995     nvinfer1::IConvolutionLayer* layer =
1996         params->converter->network()->addConvolution(
1997             *tensor->trt_tensor(), noutput, kernel_size,
1998             params->use_explicit_precision ? empty_weights
1999                                            : weights->GetTrtWeights(),
2000             empty_weights);
2001     TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
2002     layer->setStride(stride);
2003     if (padding_type == "SAME") {
2004       layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
2005     }
2006     layer->setNbGroups(num_groups);
2007     layer->setDilation(dilation);
2008     conv_layer = layer;
2009   }
2010 
2011   // After creating the conv layer, if we are in explicit precision mode and the
2012   // weights input is a tensor, then we need to override the weights input by
2013   // calling setInput() on the layer.
2014   if (params->use_explicit_precision) {
2015     TRT_ENSURE(inputs.at(1).is_tensor());
2016 
2017     conv_layer->setInput(1, *inputs.at(1).tensor()->trt_tensor());
2018   }
2019 
2020   params->converter->SetLayerName(conv_layer, node_def, "conv");
2021   ITensorProxyPtr output_tensor = conv_layer->getOutput(0);
2022   // Add an extra padding for Deconv because TRT doesn't accept the
2023   // argument output_shape and thus the TRT output shape could be wrong
2024   // in case of strides>1.
2025   if (is_conv2d_backprop_input) {
2026     std::vector<int64_t> output_spatial_dims =
2027         GetSpatialDimsFromOutputSizes(backprop_output_size, h_index, w_index);
2028     const int output_height = output_spatial_dims[0];
2029     const int output_width = output_spatial_dims[1];
2030     nvinfer1::Dims trt_output_shape = output_tensor->getDimensions();
2031     // What determines the padding size is the difference between the given
2032     // input_sizes (tf_output_shape) and TRT computed size.
2033     int out_h_idx = params->use_implicit_batch ? 1 : 2;
2034     int out_w_idx = params->use_implicit_batch ? 2 : 3;
2035     const int height_diff = output_height - trt_output_shape.d[out_h_idx];
2036     const int width_diff = output_width - trt_output_shape.d[out_w_idx];
2037     if ((height_diff < 0) || (width_diff < 0)) {
2038       return errors::InvalidArgument(
2039           "input_sizes argument of Conv2DBackprop (i.e. output_shape argument "
2040           "of conv2d_transpose) ",
2041           "is too small for the given out_backprop argument of Conv2DBackprop "
2042           "(i.e. input argument of conv2d_transpose). Expect: ",
2043           "(", output_height, ", ", output_width, ") >= ", "(",
2044           trt_output_shape.d[out_h_idx], ", ", trt_output_shape.d[out_w_idx],
2045           ")");
2046     }
2047     // Only add a padding layer if padding sizes are larger than 0
2048     if ((height_diff > 0) || (width_diff > 0)) {
2049       nvinfer1::DimsHW pre_padding(0, 0);
2050       nvinfer1::DimsHW post_padding(height_diff, width_diff);
2051       nvinfer1::IPaddingLayer* padding_layer =
2052           params->converter->network()->addPadding(*output_tensor->trt_tensor(),
2053                                                    pre_padding, post_padding);
2054       output_tensor = padding_layer->getOutput(0);
2055       params->converter->SetLayerName(padding_layer, node_def, "pad");
2056     }
2057   }
2058   // Restore transpose.
2059   if (need_transpose) {
2060     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
2061         output_tensor, {0, 2, 3, 1}, &output_tensor, node_def, "to_NHWC"));
2062   }
2063   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
2064   return Status::OK();
2065 }
2066 
AllowInefficientTranspose()2067 bool AllowInefficientTranspose() {
2068   static bool result = [] {
2069     bool value;
2070     Status status =
2071         ReadBoolFromEnvVar("TF_DEBUG_TRT_ALLOW_INEFFICIENT_TRANSPOSE",
2072                            /*default_value=*/false, &value);
2073     if (!status.ok()) {
2074       LOG(ERROR) << status;
2075     }
2076     return value;
2077   }();
2078 
2079   return result;
2080 }
2081 
ConvertTranspose(OpConverterParams * params)2082 Status ConvertTranspose(OpConverterParams* params) {
2083   const auto& inputs = params->inputs;
2084   TF_RETURN_IF_ERROR(
2085       CheckInputsWeights(*params, {{"x", false}, {"perm", true}}));
2086   TF_RETURN_IF_ERROR(AllowDataTypes(
2087       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
2088   // Get the permutation from weights.
2089   TRT_ShapedWeights weights = inputs.at(1).weights();
2090   const int* weights_ptr = weights.GetPointer<int>();
2091   std::vector<int> perm(weights_ptr, weights_ptr + weights.count());
2092 
2093   // Verify the permutation.
2094   ITensorProxyPtr input_tensor = inputs.at(0).tensor();
2095   const int perm_size =
2096       params->use_implicit_batch ? perm.size() - 1 : perm.size();
2097   if (perm_size != size_t(input_tensor->getDimensions().nbDims)) {
2098     return errors::InvalidArgument(
2099         "Rank of perm for transpose does not match with that of the input.");
2100   }
2101   if (params->use_implicit_batch && perm[0] != 0) {
2102     return errors::Unimplemented(
2103         "Transpose at batch dimension is not supported.");
2104   }
2105 
2106   if (!IS_TRT_VERSION_GE(7, 1, 3, 4)) {
2107     // TensorRT versions before 7.1.3.4 is slow transposing large tensors.
2108     // So check tensor size, and don't convert if it is too large.
2109     constexpr int64_t kMaxEfficientTranspose = 2500000;
2110     int64_t tensor_size = DimsAdapter(input_tensor->getDimensions()).Volume();
2111     if (!AllowInefficientTranspose() && tensor_size > kMaxEfficientTranspose) {
2112       return errors::Unimplemented(StrCat("Transpose too large:", tensor_size));
2113     }
2114   }
2115 
2116   if (params->validation_only) return Status::OK();
2117 
2118   // Start conversion.
2119   ITensorProxyPtr output_tensor = nullptr;
2120   TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
2121       input_tensor, perm, &output_tensor, params->node_def));
2122   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
2123   return Status::OK();
2124 }
2125 
ConvertShape(OpConverterParams * params)2126 Status ConvertShape(OpConverterParams* params) {
2127   const auto& inputs = params->inputs;
2128   TF_RETURN_IF_ERROR(
2129       CheckInputsWeights(*params, {{"input", TrtInputArg::kBoth}}));
2130   if (params->use_implicit_batch) {
2131     return errors::Unimplemented(
2132         "Shape is only supported for explicit batch mode.");
2133   }
2134   DimsAdapter input_dims(inputs.at(0).GetTrtDims());
2135   if (params->validation_only) return Status::OK();
2136 
2137   StatusOr<TRTNetworkBuilder> builder = TRTNetworkBuilder::Create(
2138       params->converter->network(), params->weight_store);
2139   TRT_ENSURE_OK(builder);
2140   if (input_dims.IsStatic()) {
2141     // Create a const node with the value of the shape.
2142     StatusOr<nvinfer1::IConstantLayer*> const_layer =
2143         builder->ConstantShape(input_dims);
2144     TRT_ENSURE_PTR_OK(const_layer);
2145     params->outputs->push_back(
2146         TRT_TensorOrWeights((*const_layer)->getOutput(0)));
2147     return Status::OK();
2148   }
2149   StatusOr<nvinfer1::IShapeLayer*> shape_layer =
2150       builder->Shape(inputs.at(0).tensor()->trt_tensor());
2151   TRT_ENSURE_PTR_OK(shape_layer);
2152   params->converter->SetLayerName(*shape_layer, params->node_def, "shape");
2153   params->outputs->push_back(TRT_TensorOrWeights((*shape_layer)->getOutput(0)));
2154   return Status::OK();
2155 }
2156 
ExpectShapeTensor(const TRT_TensorOrWeights & tensor)2157 Status ExpectShapeTensor(const TRT_TensorOrWeights& tensor) {
2158   if (tensor.tensor()->getType() != nvinfer1::DataType::kINT32) {
2159     return errors::InvalidArgument("Expected a shape tensor with INT32 type");
2160   }
2161   if (tensor.GetTrtDims().nbDims > 1) {
2162     return errors::InvalidArgument("Expected a 0D or 1D shape tensor");
2163   }
2164   return Status::OK();
2165 }
2166 
2167 // Converts Reshape op if the input has dynamic (unknown) dims.
ConvertDynamicReshape(OpConverterParams * params)2168 Status ConvertDynamicReshape(OpConverterParams* params) {
2169   if (params->use_implicit_batch) {
2170     return errors::InvalidArgument(
2171         "The input \"shape\" for Reshape must be a constant in implicit batch"
2172         " mode.");
2173   }
2174   if (!IS_TRT_VERSION_GE(7, 1, 3, 0)) {
2175     // While officially TRT supports shape value input , there are problems with
2176     // shape input handling that cause networks converted with
2177     // ConvertDynamicReshape fail. Here we conservatively switch off the
2178     // converter before TRT 7.1.3.
2179     return errors::InvalidArgument(
2180         "Non constant shape input tensor for Reshape requires minimum TRT "
2181         "7.1.3");
2182   }
2183   const auto& inputs = params->inputs;
2184   const TRT_TensorOrWeights& input_tensor = inputs.at(0);
2185 
2186   // If the input is a tensor it must be a shape tensor.
2187   TF_RETURN_IF_ERROR(ExpectShapeTensor(inputs.at(1)));
2188   if (inputs.at(1).tensor()->getDimensions().nbDims == 0) {
2189     // Dynamic reshape requires a 1D shape tensor.
2190     return errors::Unimplemented(
2191         "Reshape with dynamic input requires 1D input tensor");
2192   }
2193   if (params->validation_only) return Status::OK();
2194   nvinfer1::IShuffleLayer* layer = params->converter->network()->addShuffle(
2195       *input_tensor.tensor()->trt_tensor());
2196   VLOG(2) << "ConvertReshape setInput (1) "
2197           << DebugString(inputs.at(1).tensor()->getDimensions());
2198   layer->setInput(1, *inputs.at(1).tensor()->trt_tensor());
2199   params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
2200   return Status::OK();
2201 }
2202 
2203 // Converts Reshape in explicit batch mode if the input has static (known) dims.
ConvertStaticReshapeForExplicitBatchMode(OpConverterParams * params,DimsAdapter output_dims,ITensorProxyPtr * output_tensor)2204 Status ConvertStaticReshapeForExplicitBatchMode(
2205     OpConverterParams* params, DimsAdapter output_dims,
2206     ITensorProxyPtr* output_tensor) {
2207   return PrepareTensorForShape(params->converter, params->inputs.at(0),
2208                                output_dims, params->validation_only,
2209                                output_tensor, params->node_def);
2210 }
2211 
2212 // Converts Reshape in implicit batch mode. The input has static (known) dims.
ConvertStaticReshapeForImplicitBatchMode(OpConverterParams * params,DimsAdapter output_dims,ITensorProxyPtr * output_tensor)2213 Status ConvertStaticReshapeForImplicitBatchMode(
2214     OpConverterParams* params, DimsAdapter output_dims,
2215     ITensorProxyPtr* output_tensor) {
2216   const auto& inputs = params->inputs;
2217   const TRT_TensorOrWeights& input_tensor = inputs.at(0);
2218   const int input_batch_dim = input_tensor.batch_size();
2219   const int64_t output_batch_dim = output_dims.dim(0);
2220 
2221   DimsAdapter input_nonbatch_dims(input_tensor.GetTrtDims());
2222   DimsAdapter output_nonbatch_dims(output_dims);
2223   TF_RETURN_IF_ERROR(output_nonbatch_dims.RemoveBatchDimension());
2224 
2225   VLOG(1) << "input_batch_dim=" << input_batch_dim
2226           << ", input_nonbatch_dims=" << input_nonbatch_dims.DebugString()
2227           << "\nresult_batch_dim=" << output_batch_dim
2228           << ", result_nonbatch_dims=" << output_nonbatch_dims.DebugString();
2229 
2230   // Check whether input_batch_dim and output_batch_dim will have the same
2231   // static value.
2232   bool reshape_may_change_batch_dim = false;
2233   if (input_batch_dim != -1 && output_batch_dim != -1) {
2234     reshape_may_change_batch_dim = (input_batch_dim != output_batch_dim);
2235   } else {
2236     reshape_may_change_batch_dim =
2237         !AreDimsStaticWithSameSize(input_nonbatch_dims, output_nonbatch_dims);
2238   }
2239   if (reshape_may_change_batch_dim) {
2240     return errors::Unimplemented("Reshape on batch dimension is not supported");
2241   }
2242   // Perform the conversion.
2243   return PrepareTensorForShape(params->converter, input_tensor,
2244                                output_nonbatch_dims, params->validation_only,
2245                                output_tensor, params->node_def);
2246 }
2247 
ConvertReshape(OpConverterParams * params)2248 Status ConvertReshape(OpConverterParams* params) {
2249   const auto& inputs = params->inputs;
2250   TF_RETURN_IF_ERROR(CheckInputsWeights(
2251       *params,
2252       {{"tensor", TrtInputArg::kTensor}, {"shape", TrtInputArg::kBoth}}));
2253   TF_RETURN_IF_ERROR(AllowDataTypes(
2254       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
2255   if (inputs.at(1).is_tensor()) {
2256     return ConvertDynamicReshape(params);
2257   }
2258 
2259   // TODO(bixia): we can't use inputs.at(1).weights().ToVector<int>() for two
2260   // reasons: (1) When weights.count()==0, TRT_ShapedWeights::tensor_ dtype is
2261   // not properly set to INT32. (2) I tried a fix for the first problem, I got
2262   // shared pointer related error in convert_nodes_test. We should fix the
2263   // problems and switch to use inputs.at(1).weights().ToVector<int>(), a type
2264   // safe method to access the content of the tensor.
2265   TRT_ShapedWeights weights = inputs.at(1).weights();
2266   if (weights.count() == 0 && params->use_implicit_batch) {
2267     return errors::Unimplemented("Reshape to shape=[] is not supported");
2268   }
2269 
2270   DimsAdapter output_shape_dims(
2271       absl::MakeSpan(weights.GetPointer<int>(), weights.count()));
2272   ITensorProxyPtr output_tensor = nullptr;
2273 
2274   if (!params->use_implicit_batch) {
2275     TF_RETURN_IF_ERROR(ConvertStaticReshapeForExplicitBatchMode(
2276         params, output_shape_dims, &output_tensor));
2277   } else {
2278     TF_RETURN_IF_ERROR(ConvertStaticReshapeForImplicitBatchMode(
2279         params, output_shape_dims, &output_tensor));
2280   }
2281   if (params->validation_only) return Status::OK();
2282 
2283   // Record the conversion result.
2284   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
2285   return Status::OK();
2286 }
2287 
ConvertExpandDims(OpConverterParams * params)2288 Status ConvertExpandDims(OpConverterParams* params) {
2289   const auto& inputs = params->inputs;
2290   const auto& node_def = params->node_def;
2291   TF_RETURN_IF_ERROR(
2292       CheckInputsWeights(*params, {{"input", false}, {"axis", true}}));
2293   TF_RETURN_IF_ERROR(AllowDataTypes(
2294       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
2295   // Get input shape as vector.
2296   const TRT_TensorOrWeights& input_tensor = inputs.at(0);
2297   const nvinfer1::Dims dims = input_tensor.GetTrtDims();
2298   std::vector<int> input_dims(dims.d, dims.d + dims.nbDims);
2299   // Get axis to expand on.
2300   auto axis = inputs.at(1).weights().GetSpan<int>();
2301   if (axis.size() != 1) {
2302     return errors::InvalidArgument("ExpandDims axis must be a scalar");
2303   }
2304   // Use rank = nbDims + 1 for ConvertAxis's bounds checking to account for
2305   // ExpandDim's ability to add an axis at end of the shape.
2306   int trt_axis;
2307   TF_RETURN_IF_ERROR(ConvertAxis(axis[0], dims.nbDims + 1, node_def.name(),
2308                                  params->use_implicit_batch, &trt_axis));
2309   if (params->validation_only) return Status::OK();
2310   ITensorProxyPtr output_tensor = nullptr;
2311 
2312   if (!params->use_implicit_batch && !HasStaticShape(input_dims)) {
2313     TF_RETURN_IF_ERROR(params->converter->DynamicExpandDims(
2314         /*input=*/input_tensor.tensor(),
2315         /*dims=*/dims,
2316         /*axis=*/trt_axis,
2317         /*params=*/params,
2318         /*output=*/&output_tensor));
2319   } else {
2320     // ExpandDims: Insert new dim of size 1.
2321     input_dims.insert(input_dims.begin() + trt_axis, 1);
2322     // Reshape tensor.
2323     DimsAdapter dims(input_dims);
2324     TF_RETURN_IF_ERROR(PrepareTensorForShape(
2325         params->converter, input_tensor, dims,
2326         /*validation_only=*/false, &output_tensor, params->node_def));
2327   }
2328   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
2329   return Status::OK();
2330 }
2331 
DynamicReshape(ITensorProxyPtr input,std::vector<std::pair<int,int>> slices,OpConverterParams * params,ITensorProxyPtr * output,std::vector<int> size_for_added_dims,std::optional<int> op_instance)2332 Status Converter::DynamicReshape(ITensorProxyPtr input,
2333                                  std::vector<std::pair<int, int>> slices,
2334                                  OpConverterParams* params,
2335                                  ITensorProxyPtr* output,
2336                                  std::vector<int> size_for_added_dims,
2337                                  std::optional<int> op_instance) {
2338   *output = nullptr;
2339   // DynamicReshape relies on INetworkDefinition::addShape
2340   if (params->validation_only) {
2341     return errors::Internal(
2342         "DynamicReshape should not be used during validation");
2343   }
2344   ITensorProxyPtr shape =
2345       network()->addShape(*input->trt_tensor())->getOutput(0);
2346   // Build new shape = shape[:trt_axis] + [1] + shape[trt_axis:]
2347   std::vector<ITensorProxyPtr> concat_inputs;
2348   int max_num_slices = std::max(slices.size(), size_for_added_dims.size());
2349   int op_instance_value = op_instance.has_value() ? op_instance.value() : 0;
2350 
2351   for (int i = 0; i < max_num_slices; i++) {
2352     ITensorProxyPtr tensor;
2353     // maybe_add_a_dimension(i);
2354     if (i < size_for_added_dims.size() && size_for_added_dims[i] >= 0) {
2355       nvinfer1::Dims dims{1, {1}};
2356       if (size_for_added_dims[i] > 0) {
2357         dims.d[0] = size_for_added_dims[i];
2358       }
2359       TF_RETURN_IF_ERROR(
2360           CreateScalarConstant(params, std::min(size_for_added_dims[i], 1),
2361                                &tensor, nvinfer1::DataType::kINT32, dims));
2362       concat_inputs.push_back(tensor);
2363     }
2364     if (i < slices.size()) {
2365       nvinfer1::ISliceLayer* slice_layer = network()->addSlice(
2366           *shape->trt_tensor(), {1, {slices[i].first}},
2367           {1, {slices[i].second - slices[i].first}}, {1, {1}});
2368       concat_inputs.push_back(slice_layer->getOutput(0));
2369       string slice_name = StrCat("slice_", op_instance_value);
2370       SetLayerName(slice_layer, params->node_def, slice_name,
2371                    /*op_instance=*/i);
2372     }
2373   }
2374   std::vector<nvinfer1::ITensor*> trt_concat_inputs;
2375   for (const auto& t : concat_inputs) {
2376     trt_concat_inputs.push_back(t->trt_tensor());
2377   }
2378   nvinfer1::IConcatenationLayer* concat_layer = network()->addConcatenation(
2379       static_cast<nvinfer1::ITensor* const*>(trt_concat_inputs.data()),
2380       concat_inputs.size());
2381   SetLayerName(concat_layer, params->node_def, "concat", op_instance);
2382   concat_layer->setAxis(0);
2383   ITensorProxyPtr new_shape = concat_layer->getOutput(0);
2384   // Reshape input using new shape
2385   nvinfer1::IShuffleLayer* shuffle =
2386       network()->addShuffle(*input->trt_tensor());
2387   SetLayerName(shuffle, params->node_def, "shuffle", op_instance);
2388   shuffle->setInput(1, *new_shape->trt_tensor());
2389   *output = shuffle->getOutput(0);
2390   return Status::OK();
2391 }
2392 
DynamicExpandDims(ITensorProxyPtr input,const nvinfer1::Dims & dims,int axis,OpConverterParams * params,ITensorProxyPtr * output,std::optional<int> op_instance)2393 Status Converter::DynamicExpandDims(ITensorProxyPtr input,
2394                                     const nvinfer1::Dims& dims, int axis,
2395                                     OpConverterParams* params,
2396                                     ITensorProxyPtr* output,
2397                                     std::optional<int> op_instance) {
2398   if (params->validation_only) {
2399     *output = nullptr;
2400     return errors::Internal(
2401         "DynamicExpandDims should not be used during validation");
2402   }
2403   std::vector<std::pair<int, int>> slices;
2404   std::vector<int> extra_dims;
2405   if (axis != 0) {
2406     slices.push_back(std::pair<int, int>{0, axis});
2407     extra_dims.push_back(-1);
2408   }
2409   extra_dims.push_back(1);
2410   if (axis != dims.nbDims) {
2411     slices.push_back(std::pair<int, int>{axis, dims.nbDims});
2412   }
2413   return DynamicReshape(
2414       /*input=*/input,
2415       /*slices=*/slices,
2416       /*params=*/params,
2417       /*output=*/output,
2418       /*size_for_added_dims=*/extra_dims,
2419       /*op_instance=*/op_instance);
2420 }
2421 
SqueezeTensor(ITensorProxyPtr input,std::vector<int> * input_dims,OpConverterParams * params,ITensorProxyPtr * output,std::optional<int> op_instance)2422 Status Converter::SqueezeTensor(ITensorProxyPtr input,
2423                                 std::vector<int>* input_dims,
2424                                 OpConverterParams* params,
2425                                 ITensorProxyPtr* output,
2426                                 std::optional<int> op_instance) {
2427   // If the remaining dimensions of a squeeze operation have dynamic sizes, we
2428   // need to use TRT ops to build the result shape for the squeeze operation.
2429   // This is because IShuffleLayer::setReshapeDimensions treats -1 as a special
2430   // value.
2431   if (!params->use_implicit_batch && !HasStaticShape(*input_dims)) {
2432     std::vector<std::pair<int, int>> slices;
2433     for (int i = 0; i < input_dims->size(); i++) {
2434       if (input_dims->at(i) != 0) {
2435         slices.push_back(std::pair<int, int>(i, i + 1));
2436       }
2437     }
2438     return DynamicReshape(
2439         /*input=*/input,
2440         /*slices=*/slices,
2441         /*params=*/params,
2442         /*output=*/output,
2443         /*size_for_added_dims=*/{},
2444         /*op_instance=*/op_instance);
2445   }
2446   // Remove all dims which are equal to 0.
2447   input_dims->erase(std::remove(input_dims->begin(), input_dims->end(), 0),
2448                     input_dims->end());
2449   // Reshape tensor.
2450   TF_RETURN_IF_ERROR(PrepareTensorForShape(
2451       params->converter, TRT_TensorOrWeights(input), DimsAdapter(*input_dims),
2452       /*validation_only=*/false, output, params->node_def));
2453   return Status::OK();
2454 }
2455 
ConvertSqueeze(OpConverterParams * params)2456 Status ConvertSqueeze(OpConverterParams* params) {
2457   const auto& inputs = params->inputs;
2458   const auto& node_def = params->node_def;
2459   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
2460   TF_RETURN_IF_ERROR(AllowDataTypes(
2461       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
2462   // Get input shape.
2463   const TRT_TensorOrWeights& input_tensor = inputs.at(0);
2464   const nvinfer1::Dims dims = input_tensor.GetTrtDims();
2465   std::vector<int> input_dims(dims.d, dims.d + dims.nbDims);
2466   std::vector<int64_t> squeeze_dims;
2467   TF_RETURN_IF_ERROR(
2468       GetNodeAttr(AttrSlice(node_def), "squeeze_dims", &squeeze_dims));
2469   if (squeeze_dims.empty()) {
2470     if (params->use_implicit_batch || !HasStaticShape(dims)) {
2471       return errors::Unimplemented(
2472           "Squeeze is not implemented for empty squeeze_dims");
2473     } else {
2474       // explicit batch mode with static input shape we squeeze all singleton
2475       // dimensions
2476       for (int& dim : input_dims) {
2477         if (dim == 1) {
2478           // Mark it for removal by setting it to 0
2479           dim = 0;
2480         }
2481       }
2482     }
2483   } else {
2484     std::vector<int> trt_axes;
2485     trt_axes.reserve(squeeze_dims.size());
2486     for (int tf_axis : squeeze_dims) {
2487       // If the axis is valid, then convert it to TRT axis, otherwise abort
2488       // conversion.
2489       int trt_axis;
2490       TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(),
2491                                      params->use_implicit_batch, &trt_axis));
2492       // Make sure target dimension is size 1 or unknown size (-1)
2493       if (input_dims[trt_axis] != -1 && input_dims[trt_axis] != 1) {
2494         return errors::InvalidArgument(
2495             "Dimension ", tf_axis, " with size ", input_dims[trt_axis],
2496             " cannot be squeezed because it must be size 1");
2497       }
2498       trt_axes.push_back(trt_axis);
2499     }
2500     // Mark axes to remove by setting them to 0.
2501     for (int axis : trt_axes) {
2502       input_dims[axis] = 0;
2503     }
2504   }
2505   if (params->validation_only) return Status::OK();
2506 
2507   ITensorProxyPtr output_tensor = nullptr;
2508   TF_RETURN_IF_ERROR(params->converter->SqueezeTensor(
2509       /*input=*/input_tensor.tensor(),
2510       /*input_dims=*/&input_dims,
2511       /*params=*/params,
2512       /*output=*/&output_tensor));
2513   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
2514   return Status::OK();
2515 }
2516 
ConvertSlice(OpConverterParams * params)2517 Status ConvertSlice(OpConverterParams* params) {
2518   const auto& inputs = params->inputs;
2519   TF_RETURN_IF_ERROR(CheckInputsWeights(
2520       *params, {{"input", false}, {"begin", true}, {"size", true}}));
2521   TF_RETURN_IF_ERROR(AllowDataTypes(
2522       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
2523 
2524   const TRT_ShapedWeights& begin_weights = inputs.at(1).weights();
2525   const TRT_ShapedWeights& size_weights = inputs.at(2).weights();
2526 
2527   // Check that "begin" is not negative.
2528   if (absl::c_any_of(begin_weights.GetSpan<int32>(),
2529                      [](const int32 val) { return val < 0; })) {
2530     return errors::InvalidArgument("\"begin\" in Slice is out of range");
2531   }
2532 
2533   // Check that "size" is not less than -1.
2534   if (absl::c_any_of(size_weights.GetSpan<int32>(),
2535                      [](const int32 val) { return val < -1; })) {
2536     return errors::InvalidArgument("\"size\" in Slice is out of range");
2537   }
2538 
2539   // Get the input dims and add batch dimension so that indexes line up
2540   // properly.
2541   PartialTensorShape input_shape;
2542   TF_RETURN_IF_ERROR(
2543       DimsAdapter(inputs.at(0).GetTrtDims())
2544           .PartialTensorShape(
2545               &input_shape, params->use_implicit_batch
2546                                 ? std::optional<int>(inputs.at(0).batch_size())
2547                                 : std::nullopt));
2548 
2549   if (static_cast<int64>(input_shape.dims()) !=
2550           begin_weights.GetTensor().NumElements() ||
2551       static_cast<int64>(input_shape.dims()) !=
2552           size_weights.GetTensor().NumElements()) {
2553     return errors::InvalidArgument(
2554         "Length of begin and size arguments must equal rank of input for "
2555         "Slice");
2556   }
2557 
2558   // Check that batch dimension is unmodified.
2559   if (params->use_implicit_batch) {
2560     auto begin_v = begin_weights.GetSpan<int32>();
2561     auto size_v = size_weights.GetSpan<int32>();
2562 
2563     // The batch dimension is modified if begin doesn't start from 0 or slice
2564     // size on d0 is not equal to input size on d0. Slice size -1 means slices
2565     // to the end of the dimension.
2566     if (begin_v[0] != 0 ||
2567         (size_v[0] != -1 && size_v[0] != input_shape.dim_size(0))) {
2568       return errors::Unimplemented(
2569           "TensorRT does not allow modifications to the batch dimension in "
2570           "implicit batch mode");
2571     }
2572   }
2573 
2574   PartialTensorShape processing_shape;
2575   PartialTensorShape final_shape;
2576   bool is_identity;
2577   bool is_simple_slice;
2578   bool slice_dim0;
2579   absl::InlinedVector<int64, 4> begin;
2580   absl::InlinedVector<int64, 4> end;
2581   absl::InlinedVector<int64, 4> strides;
2582   StridedSliceShapeSpec strided_slice_spec;
2583   std::bitset<32> begin_mask(0);
2584   std::bitset<32> end_mask(0);
2585   std::bitset<32> ellipsis_mask(0);
2586   std::bitset<32> new_axis_mask(0);
2587   std::bitset<32> shrink_axis_mask(0);
2588   Tensor strides_tensor = tensor::DeepCopy(begin_weights.GetTensor());
2589   Tensor end_tensor = tensor::DeepCopy(size_weights.GetTensor());
2590   Tensor size_tensor = tensor::DeepCopy(size_weights.GetTensor());
2591 
2592   // Use the content in begin_weights and size_tensor to setup begin_mask,
2593   // end_mask, end_tensor, strides_tensor, and end_tensor.
2594   auto strides_vec = strides_tensor.flat<int32>();
2595   auto end_vec = end_tensor.flat<int32>();
2596   auto size_vec = size_tensor.flat<int32>();
2597   auto begin_vec = begin_weights.GetTensor().flat<int32>();
2598 
2599   for (int i = 0; i < input_shape.dims(); i++) {
2600     strides_vec(i) = 1;
2601     begin_mask[i] = false;
2602     if (size_vec(i) == -1) {
2603       end_mask[i] = true;
2604       end_vec(i) = 0;
2605       size_vec(i) = 0;
2606     } else {
2607       end_mask[i] = false;
2608       end_vec(i) = begin_vec(i) + size_vec(i);
2609       if (end_vec(i) > input_shape.dim_size(i) && input_shape.dim_size(i) > 0) {
2610         return errors::InvalidArgument("\"begin\" + \"size\" for dimension ", i,
2611                                        " in Slice is out of range");
2612       }
2613     }
2614   }
2615 
2616   auto bitset_to_int32 = [](const std::bitset<32>& bs) {
2617     return static_cast<int32_t>(bs.to_ulong());
2618   };
2619 
2620   TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
2621       &begin_weights.GetTensor(), &end_tensor, strides_tensor, input_shape,
2622       bitset_to_int32(begin_mask), bitset_to_int32(end_mask),
2623       bitset_to_int32(ellipsis_mask), bitset_to_int32(new_axis_mask),
2624       bitset_to_int32(shrink_axis_mask), &processing_shape, &final_shape,
2625       &is_identity, &is_simple_slice, &slice_dim0, &begin, &end, &strides,
2626       &strided_slice_spec));
2627 
2628   VLOG(2) << "ConvertSlice: "
2629           << "\n input_shape: " << input_shape
2630           << "\n procesing_shape: " << processing_shape
2631           << "\n final_shape: " << final_shape
2632           << "\n  begin: " << DebugString(begin)
2633           << "\n  stride: " << DebugString(strides)
2634           << "\n  end: " << DebugString(end)
2635           << "\n is identity: " << is_identity
2636           << "\n is simple_slice: " << is_simple_slice
2637           << "\n slice dim0: " << slice_dim0 << " StridedSliceShapeSpec:"
2638           << "\n   begin_dense_mask: "
2639           << std::bitset<32>(strided_slice_spec.begin_dense_mask)
2640           << "\n   end_dense_mask: "
2641           << std::bitset<32>(strided_slice_spec.end_dense_mask)
2642           << "\n   shrink_dense_mask: "
2643           << std::bitset<32>(strided_slice_spec.shrink_axis_dense_mask);
2644 
2645   return ConvertStridedSliceHelper(params, inputs.at(0), input_shape, begin,
2646                                    strides, end, std::nullopt, std::nullopt,
2647                                    strided_slice_spec);
2648 }
2649 
ConvertStridedSlice(OpConverterParams * params)2650 Status ConvertStridedSlice(OpConverterParams* params) {
2651   const auto& inputs = params->inputs;
2652   const auto& node_def = params->node_def;
2653 
2654   TF_RETURN_IF_ERROR(CheckInputsWeights(
2655       *params,
2656       {{"input", false}, {"begin", true}, {"end", true}, {"strides", true}}));
2657   TF_RETURN_IF_ERROR(AllowDataTypes(
2658       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
2659 
2660   int32 begin_mask, end_mask, ellipsis_mask, shrink_axis_mask, new_axis_mask;
2661   AttrSlice attrs(node_def);
2662   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "begin_mask", &begin_mask));
2663   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "end_mask", &end_mask));
2664   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ellipsis_mask", &ellipsis_mask));
2665   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "shrink_axis_mask", &shrink_axis_mask));
2666   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "new_axis_mask", &new_axis_mask));
2667 
2668   // New_axis_mask is not supported. TODO(tfeher): Support this by expanddims.
2669   if (new_axis_mask != 0) {
2670     return errors::Unimplemented(
2671         "new_axis_mask is not supported for StridedSlice");
2672   }
2673 
2674   // Shrinking axis on batch dimension is not allowed in implicit batch mode.
2675   if (params->use_implicit_batch && shrink_axis_mask & 1) {
2676     return errors::Unimplemented(
2677         "TensorRT does not allow modifications to the batch dimension");
2678   }
2679 
2680   // Convert TensorRT dimensions to TensorFlow shape. Implicit batch is added to
2681   // the TensorFlow shape, to be consistent with the weights and masks and to
2682   // support the use of TensorFlow slice op validator.
2683   PartialTensorShape input_shape;
2684   TF_RETURN_IF_ERROR(
2685       DimsAdapter(inputs.at(0).GetTrtDims())
2686           .PartialTensorShape(
2687               &input_shape, params->use_implicit_batch
2688                                 ? std::optional<int>(inputs.at(0).batch_size())
2689                                 : std::nullopt));
2690 
2691   const TRT_ShapedWeights& begin_weights = inputs.at(1).weights();
2692   const TRT_ShapedWeights& end_weights = inputs.at(2).weights();
2693   const TRT_ShapedWeights& stride_weights = inputs.at(3).weights();
2694   if (!AllLengthsEqual({begin_weights.ToVector<int>(),
2695                         end_weights.ToVector<int>(),
2696                         stride_weights.ToVector<int>()})) {
2697     return errors::InvalidArgument(
2698         "Length of begin, end, and stride must be equal");
2699   }
2700 
2701   PartialTensorShape processing_shape;
2702   PartialTensorShape final_shape;
2703   bool is_identity;
2704   bool is_simple_slice;
2705   bool slice_dim0;
2706   absl::InlinedVector<int64, 4> begin;
2707   absl::InlinedVector<int64, 4> end;
2708   absl::InlinedVector<int64, 4> strides;
2709   StridedSliceShapeSpec strided_slice_spec;
2710 
2711   TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
2712       &begin_weights.GetTensor(), &end_weights.GetTensor(),
2713       stride_weights.GetTensor(), input_shape, begin_mask, end_mask,
2714       ellipsis_mask, new_axis_mask, shrink_axis_mask, &processing_shape,
2715       &final_shape, &is_identity, &is_simple_slice, &slice_dim0, &begin, &end,
2716       &strides, &strided_slice_spec));
2717 
2718   if (!params->validation_only) {
2719     VLOG(2) << "After ValidateStridedSliceOp:"
2720             << "\n input_shape: " << input_shape
2721             << "\n procesing_shape: " << processing_shape
2722             << "\n final_shape: " << final_shape
2723             << "\n  begin: " << DebugString(begin)
2724             << "\n  stride: " << DebugString(strides)
2725             << "\n  end: " << DebugString(end)
2726             << " is identity: " << is_identity
2727             << "\n is simple_slice: " << is_simple_slice
2728             << "\n slice dim0: " << slice_dim0 << " StridedSliceShapeSpec:"
2729             << "\n   begin_dense_mask: "
2730             << std::bitset<32>(strided_slice_spec.begin_dense_mask)
2731             << "\n   end_dense_mask: "
2732             << std::bitset<32>(strided_slice_spec.end_dense_mask)
2733             << "\n   shrink_dense_mask: "
2734             << std::bitset<32>(strided_slice_spec.shrink_axis_dense_mask);
2735   }
2736 
2737   // If the first dimension of the ellepsis_mask is set, and fewer dimensions
2738   // are specified than the number of input dimensions, then the batch dimension
2739   // is not modified Otherwise we must check whether the batch dimension is
2740   // modified.
2741   if (params->use_implicit_batch &&
2742       !((ellipsis_mask & 1) &&
2743         begin_weights.Shape().NumDims() < input_shape.dims())) {
2744     // Check that batch dimension is unmodified. We need to use the expanded
2745     // begin/end/strides array since the original array may be incorrect when
2746     // (ellipsis_mask&1)==1.
2747     const bool begin_is_modified = !(begin_mask & 1) && (begin[0] != 0);
2748     const bool stride_is_modified = (strides[0] != 1);
2749     // If the batch size is -1 and the end mask is not set, we can only know if
2750     // the batch dimension is unmodified when the batch size is defined. When
2751     // the batch size is undefined, we don't convert to be safe.
2752     const bool batch_size_is_defined = (input_shape.dim_size(0) > 0);
2753     const bool end_is_modified =
2754         !(end_mask & 1) &&
2755         (!batch_size_is_defined || (end[0] != input_shape.dim_size(0)));
2756     if (begin_is_modified || stride_is_modified || end_is_modified) {
2757       return errors::Unimplemented(
2758           "TensorRT does not allow modifications to the batch dimension");
2759     }
2760   }
2761 
2762   // shrink_axis_mask requires a reshape after the slice.
2763   std::optional<nvinfer1::Dims> final_shape_dims = std::nullopt;
2764   if (shrink_axis_mask) {
2765     final_shape_dims.emplace();
2766     auto dims_adap =
2767         DimsAdapter::Create(final_shape, params->use_implicit_batch);
2768     TRT_ENSURE_OK(dims_adap);
2769     *final_shape_dims = dims_adap->AsTrtDims();
2770   }
2771 
2772   return ConvertStridedSliceHelper(params, inputs.at(0), input_shape, begin,
2773                                    strides, end, final_shape_dims, 0,
2774                                    strided_slice_spec);
2775 }
2776 
ConvertConv2D(OpConverterParams * params)2777 Status ConvertConv2D(OpConverterParams* params) {
2778   return ConvertConv2DHelper(params, 1, /*is_conv2d_backprop_input=*/false);
2779 }
2780 
ConvertConv2DDepthwise(OpConverterParams * params)2781 Status ConvertConv2DDepthwise(OpConverterParams* params) {
2782   return ConvertConv2DHelper(params, 0, /*is_conv2d_backprop_input=*/false);
2783 }
2784 
ConvertConv2DBackpropInput(OpConverterParams * params)2785 Status ConvertConv2DBackpropInput(OpConverterParams* params) {
2786   return ConvertConv2DHelper(params, 1, /*is_conv2d_backprop_input=*/true);
2787 }
2788 
ConvertConv3DHelper(OpConverterParams * params,int group,bool is_conv3d_backprop_input=false)2789 Status ConvertConv3DHelper(OpConverterParams* params, int group,
2790                            bool is_conv3d_backprop_input = false) {
2791   const int kNumDims = 5;
2792   const auto& inputs = params->inputs;
2793   const auto& node_def = params->node_def;
2794   TRT_TensorOrWeights backprop_output_size;
2795   ITensorProxyPtr tensor = nullptr;
2796   if (is_conv3d_backprop_input) {
2797     // In the case when Conv3dBackpropInput is used for conv3d_transpose, these
2798     // inputs correspond to: output size, filter, and input.
2799     TF_RETURN_IF_ERROR(CheckInputsWeights(
2800         *params,
2801         {{"input_sizes", true}, {"filter", true}, {"out_backprop", false}}));
2802     backprop_output_size = inputs.at(0);
2803     tensor = inputs.at(2).tensor();
2804   } else {
2805     TF_RETURN_IF_ERROR(
2806         CheckInputsWeights(*params, {{"input", false}, {"filter", true}}));
2807     tensor = inputs.at(0).tensor();
2808   }
2809   TF_RETURN_IF_ERROR(
2810       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
2811   const TRT_ShapedWeights weights_drsck = inputs.at(1).weights();
2812   if (weights_drsck.Shape().NumDims() != kNumDims) {
2813     return errors::InvalidArgument("Conv3D expects kernel of dimension 5");
2814   }
2815 
2816   string data_format, padding_type;
2817   std::vector<int64_t> tf_dilations, tf_stride;
2818   AttrSlice attrs(node_def);
2819   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
2820   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding_type));
2821   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "dilations", &tf_dilations));
2822   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &tf_stride));
2823 
2824   const bool is_ndhwc = (data_format == "NDHWC");  // Or NCDHW 01234 - > 02341
2825   const int d_index = is_ndhwc ? 1 : 2;
2826   const int h_index = is_ndhwc ? 2 : 3;
2827   const int w_index = is_ndhwc ? 3 : 4;
2828   const int c_index = is_ndhwc ? 4 : 1;
2829   if (tf_dilations.size() != kNumDims) {
2830     return errors::InvalidArgument(
2831         "Convolution dilations field must specify 5 dimensions");
2832   }
2833   if (tf_dilations[0] != 1 || tf_dilations[c_index] != 1) {
2834     return errors::Unimplemented(
2835         "Dilation rate must be 1 for batch and channel dimensions");
2836   }
2837 
2838   const nvinfer1::Dims3 dilation_dhw(
2839       tf_dilations[d_index], tf_dilations[h_index], tf_dilations[w_index]);
2840   if (is_conv3d_backprop_input &&
2841       (dilation_dhw.d[0] != 1 || dilation_dhw.d[1] != 1 ||
2842        dilation_dhw.d[2] != 1)) {
2843     return errors::Unimplemented(
2844         "Dilation with Conv3DBackpropInputV2 (conv3d_transpose) is not "
2845         "supported");
2846   }
2847 
2848   if (tf_stride.size() != kNumDims) {
2849     return errors::InvalidArgument(
2850         "Convolution strides field must specify 5 dimensions");
2851   }
2852   if (tf_stride[0] != 1 || tf_stride[c_index] != 1) {
2853     return errors::Unimplemented(
2854         "Stride must be 1 for batch and channel dimensions");
2855   }
2856 
2857   const nvinfer1::Dims3 stride_dhw(tf_stride[d_index], tf_stride[h_index],
2858                                    tf_stride[w_index]);
2859   const auto tensor_dim = tensor->getDimensions();
2860 
2861   // Asymmetric padding on Deconv not supported for now
2862   if (is_conv3d_backprop_input && padding_type == "SAME") {
2863     StatusOr<TRT_ShapedWeights> weights =
2864         params->weight_store->GetTempWeights(weights_drsck);
2865     TRT_ENSURE_OK(weights);
2866     nvinfer1::Dims3 effective_kernel_size(
2867         weights->Shape().dim(0) +
2868             (weights->Shape().dim(0) - 1) * (dilation_dhw.d[0] - 1),  // D
2869         weights->Shape().dim(1) +
2870             (weights->Shape().dim(1) - 1) * (dilation_dhw.d[1] - 1),  // R
2871         weights->Shape().dim(2) +
2872             (weights->Shape().dim(2) - 1) * (dilation_dhw.d[2] - 1)  // S
2873     );
2874 
2875     const auto output_size_weights =
2876         backprop_output_size.weights().GetPointer<int>();
2877     const std::vector<int64_t> input_dims = {output_size_weights[d_index],
2878                                              output_size_weights[h_index],
2879                                              output_size_weights[w_index]};
2880 
2881     const std::vector<std::pair<int, int>> padding =
2882         CreateSamePadding(stride_dhw, effective_kernel_size, input_dims);
2883 
2884     if (padding[0].first != padding[0].second ||
2885         padding[1].first != padding[1].second ||
2886         padding[2].first != padding[2].second) {
2887       return errors::Unimplemented(
2888           "Asymmetric padding with Conv3DBackpropInputV2 (conv3d_transpose) is "
2889           "not supported");
2890     }
2891   }
2892 
2893   // Channel dim must be static for Conv3D since we use that value for
2894   // num_groups at build time.
2895   // TODO: Allow conversion if kImplicitBatchModeCompatible||kOptimal is used.
2896   int implicit_batch_offset = params->use_implicit_batch ? -1 : 0;
2897   if (tensor->getDimensions().d[c_index + implicit_batch_offset] == -1) {
2898     return errors::InvalidArgument("Channel dimension must be static");
2899   }
2900 
2901   // Finished validation checks
2902   if (params->validation_only) return Status::OK();
2903 
2904   // Transpose to NCDHW (NCDHW is required for IConvLayer).
2905   const bool need_transpose = is_ndhwc;
2906   if (need_transpose) {
2907     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
2908         tensor, {0, 4, 1, 2, 3}, &tensor, node_def, "to_NCDHW"));
2909   }
2910 
2911   // group == 0 signifies that this is a depthwise convolution, so set
2912   // num_groups to size of input's channel dim. For a non-depthwise conv,
2913   // num_groups will be 1.
2914   const int num_groups = (group == 0) ? tensor_dim.d[0] : group;
2915 
2916   // For conv, TF weights are DRSCK, and TRT expects KCDRS.
2917   // For backprop, TF weights are DRSKC, and TRT expects KCDRS.
2918   // Therefore, this reorder will work for both cases.
2919   StatusOr<TRT_ShapedWeights> weights =
2920       params->weight_store->GetTempWeights(weights_drsck);
2921   TRT_ENSURE_OK(weights);
2922   ReorderDRSCKToKCDRS(weights_drsck, &*weights, num_groups);
2923   TRT_ShapedWeights biases(weights->TrtDType());
2924   const int output_axis = is_conv3d_backprop_input ? 1 : 0;
2925   const int noutput = weights->Shape().dim(output_axis) * num_groups;
2926   nvinfer1::Dims3 kernel_size_drs(weights->Shape().dim(2),  // D
2927                                   weights->Shape().dim(3),  // R
2928                                   weights->Shape().dim(4)   // S
2929   );
2930 
2931   // Add convolution.
2932   nvinfer1::ILayer* conv_layer = nullptr;
2933   if (is_conv3d_backprop_input) {
2934     nvinfer1::IDeconvolutionLayer* layer =
2935         params->converter->network()->addDeconvolutionNd(
2936             *tensor->trt_tensor(), noutput, kernel_size_drs,
2937             weights->GetTrtWeights(), biases.GetTrtWeights());
2938     TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
2939     layer->setStrideNd(stride_dhw);  // change to nd set stride
2940 
2941     if (padding_type == "SAME") {
2942       VLOG(2) << "Using SAME padding";
2943       // SAME_UPPER means that post padding is preferred.
2944       layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
2945     }
2946 
2947     layer->setNbGroups(num_groups);
2948     conv_layer = layer;
2949   } else {
2950     nvinfer1::IConvolutionLayer* layer =
2951         params->converter->network()->addConvolutionNd(
2952             *tensor->trt_tensor(), noutput, kernel_size_drs,
2953             weights->GetTrtWeights(), biases.GetTrtWeights());
2954     TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
2955     layer->setStrideNd(stride_dhw);
2956 
2957     if (padding_type == "SAME") {
2958       VLOG(2) << "Using SAME padding";
2959       layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
2960     }
2961 
2962     layer->setNbGroups(num_groups);
2963     layer->setDilationNd(dilation_dhw);
2964     conv_layer = layer;
2965   }
2966   params->converter->SetLayerName(conv_layer, node_def, "conv");
2967   ITensorProxyPtr output_tensor = conv_layer->getOutput(0);
2968 
2969   // Restore transpose.
2970   if (need_transpose) {
2971     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
2972         output_tensor, {0, 2, 3, 4, 1}, &output_tensor, node_def, "to_NDHWC"));
2973   }
2974   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
2975   return Status::OK();
2976 }
2977 
ConvertConv3D(OpConverterParams * params)2978 Status ConvertConv3D(OpConverterParams* params) {
2979   return ConvertConv3DHelper(params, 1, /*is_conv3d_backprop_input=*/false);
2980 }
2981 
ConvertConv3DBackpropInputV2(OpConverterParams * params)2982 Status ConvertConv3DBackpropInputV2(OpConverterParams* params) {
2983   return ConvertConv3DHelper(params, 1, /*is_conv3d_backprop_input=*/true);
2984 }
2985 
ConvertPool3D(OpConverterParams * params)2986 Status ConvertPool3D(OpConverterParams* params) {
2987   const int kNumDims = 5;
2988   const auto& inputs = params->inputs;
2989   const auto& node_def = params->node_def;
2990   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
2991   TF_RETURN_IF_ERROR(
2992       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
2993   nvinfer1::PoolingType type;
2994   if (node_def.op() == "MaxPool3D") {
2995     type = nvinfer1::PoolingType::kMAX;
2996   } else if (node_def.op() == "AvgPool3D") {
2997     type = nvinfer1::PoolingType::kAVERAGE;
2998   } else {
2999     return errors::Unimplemented("Unsupported pooling type: ", node_def.op());
3000   }
3001 
3002   string data_format, padding_type;
3003   std::vector<int64_t> tf_stride, tf_kernel;
3004   AttrSlice attrs(node_def);
3005   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding_type));
3006   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
3007   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &tf_stride));
3008   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &tf_kernel));
3009 
3010   if ((padding_type != "SAME") && (padding_type != "VALID")) {
3011     return errors::Unimplemented("Unsupported padding type: ", padding_type);
3012   }
3013 
3014   const bool is_ndhwc = (data_format == "NDHWC");
3015   const int c_index = is_ndhwc ? 4 : 1;
3016   const int d_index = is_ndhwc ? 1 : 2;
3017   const int h_index = is_ndhwc ? 2 : 3;
3018   const int w_index = is_ndhwc ? 3 : 4;
3019 
3020   if (tf_stride.size() != kNumDims) {
3021     return errors::InvalidArgument(
3022         "Pooling strides field must specify 5 dimensions");
3023   }
3024   if (tf_stride[0] != 1 || tf_stride[c_index] != 1) {
3025     return errors::Unimplemented(
3026         "stride must be 1 for batch and channel dimensions");
3027   }
3028 
3029   if (tf_kernel.size() != kNumDims) {
3030     return errors::InvalidArgument(
3031         "Pooling ksize field must specify 5 dimensions");
3032   }
3033   if (tf_kernel[0] != 1 || tf_kernel[c_index] != 1) {
3034     return errors::Unimplemented(
3035         "ksize must be 1 for batch and channel dimensions");
3036   }
3037   if (params->validation_only) return Status::OK();
3038 
3039   ITensorProxyPtr tensor = inputs.at(0).tensor();
3040   if (data_format == "NDHWC") {
3041     // NDHWC => NCDHW
3042     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
3043         tensor, {0, 4, 1, 2, 3}, &tensor, node_def, "to_NCDHW"));
3044   }
3045 
3046   const nvinfer1::Dims3 stride(tf_stride[d_index], tf_stride[h_index],
3047                                tf_stride[w_index]);
3048   const nvinfer1::Dims3 ksize(tf_kernel[d_index], tf_kernel[h_index],
3049                               tf_kernel[w_index]);
3050 
3051   nvinfer1::IPoolingLayer* layer = params->converter->network()->addPoolingNd(
3052       *tensor->trt_tensor(), type, ksize);
3053   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3054 
3055   layer->setStrideNd(stride);
3056   // VALID padding is the default TRT behavior.
3057   if (padding_type == "SAME") {
3058     // SAME_UPPER means that post padding is preferred.
3059     layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
3060   }
3061   params->converter->SetLayerName(layer, node_def, "pooling");
3062 
3063   ITensorProxyPtr output_tensor = layer->getOutput(0);
3064   if (data_format == "NDHWC") {
3065     // NCDHW => NDHWC
3066     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
3067         output_tensor, {0, 2, 3, 4, 1}, &output_tensor, node_def, "to_NDHWC"));
3068   }
3069 
3070   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
3071   return Status::OK();
3072 }
3073 
ConvertFusedConv2DBiasActivation(OpConverterParams * params)3074 Status ConvertFusedConv2DBiasActivation(OpConverterParams* params) {
3075   const auto& inputs = params->inputs;
3076   const auto& node_def = params->node_def;
3077 
3078   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false},
3079                                                   {"filter", true},
3080                                                   {"bias", true},
3081                                                   {"side_input", true},
3082                                                   {"conv_input_scale", true},
3083                                                   {"side_input_scale", true}}));
3084   ITensorProxyPtr tensor = inputs.at(0).tensor();
3085   TF_RETURN_IF_ERROR(
3086       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
3087   TRT_ShapedWeights weights = inputs.at(1).weights();
3088   if (weights.Shape().NumDims() != 4) {
3089     return errors::InvalidArgument(
3090         "FusedConv2DBiasActivation expects kernel of dimension 4");
3091   }
3092 
3093   string data_format, filter_format, activation_mode, padding_type;
3094   std::vector<int64_t> tf_dilations, tf_stride;
3095   AttrSlice attrs(node_def);
3096   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
3097   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "filter_format", &filter_format));
3098   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "activation_mode", &activation_mode));
3099   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding_type));
3100   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "dilations", &tf_dilations));
3101   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &tf_stride));
3102 
3103   if (data_format != "NHWC" && data_format != "NCHW") {
3104     return errors::InvalidArgument("Unsupported data_format:", data_format);
3105   }
3106   int c_index = (data_format == "NHWC") ? 3 : 1;
3107   int h_index = (data_format == "NHWC") ? 1 : 2;
3108   int w_index = (data_format == "NHWC") ? 2 : 3;
3109 
3110   if (tf_dilations.size() != 4) {
3111     return errors::InvalidArgument(
3112         "Convolution dilations field must specify 4 dimensions");
3113   }
3114   if (tf_dilations[0] != 1 || tf_dilations[c_index] != 1) {
3115     return errors::Unimplemented(
3116         "Dilation rate must be 1 for batch and channel dimensions");
3117   }
3118   const nvinfer1::DimsHW dilation(tf_dilations[h_index], tf_dilations[w_index]);
3119 
3120   if (tf_stride.size() != 4) {
3121     return errors::InvalidArgument(
3122         "Convolution strides field must specify 4 dimensions");
3123   }
3124   if (tf_stride[0] != 1 || tf_stride[c_index] != 1) {
3125     return errors::Unimplemented(
3126         "Stride must be 1 for batch and channel dimensions");
3127   }
3128   const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
3129   auto op_pair = ActivationTypeMap()->find(activation_mode);
3130   if (op_pair == ActivationTypeMap()->end() && activation_mode != "None") {
3131     return errors::Unimplemented("Activation mode not supported: ",
3132                                  activation_mode);
3133   }
3134 
3135   if (filter_format != "HWIO" && filter_format != "OIHW") {
3136     return errors::InvalidArgument("Unsupported filter_format:", filter_format);
3137   }
3138   // Check that there's no side_input or conv_input_scale.
3139   TRT_ShapedWeights side_input = inputs.at(3).weights();
3140   if (side_input.count() != 0) {
3141     return errors::InvalidArgument(
3142         "FusedConv2DBiasActivation doesn't yet support side_input");
3143   }
3144   TRT_ShapedWeights conv_input_scale = inputs.at(4).weights();
3145   if (conv_input_scale.count() != 1 ||
3146       conv_input_scale.TrtDType() != nvinfer1::DataType::kFLOAT ||
3147       conv_input_scale.GetSpan<float>()[0] != 1.0) {
3148     return errors::InvalidArgument(
3149         "FusedConv2DBiasActivation doesn't yet support conv_input_scale");
3150   }
3151   if (params->validation_only) return Status::OK();
3152 
3153   // Transpose to NCHW (NCHW is required for IConvLayer).
3154   const bool need_transpose = (data_format == "NHWC");
3155   if (need_transpose) {
3156     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
3157         tensor, {0, 3, 1, 2}, &tensor, node_def, "to_NCHW"));
3158   }
3159 
3160   nvinfer1::DimsHW kernel_size;
3161   if (filter_format == "OIHW") {
3162     kernel_size.h() = weights.Shape().dim(2);
3163     kernel_size.w() = weights.Shape().dim(3);
3164   } else {
3165     // HWIO.
3166     DCHECK_EQ(filter_format, "HWIO");
3167     kernel_size.h() = weights.Shape().dim(0);
3168     kernel_size.w() = weights.Shape().dim(1);
3169   }
3170 
3171   // Add convolution.
3172   TRT_ShapedWeights biases = inputs.at(2).weights();
3173   nvinfer1::IConvolutionLayer* conv_layer = nullptr;
3174   if (filter_format == "OIHW") {
3175     // Weights are already in the right order.
3176     conv_layer = params->converter->network()->addConvolution(
3177         *tensor->trt_tensor(), weights.Shape().dim(0), kernel_size,
3178         weights.GetTrtWeights(), biases.GetTrtWeights());
3179   } else {
3180     // For conv, TF weights are RSCK, and TRT expects KCRS.
3181     TRT_ENSURE(filter_format == "HWIO");
3182     StatusOr<TRT_ShapedWeights> weights_kcrs =
3183         params->weight_store->GetTempWeights(weights);
3184     TRT_ENSURE_OK(weights_kcrs);
3185     ReorderRSCKToKCRS(weights, &*weights_kcrs, 1);
3186     conv_layer = params->converter->network()->addConvolution(
3187         *tensor->trt_tensor(), weights.Shape().dim(3), kernel_size,
3188         weights_kcrs->GetTrtWeights(), biases.GetTrtWeights());
3189   }
3190   TFTRT_RETURN_ERROR_IF_NULLPTR(conv_layer, node_def.name());
3191   conv_layer->setStride(stride);
3192   if (padding_type == "SAME") {
3193     conv_layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
3194   }
3195   params->converter->SetLayerName(conv_layer, node_def, "conv");
3196   conv_layer->setNbGroups(1);
3197   conv_layer->setDilation(dilation);
3198   ITensorProxyPtr output_tensor = conv_layer->getOutput(0);
3199 
3200   // Add activation if there is one.
3201   if (op_pair != ActivationTypeMap()->end()) {
3202     nvinfer1::IActivationLayer* activation_layer =
3203         params->converter->network()->addActivation(
3204             *output_tensor->trt_tensor(), op_pair->second);
3205     TFTRT_RETURN_ERROR_IF_NULLPTR(activation_layer, node_def.name());
3206     params->converter->SetLayerName(activation_layer, node_def, "activation");
3207     output_tensor = activation_layer->getOutput(0);
3208   }
3209   // Restore transpose.
3210   if (need_transpose) {
3211     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
3212         output_tensor, {0, 2, 3, 1}, &output_tensor, node_def, "to_NHWC"));
3213   }
3214   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
3215   return Status::OK();
3216 }
3217 
ConvertPool(OpConverterParams * params)3218 Status ConvertPool(OpConverterParams* params) {
3219   const auto& inputs = params->inputs;
3220   const auto& node_def = params->node_def;
3221   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
3222   std::set<DataType> allowed_types{DataType::DT_FLOAT, DataType::DT_HALF,
3223                                    DataType::DT_INT8};
3224   TF_RETURN_IF_ERROR(AllowDataTypes(*params, allowed_types));
3225   nvinfer1::PoolingType type;
3226   if (node_def.op() == "MaxPool") {
3227     type = nvinfer1::PoolingType::kMAX;
3228   } else if (node_def.op() == "AvgPool") {
3229     type = nvinfer1::PoolingType::kAVERAGE;
3230   } else {
3231     return errors::Unimplemented("Unsupported pooling type: ", node_def.op());
3232   }
3233 
3234   string data_format, padding_type;
3235   std::vector<int64_t> tf_stride, tf_kernel;
3236   AttrSlice attrs(node_def);
3237   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
3238   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding_type));
3239   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &tf_stride));
3240   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &tf_kernel));
3241 
3242   if ((padding_type != "SAME") && (padding_type != "VALID")) {
3243     return errors::Unimplemented("Unsupported padding type: ", padding_type);
3244   }
3245   if (params->validation_only) return Status::OK();
3246 
3247   ITensorProxyPtr tensor = inputs.at(0).tensor();
3248   int h_index = 2;
3249   int w_index = 3;
3250   if (data_format == "NHWC") {
3251     h_index = 1;
3252     w_index = 2;
3253     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
3254         tensor, {0, 3, 1, 2}, &tensor, node_def, "to_NCHW"));
3255   }
3256 
3257   const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
3258   const nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
3259 
3260   nvinfer1::IPoolingLayer* layer = params->converter->network()->addPooling(
3261       *tensor->trt_tensor(), type, ksize);
3262   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3263 
3264   layer->setStride(stride);
3265   // VALID padding is the default TRT behavior.
3266   if (padding_type == "SAME") {
3267     // SAME_UPPER means that post padding is preferred.
3268     layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
3269   }
3270   params->converter->SetLayerName(layer, node_def, "pooling");
3271   ITensorProxyPtr output_tensor = layer->getOutput(0);
3272 
3273   if (data_format == "NHWC") {
3274     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
3275         output_tensor, {0, 2, 3, 1}, &output_tensor, node_def, "to_NHWC"));
3276   }
3277   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
3278   return Status::OK();
3279 }
3280 
ConvertClipByValue(OpConverterParams * params)3281 Status ConvertClipByValue(OpConverterParams* params) {
3282   const auto& inputs = params->inputs;
3283   const auto& node_def = params->node_def;
3284   // TODO(tmorris): We can also allow the case where min and max are tensors by
3285   // using elementwise min and max layers.
3286   TF_RETURN_IF_ERROR(CheckInputsWeights(
3287       *params,
3288       {{"t", false}, {"clip_value_min", true}, {"clip_value_max", true}}));
3289   TF_RETURN_IF_ERROR(
3290       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
3291   if (params->validation_only) return Status::OK();
3292 
3293   DataType dtype;
3294   TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node_def), "T", &dtype));
3295 
3296   float clip_value_min = 0.0f;
3297   float clip_value_max = 0.0f;
3298   // TODO(tmorris): Add a templated helper function to get scalar weights of
3299   // InType casted to OutType.
3300   if (dtype == DataType::DT_FLOAT) {
3301     clip_value_min = inputs.at(1).weights().GetSpan<float>()[0];
3302     clip_value_max = inputs.at(2).weights().GetSpan<float>()[0];
3303   } else if (dtype == DataType::DT_HALF) {
3304     clip_value_min =
3305         static_cast<float>(inputs.at(1).weights().GetSpan<Eigen::half>()[0]);
3306     clip_value_max =
3307         static_cast<float>(inputs.at(2).weights().GetSpan<Eigen::half>()[0]);
3308   }
3309 
3310   nvinfer1::IActivationLayer* layer =
3311       params->converter->network()->addActivation(
3312           *inputs.at(0).tensor()->trt_tensor(),
3313           nvinfer1::ActivationType::kCLIP);
3314   layer->setAlpha(clip_value_min);
3315   layer->setBeta(clip_value_max);
3316   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3317   params->converter->SetLayerName(layer, node_def, "activation");
3318   params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
3319   return Status::OK();
3320 }
3321 
ConvertBiasAdd(OpConverterParams * params)3322 Status ConvertBiasAdd(OpConverterParams* params) {
3323   const auto& inputs = params->inputs;
3324   const auto& node_def = params->node_def;
3325   TFTRT_CHECK_INPUT_SIZE(inputs.size(), 2, node_def);
3326 
3327   if (inputs[0].is_weights() && inputs[1].is_weights()) {
3328     // TODO(lsugy): don't assume that if all inputs are weights, grappler
3329     // should fold them, because variables are weights.
3330     return errors::InvalidArgument(
3331         "All inputs are weights, but Grappler is expected to fold them.");
3332   }
3333 
3334   TF_RETURN_IF_ERROR(
3335       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
3336 
3337   string data_format;
3338   TF_RETURN_IF_ERROR(
3339       GetNodeAttr(AttrSlice(node_def), "data_format", &data_format));
3340 
3341   nvinfer1::Dims input_shape = inputs.at(0).GetTrtDims();
3342   nvinfer1::Dims bias_shape = inputs.at(1).GetTrtDims();
3343   // The bias input arg is a 1-D tensor with length C. If the input is NCHW,
3344   // then we need to unsqueeze the bias such that its shape is [1, C, 1, 1].
3345   if (data_format == "NCHW") {
3346     if (params->use_implicit_batch) {
3347       // The batch dim is not included in implicit batch mode, so the shape of
3348       // the bias tensor is [C, 1, 1].
3349       bias_shape.nbDims = input_shape.nbDims;
3350       std::fill(bias_shape.d + 1, bias_shape.d + bias_shape.nbDims, 1);
3351     } else {
3352       // In explicit batch mode we create a tensor with shape [1, C, 1, 1].
3353       std::vector<int> bias_shape_vec(bias_shape.d,
3354                                       bias_shape.d + bias_shape.nbDims);
3355       // Insert 1 before for batch dim
3356       bias_shape_vec.insert(bias_shape_vec.begin(), 1);
3357       // Trail with 1s to match input_shape size
3358       bias_shape_vec.insert(bias_shape_vec.end(),
3359                             input_shape.nbDims - bias_shape_vec.size(), 1);
3360       DimsAdapter(bias_shape_vec).TrtDims(&bias_shape);
3361     }
3362   } else {
3363     // Next, broadcast the bias across the input.
3364     TF_RETURN_IF_ERROR(GetTrtBroadcastShape(inputs.at(0), inputs.at(1),
3365                                             /*check_feasibility=*/true,
3366                                             params->use_implicit_batch,
3367                                             &input_shape, &bias_shape));
3368   }
3369 
3370   // Convert input to a TRT tensor
3371   ITensorProxyPtr input_tensor{nullptr};
3372   TF_RETURN_IF_ERROR(PrepareTensorForShape(
3373       params->converter, inputs.at(0), DimsAdapter(input_shape),
3374       params->validation_only, &input_tensor, node_def,
3375       /*op_instance=*/0));
3376 
3377   // Finally, reshape bias. Since the bias is usually a constant, this will
3378   // normally happen at conversion-time.
3379   ITensorProxyPtr bias_tensor{nullptr};
3380   TF_RETURN_IF_ERROR(PrepareTensorForShape(
3381       params->converter, inputs.at(1), DimsAdapter(bias_shape),
3382       params->validation_only, &bias_tensor, node_def,
3383       /*op_instance=*/1));
3384   VLOG(2) << "Bias shape adjusted to " << DebugString(bias_shape);
3385 
3386   if (params->validation_only) return Status::OK();
3387 
3388   nvinfer1::IElementWiseLayer* layer =
3389       params->converter->network()->addElementWise(
3390           *input_tensor->trt_tensor(), *bias_tensor->trt_tensor(),
3391           nvinfer1::ElementWiseOperation::kSUM);
3392   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3393   params->converter->SetLayerName(layer, node_def, "sum");
3394   ITensorProxyPtr output_tensor = layer->getOutput(0);
3395 
3396   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
3397   return Status::OK();
3398 }
3399 
3400 template <typename Input>
IsIntegerInInt32Bounds(const Input & inp)3401 inline bool IsIntegerInInt32Bounds(const Input& inp) {
3402   static_assert(std::is_integral<Input>::value,
3403                 "This function is only implemented for integral types.");
3404   // If Input is always within the range of int32, return true.
3405   if (sizeof(Input) < sizeof(int32) || std::is_same<Input, int32>::value) {
3406     return true;
3407   }
3408   // Otherwise, we need to check the value of the input. If the input is
3409   // unsigned, we only check the upper bound.
3410   if (!std::numeric_limits<Input>::is_signed) {
3411     return inp <= static_cast<Input>(std::numeric_limits<int32>::max());
3412   }
3413   // We can safely cast lowest() here since we now know that Input is signed and
3414   // sizeof(Input) >= sizeof(int32)
3415   return (inp >= static_cast<Input>(std::numeric_limits<int32>::lowest()) &&
3416           inp <= static_cast<Input>(std::numeric_limits<int32>::max()));
3417 }
3418 
3419 template <DataType dtype>
CopyToTrtInt32Array(const Tensor & tensor,int32 * dst)3420 Status CopyToTrtInt32Array(const Tensor& tensor, int32* dst) {
3421   typedef typename EnumToDataType<dtype>::Type CType;
3422   const CType* src = tensor.flat<CType>().data();
3423   for (int i = 0; i < tensor.NumElements(); ++i) {
3424     // This becomes a no-op if CType is within bounds of int32
3425     if (!IsIntegerInInt32Bounds(src[i])) {
3426       return errors::InvalidArgument("Value at index ", i,
3427                                      " is outside the range of int32");
3428     }
3429     dst[i] = static_cast<int32>(src[i]);
3430   }
3431   return Status::OK();
3432 }
3433 
TfTensorToTrtWeights(const Tensor & tensor,TrtWeightStore * weight_store,TRT_ShapedWeights * weights)3434 Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store,
3435                             TRT_ShapedWeights* weights) {
3436   const DataType dtype = tensor.dtype();
3437 
3438   // We always convert the integer constants to INT32.
3439   //
3440   // TODO(aaroey): FP16 will remain in half format and is not converted to
3441   // FP32, but the converter currently uses all float weights as FP32. Fix
3442   // this.
3443   DataType converted_dtype = DataTypeIsInteger(dtype) ? DT_INT32 : dtype;
3444 
3445   // Verify that the dtype is supported by TensorRT. Otherwise, return an error.
3446   nvinfer1::DataType trt_dtype;
3447   TF_RETURN_IF_ERROR(TfTypeToTrtType(converted_dtype, &trt_dtype));
3448 
3449   if (tensor.NumElements() == 0) {
3450     // Return empty weights.
3451     *weights = TRT_ShapedWeights(trt_dtype);
3452     return Status::OK();
3453   }
3454 
3455   StatusOr<DimsAdapter> weight_dims = DimsAdapter::Create(tensor.shape());
3456   TRT_ENSURE_OK(weight_dims);
3457 
3458   auto tmp = weight_store->GetTempWeights(trt_dtype, weight_dims->AsTrtDims());
3459   TRT_ENSURE_OK(tmp);
3460   *weights = std::move(tmp).value();
3461 
3462   // Copy the tensor directly if the tensor does not require cast to the
3463   // supported type.
3464   if (converted_dtype == dtype) {
3465     std::copy_n(tensor.tensor_data().data(), tensor.TotalBytes(),
3466                 weights->GetPointer<int8>());
3467     return Status::OK();
3468   }
3469 
3470   Status status = Status::OK();
3471   // Copy tensor elements after casting them to the converted DataType.
3472   int32* dst = weights->GetPointer<int32>();
3473   switch (dtype) {
3474     case DT_INT8:
3475       status = CopyToTrtInt32Array<DT_INT8>(tensor, dst);
3476       break;
3477     case DT_UINT8:
3478       status = CopyToTrtInt32Array<DT_UINT8>(tensor, dst);
3479       break;
3480     case DT_INT16:
3481       status = CopyToTrtInt32Array<DT_INT16>(tensor, dst);
3482       break;
3483     case DT_UINT16:
3484       status = CopyToTrtInt32Array<DT_UINT16>(tensor, dst);
3485       break;
3486     case DT_UINT32:
3487       status = CopyToTrtInt32Array<DT_UINT32>(tensor, dst);
3488       break;
3489     case DT_INT64:
3490       status = CopyToTrtInt32Array<DT_INT64>(tensor, dst);
3491       break;
3492     case DT_UINT64:
3493       status = CopyToTrtInt32Array<DT_UINT64>(tensor, dst);
3494       break;
3495     default:
3496       return errors::Internal("Unexpected DataType: ", DataTypeString(dtype));
3497   }
3498   return status;
3499 }
3500 
3501 // Convert a Const NodeDef to TRT_ShapedWeights. This is a special converter, it
3502 // always ignores the params->validation_only parameter but adds the converted
3503 // weights to params->outputs. We did this since TrtNodeValidator needs the
3504 // weights as input to other nodes, and use it to determine whether those nodes
3505 // are supported by TRT.
ConvertConst(OpConverterParams * params)3506 Status ConvertConst(OpConverterParams* params) {
3507   const auto& inputs = params->inputs;
3508   const auto& node_def = params->node_def;
3509   if (!inputs.empty()) {
3510     return errors::InvalidArgument(
3511         "Constant node is expected to have empty input list");
3512   }
3513 
3514   // Create shaped weights as output
3515   const auto& tensor_proto = node_def.attr().at("value").tensor();
3516   Tensor tensor;
3517   if (!tensor.FromProto(tensor_proto)) {
3518     return errors::Internal("Cannot parse weight tensor proto: ",
3519                             node_def.name());
3520   }
3521 
3522   DataType dtype;
3523   TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node_def), "dtype", &dtype));
3524 
3525   if (dtype != tensor.dtype()) {
3526     return errors::InvalidArgument("DataType mismatch between attr (",
3527                                    DataTypeString(dtype), ") and tensor (",
3528                                    DataTypeString(tensor.dtype()), ")");
3529   }
3530 
3531   TRT_ShapedWeights weights;
3532   TF_RETURN_IF_ERROR(
3533       TfTensorToTrtWeights(tensor, params->weight_store, &weights));
3534 
3535   if (params->outputs != nullptr) {
3536     params->outputs->push_back(TRT_TensorOrWeights(weights));
3537   }
3538   return Status::OK();
3539 }
3540 
ConvertIdentity(OpConverterParams * params)3541 Status ConvertIdentity(OpConverterParams* params) {
3542   // TODO(tmorris): TRT's Identity layer does not get optimized away as of TRT
3543   // 5.0, however once we know that it does it would be nice to use that
3544   // instead.
3545   if (params->validation_only) return Status::OK();
3546 
3547   for (int i = 0; i < params->inputs.size(); i++) {
3548     params->outputs->push_back(params->inputs.at(i));
3549   }
3550   return Status::OK();
3551 }
3552 
ConvertSquare(OpConverterParams * params)3553 Status ConvertSquare(OpConverterParams* params) {
3554   const auto& inputs = params->inputs;
3555   const auto& node_def = params->node_def;
3556   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
3557   TF_RETURN_IF_ERROR(AllowDataTypes(
3558       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
3559   if (params->validation_only) return Status::OK();
3560 
3561   // Constant 2 with same rank as input
3562   ITensorProxyPtr const2_tensor = nullptr;
3563   TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant(
3564       params, 2.0f, inputs.at(0).GetTrtDims(), &const2_tensor));
3565 
3566   // ElementWise Pow Operation
3567   nvinfer1::IElementWiseLayer* layer =
3568       params->converter->network()->addElementWise(
3569           *inputs.at(0).tensor()->trt_tensor(), *const2_tensor->trt_tensor(),
3570           nvinfer1::ElementWiseOperation::kPOW);
3571   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3572   params->converter->SetLayerName(layer, node_def);
3573   ITensorProxyPtr output_tensor = layer->getOutput(0);
3574 
3575   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
3576   return Status::OK();
3577 }
3578 
ConvertReduce(OpConverterParams * params)3579 Status ConvertReduce(OpConverterParams* params) {
3580   const auto& inputs = params->inputs;
3581   const auto& node_def = params->node_def;
3582   TF_RETURN_IF_ERROR(
3583       CheckInputsWeights(*params, {{"input", false}, {"axis", true}}));
3584   TF_RETURN_IF_ERROR(AllowDataTypes(
3585       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
3586 
3587   ITensorProxyPtr tensor = inputs.at(0).tensor();
3588   auto tf_axes_list = inputs.at(1).weights().GetSpan<int>();
3589 
3590   DataType idx_dtype{DataType::DT_INT32};
3591   bool keep_dims{false};
3592   AttrSlice attrs(node_def);
3593   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Tidx", &idx_dtype));
3594   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "keep_dims", &keep_dims));
3595 
3596   // Only expect to handle INT32 as attributes for now
3597   if (idx_dtype != DataType::DT_INT32) {
3598     return errors::Unimplemented("Tidx supports only DT_INT32");
3599   }
3600 
3601   int axes = 0;
3602   if (tf_axes_list.size() == 0) {
3603     return errors::InvalidArgument(
3604         "TRT cannot support reduce on all (batch) dimensions");
3605   }
3606   for (int i = 0; i < tf_axes_list.size(); i++) {
3607     int trt_axis;
3608     TF_RETURN_IF_ERROR(
3609         ConvertAxis(tf_axes_list[i], tensor->getDimensions().nbDims,
3610                     node_def.name(), params->use_implicit_batch, &trt_axis));
3611     axes |= (1 << trt_axis);
3612   }
3613 
3614   nvinfer1::ReduceOperation reduce_operation;
3615   if (node_def.op() == "Sum") {
3616     reduce_operation = nvinfer1::ReduceOperation::kSUM;
3617   } else if (node_def.op() == "Prod") {
3618     reduce_operation = nvinfer1::ReduceOperation::kPROD;
3619   } else if (node_def.op() == "Max") {
3620     reduce_operation = nvinfer1::ReduceOperation::kMAX;
3621   } else if (node_def.op() == "Min") {
3622     reduce_operation = nvinfer1::ReduceOperation::kMIN;
3623   } else if (node_def.op() == "Mean") {
3624     reduce_operation = nvinfer1::ReduceOperation::kAVG;
3625   } else {
3626     return errors::Unimplemented("Op not supported ", node_def.op());
3627   }
3628   if (params->validation_only) return Status::OK();
3629 
3630   nvinfer1::ILayer* layer = params->converter->network()->addReduce(
3631       *tensor->trt_tensor(), reduce_operation, axes, keep_dims);
3632   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3633   params->converter->SetLayerName(layer, node_def);
3634 
3635   params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
3636   return Status::OK();
3637 }
3638 
3639 // TensorRT does not support the Pack op natively. Therefore, Pack op is
3640 // converted by first expanding input tensors by adding a new dimension of size
3641 // one at the specified axis and then concatenating the tensors at the same
3642 // axis.
ConvertPack(OpConverterParams * params)3643 Status ConvertPack(OpConverterParams* params) {
3644   const auto& inputs = params->inputs;
3645   const auto& node_def = params->node_def;
3646 
3647   int num_inputs{0};
3648   int64_t tf_axis{0};
3649   AttrSlice attrs(node_def);
3650   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N", &num_inputs));
3651   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "axis", &tf_axis));
3652 
3653   if (num_inputs != inputs.size()) {
3654     return errors::InvalidArgument(
3655         "Number of inputs for Pack is inconsistent with N attribute");
3656   }
3657 
3658   // In implicit batch mode we do not allow weight input. An input tensor with
3659   // dims NCHW is represented with dims CHW during conversion time, and N is
3660   // defined only during runtime. A weight is represented with dims NCHW. We
3661   // cannot be sure that the runtime N will agree with the conversion time N,
3662   // therefore we do not convert the pack op if it has both tensor and weight
3663   // inputs. This restriction does not apply in explicit batch mode, in that
3664   // case the input tensors are also represented with full dims that include the
3665   // batch size.
3666   TrtInputArg expected_arg =
3667       params->use_implicit_batch ? TrtInputArg::kTensor : TrtInputArg::kBoth;
3668 
3669   std::vector<std::pair<string, TrtInputArg>> inputs_is_weight;
3670   inputs_is_weight.reserve(num_inputs);
3671   for (int i = 0; i < num_inputs; ++i) {
3672     inputs_is_weight.push_back({StrCat("values_", i), expected_arg});
3673   }
3674   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, inputs_is_weight));
3675 
3676   std::set<DataType> allowed_types{DataType::DT_FLOAT, DataType::DT_HALF,
3677                                    DataType::DT_INT32};
3678   TF_RETURN_IF_ERROR(AllowDataTypes(*params, allowed_types));
3679   if (num_inputs > 1) {
3680     // Verify that inputs are compatible for concatenation after the expansion.
3681     TF_RETURN_IF_ERROR(
3682         VerifyShapesMatch(inputs, /*masked_dim=*/-1, node_def.name()));
3683   }
3684 
3685   // Find the dimension of the inputs. In general inputs can have dynamic shape,
3686   // in that case we have to use DynamicExpandDims to calculate the expanded
3687   // dimensions. To avoid that, we try to find a weight input which is
3688   // guaranteed to have known static shape.
3689   int idx = 0;
3690   for (int i = 1; i < inputs.size(); i++) {
3691     if (HasStaticShape(inputs.at(i).GetTrtDims())) {
3692       idx = i;
3693     }
3694   }
3695   DimsAdapter dims(inputs.at(idx).GetTrtDims());
3696   // Convert axis from the TensorFlow format to TensorRT format.
3697   int trt_axis{0};
3698   TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.NumDims() + 1, node_def.name(),
3699                                  params->use_implicit_batch, &trt_axis));
3700 
3701   // Compute expanded dimensions and then reshape input tensors.
3702   std::vector<int64_t> tensor_dims(dims.begin(), dims.end());
3703   tensor_dims.insert(tensor_dims.begin() + trt_axis, 1);
3704   std::vector<ITensorProxyPtr> expanded_tensors;
3705 
3706   int input_index = 0;
3707   for (const TRT_TensorOrWeights& input : inputs) {
3708     ITensorProxyPtr expanded_tensor = nullptr;
3709     if (input.is_tensor() && !params->use_implicit_batch &&
3710         !HasStaticShape(dims)) {
3711       if (!params->validation_only) {
3712         TF_RETURN_IF_ERROR(params->converter->DynamicExpandDims(
3713             /*input=*/input.tensor(),
3714             /*dims=*/dims.AsTrtDims(),
3715             /*axis=*/trt_axis,
3716             /*params=*/params,
3717             /*output=*/&expanded_tensor,
3718             /*op_instance=*/input_index));
3719       }
3720     } else {
3721       TF_RETURN_IF_ERROR(PrepareTensorForShape(
3722           /*converter=*/params->converter,
3723           /*input=*/input,
3724           /*dims=*/DimsAdapter(tensor_dims),
3725           /*validation_only=*/params->validation_only,
3726           /*tensor=*/&expanded_tensor,
3727           /*node_def=*/node_def,
3728           /*op_instance=*/input_index));
3729     }
3730     if (!params->validation_only) {
3731       expanded_tensors.push_back(expanded_tensor);
3732     }
3733     input_index++;
3734   }
3735   if (params->validation_only) return Status::OK();
3736 
3737   // If there is only one tensor in the input, return the expanded tensor.
3738   if (num_inputs == 1) {
3739     params->outputs->push_back(TRT_TensorOrWeights(expanded_tensors[0]));
3740     return Status::OK();
3741   }
3742 
3743   // Otherwise, concatenate expanded tensors.
3744   std::vector<nvinfer1::ITensor*> trt_expanded_tensors;
3745   for (const auto& t : expanded_tensors) {
3746     trt_expanded_tensors.push_back(t->trt_tensor());
3747   }
3748   nvinfer1::IConcatenationLayer* layer =
3749       params->converter->network()->addConcatenation(
3750           static_cast<nvinfer1::ITensor* const*>(trt_expanded_tensors.data()),
3751           expanded_tensors.size());
3752   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3753   params->converter->SetLayerName(layer, node_def, "concat");
3754   // Note that trt_axis stays the same even after expanding tensors at the axis.
3755   layer->setAxis(trt_axis);
3756   params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
3757   return Status::OK();
3758 }
3759 
ConvertPad(OpConverterParams * params)3760 Status ConvertPad(OpConverterParams* params) {
3761   const auto& inputs = params->inputs;
3762   const auto& node_def = params->node_def;
3763   TF_RETURN_IF_ERROR(
3764       CheckInputsWeights(*params, {{"tensor", false}, {"paddings", true}}));
3765   TF_RETURN_IF_ERROR(AllowDataTypes(
3766       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT8}));
3767 
3768   // Implement tensor binaryOp weight [channel wise] for now;
3769   ITensorProxyPtr tensor = inputs.at(0).tensor();
3770   const auto dims = tensor->getDimensions();
3771   // Restore implicit batch dimension
3772   const int nb_dims =
3773       params->use_implicit_batch ? dims.nbDims + 1 : dims.nbDims;
3774 
3775   // TODO(tfeher): Support nb_dims < 4 by inserting extra dimensions to the
3776   // original input.
3777   if (nb_dims < 4) {
3778     return errors::InvalidArgument("Convertpad requires at least 4D input");
3779   }
3780   TRT_ShapedWeights pads = inputs.at(1).weights();
3781 
3782   // TODO(jie): handle data type conversion for TRT?
3783   DataType padding_dtype{DataType::DT_INT32};
3784   TF_RETURN_IF_ERROR(
3785       GetNodeAttr(AttrSlice(node_def), "Tpaddings", &padding_dtype));
3786 
3787   if (pads.Shape().dim(0) != nb_dims || pads.Shape().dim(1) != 2) {
3788     return errors::InvalidArgument("Paddings must be a weight with shape ",
3789                                    "[n, 2], where n is the rank of input ",
3790                                    "tensor");
3791   }
3792 
3793   // Only expect to handle INT32 as attributes for now
3794   if (padding_dtype != DataType::DT_INT32) {
3795     return errors::Unimplemented("Tpaddings supports only DT_INT32");
3796   }
3797   auto pad_data = pads.GetPointer<int>();
3798 
3799   std::vector<int32_t> tf_pad_index;
3800   for (int i = 0; i < nb_dims; i++) {
3801     if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0) {
3802       tf_pad_index.push_back(i);
3803     }
3804   }
3805 
3806   // No padding at all, we should exit
3807   if (tf_pad_index.empty()) {
3808     params->outputs->push_back(inputs.at(0));
3809     return Status::OK();
3810   }
3811 
3812   // TRT pad layer can only support padding on up to 2 dimensions (TRT-2579).
3813   // TODO(tfeher): Use multiple TRT pad layers to support padding on more than 2
3814   // dimensions.
3815   if (tf_pad_index.size() > 2) {
3816     return errors::InvalidArgument(
3817         "Padding layer does not support padding on > 2");
3818   }
3819 
3820   // Padding on batch dimension is not supported
3821   if (params->use_implicit_batch && tf_pad_index[0] == 0) {
3822     return errors::InvalidArgument(
3823         "Padding layer does not support padding on batch dimension");
3824   }
3825 
3826   if (params->validation_only) return Status::OK();
3827 
3828   // TRT can only do the padding at the last two dimensions. We transpose the
3829   // input tensor if needed.
3830   bool transposed_pad = false;
3831   std::vector<int> transpose_idx(nb_dims);
3832   std::iota(transpose_idx.begin(), transpose_idx.end(), 0);
3833 
3834   // trt_pad_index denotes the actual idx where the padding is performed by TRT.
3835   std::vector<int> trt_pad_index{nb_dims - 2, nb_dims - 1};
3836 
3837   // How many zeros are padded at the last two dimensions.
3838   nvinfer1::DimsHW pre_padding(0, 0);
3839   nvinfer1::DimsHW post_padding(0, 0);
3840 
3841   // Dimension to set in the pre_padding and post_padding array.
3842   std::vector<int> trt_pre_post_padding_index{0, 1};
3843 
3844   // Two special cases where we can avoid permutations.
3845   if (tf_pad_index.size() == 1 && tf_pad_index[0] == nb_dims - 1) {
3846     // Only one dimension needs to be padded. We store its index at
3847     // trt_pad_index[0]. We ignore trt_pad_index[1].
3848     trt_pad_index[0] = nb_dims - 1;
3849     trt_pre_post_padding_index[0] = 1;
3850   }
3851   if (tf_pad_index.size() == 2 && tf_pad_index[1] == nb_dims - 2) {
3852     // tf_pad_index only has two values that are in ascending order. If
3853     // tf_pad_index[1] is nb_dims-2, then swapping the two values in
3854     // trt_pad_index here makes it possible to only swap one pair of dimensions
3855     // (swap tf_pad_index[0] with nb_dims-1) in the input tensor. Otherwise, we
3856     // would have to swap two pairs of dimensions in the input tensor:
3857     // (tf_pad_index[0] with nb_dims-2) and (tf_pad_index[1], with nb_dims-1).
3858     // Here is an example for a 4D input tensor:
3859     // tf_pad_index = [1, 2]
3860     // trt_pad_index = [3, 2]
3861     // transpose_idx = [0, 3, 2, 1]
3862     std::swap(trt_pad_index[0], trt_pad_index[1]);
3863     std::swap(trt_pre_post_padding_index[0], trt_pre_post_padding_index[1]);
3864   }
3865 
3866   for (int i = 0; i < tf_pad_index.size(); i++) {
3867     const int tf_index = tf_pad_index[i];
3868     const int trt_index = trt_pad_index[i];
3869     const int k = trt_pre_post_padding_index[i];
3870     pre_padding.d[k] = pad_data[tf_index * 2];
3871     post_padding.d[k] = pad_data[tf_index * 2 + 1];
3872     if (tf_index != trt_index) {
3873       transposed_pad = true;
3874       std::swap(transpose_idx[tf_index], transpose_idx[trt_index]);
3875     }
3876   }
3877 
3878   if (transposed_pad) {
3879     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
3880         tensor, transpose_idx, &tensor, node_def, "to_pad"));
3881   }
3882 
3883   nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding(
3884       *tensor->trt_tensor(), pre_padding, post_padding);
3885   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
3886   params->converter->SetLayerName(layer, node_def);
3887   ITensorProxyPtr output_tensor = layer->getOutput(0);
3888 
3889   if (transposed_pad) {
3890     TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
3891         output_tensor, transpose_idx, &output_tensor, node_def, "from_pad"));
3892   }
3893 
3894   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
3895   return Status::OK();
3896 }
3897 
ConvertSplitHelper(OpConverterParams * params,const TRT_TensorOrWeights & input,int tf_axis,int num_splits,bool squeeze_after)3898 Status ConvertSplitHelper(OpConverterParams* params,
3899                           const TRT_TensorOrWeights& input, int tf_axis,
3900                           int num_splits, bool squeeze_after) {
3901   const auto& node_def = params->node_def;
3902   const nvinfer1::Dims dims = input.GetTrtDims();
3903   // Convert axis.
3904   int trt_axis;
3905   TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(),
3906                                  params->use_implicit_batch, &trt_axis));
3907 
3908   if (dims.d[trt_axis] < 0) {
3909     return errors::InvalidArgument("Dimension ", tf_axis,
3910                                    " must have statically defined dimensions");
3911   }
3912 
3913   // Dimension must equal num_splits for Unstack (when squeeze_after is true)
3914   if (squeeze_after && dims.d[trt_axis] != num_splits) {
3915     return errors::InvalidArgument(
3916         "Dimension ", tf_axis, " has size ", dims.d[trt_axis],
3917         " which is not equal to num of ", num_splits);
3918   }
3919   // Dimension must be evenly divisible by num_splits.
3920   if (dims.d[trt_axis] % num_splits != 0) {
3921     return errors::InvalidArgument("Dimension ", tf_axis, " of size ",
3922                                    dims.d[trt_axis],
3923                                    " is not evenly divisible by ", num_splits);
3924   }
3925 
3926   // Create parameters for StridedSliceHelper.
3927   // Slice will begin on zero for all dims, except the one being split which
3928   // will change.
3929   std::vector<int> begin(dims.nbDims, 0);
3930   std::vector<int64> input_dims(dims.d, dims.d + dims.nbDims);
3931 
3932   // Determine size of split. Slice will get the full length of all dims, except
3933   // the one being split. Undefined dims (-1) will translate to a size of -1
3934   // which will tell StridedSlice to take full length of that dim.
3935   std::vector<int> size(dims.d, dims.d + dims.nbDims);
3936   const int split_size_on_axis = dims.d[trt_axis] / num_splits;
3937   size[trt_axis] = split_size_on_axis;
3938   // Stride will always be 1
3939   std::vector<int> stride(dims.nbDims, 1);
3940   // Add dummy batch dimension
3941   if (params->use_implicit_batch) {
3942     begin.insert(begin.begin(), 0);
3943     size.insert(size.begin(), 1);
3944     stride.insert(stride.begin(), 1);
3945     input_dims.insert(input_dims.begin(), std::max(-1, input.batch_size()));
3946   }
3947   PartialTensorShape input_shape(input_dims);
3948 
3949   // Create final shape for Unpack/Unstack, where split axis is squeezed.
3950   std::optional<nvinfer1::Dims> final_shape_for_unpack = std::nullopt;
3951 
3952   // We can't use final_shape_for_unpack_ptr when input dimensions are not
3953   // fully defined.
3954   const bool is_dynamic_shape = !HasStaticShape(dims);
3955   if (squeeze_after && !is_dynamic_shape) {
3956     std::vector<int> size_after_squeeze(size);
3957     const int tf_axis = trt_axis + (params->use_implicit_batch ? 1 : 0);
3958     size_after_squeeze.erase(size_after_squeeze.begin() + tf_axis);
3959     DimsAdapter adap(size_after_squeeze);
3960     if (params->use_implicit_batch)
3961       TF_RETURN_IF_ERROR(adap.RemoveBatchDimension());
3962     final_shape_for_unpack = adap.AsTrtDims();
3963   }
3964 
3965   // Slice the input. ConvertStridedSliceHelper will push the outputs onto
3966   // params->outputs.
3967   for (int i = 0; i < num_splits; ++i) {
3968     const int tf_axis = trt_axis + (params->use_implicit_batch ? 1 : 0);
3969     begin[tf_axis] = i * split_size_on_axis;
3970 
3971     // Stride is 1 for all dims.
3972     absl::InlinedVector<int64, 4> stride_v(begin.size(), 1);
3973     absl::InlinedVector<int64, 4> begin_v;
3974     absl::InlinedVector<int64, 4> end_v;
3975     for (int i = 0; i < begin.size(); i++) {
3976       end_v.push_back(begin[i] + size[i]);
3977       begin_v.push_back(begin[i]);
3978     }
3979 
3980     TF_RETURN_IF_ERROR(ConvertStridedSliceHelper(
3981         params, input, input_shape, begin_v, stride_v, end_v,
3982         final_shape_for_unpack,
3983         /*op_instance=*/i, /*strided_slice_spec=*/std::nullopt));
3984   }
3985   if (params->validation_only) return Status::OK();
3986 
3987   // Squeeze for dynamic shapes
3988   if (squeeze_after && is_dynamic_shape) {
3989     for (int i = 0; i < params->outputs->size(); i++) {
3990       ITensorProxyPtr output_tensor = nullptr;
3991       std::vector<int> in_dims(dims.d, dims.d + dims.nbDims);
3992       input_dims[trt_axis] = 0;
3993       TF_RETURN_IF_ERROR(params->converter->SqueezeTensor(
3994           /*input=*/params->outputs->at(i).tensor(),
3995           /*input_dims=*/&in_dims,
3996           /*params=*/params,
3997           /*output=*/&output_tensor,
3998           /*op_instance=*/i));
3999       (*params->outputs)[i] = TRT_TensorOrWeights(output_tensor);
4000     }
4001   }
4002   return Status::OK();
4003 }
4004 
ConvertSplit(OpConverterParams * params)4005 Status ConvertSplit(OpConverterParams* params) {
4006   const auto& inputs = params->inputs;
4007   const auto& node_def = params->node_def;
4008   TF_RETURN_IF_ERROR(
4009       CheckInputsWeights(*params, {{"axis", true}, {"value", false}}));
4010   TF_RETURN_IF_ERROR(AllowDataTypes(*params, {
4011                                                  DataType::DT_FLOAT,
4012                                                  DataType::DT_HALF,
4013                                                  DataType::DT_INT32,
4014                                              }));
4015   int tf_axis = inputs.at(0).weights().GetSpan<int>()[0];
4016 
4017   int num_split;
4018   TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node_def), "num_split", &num_split));
4019 
4020   return ConvertSplitHelper(params, inputs.at(1), tf_axis, num_split, false);
4021 }
4022 
ConvertUnpack(OpConverterParams * params)4023 Status ConvertUnpack(OpConverterParams* params) {
4024   const auto& inputs = params->inputs;
4025   const auto& node_def = params->node_def;
4026   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"value", false}}));
4027   TF_RETURN_IF_ERROR(AllowDataTypes(*params, {
4028                                                  DataType::DT_FLOAT,
4029                                                  DataType::DT_HALF,
4030                                                  DataType::DT_INT32,
4031                                              }));
4032   // Input must be rank 1 or higher, since we can't unpack on axis 0.
4033   if (inputs.at(0).GetTrtDims().nbDims == 0) {
4034     return errors::Unimplemented(
4035         "Input \"value\" for Unpack must be rank 2 or greater");
4036   }
4037 
4038   int tf_axis = 0, num = 0;
4039   AttrSlice attrs(node_def);
4040   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "axis", &tf_axis));
4041   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "num", &num));
4042 
4043   return ConvertSplitHelper(params, inputs.at(0), tf_axis, num, true);
4044 }
4045 
ConvertCast(OpConverterParams * params)4046 Status ConvertCast(OpConverterParams* params) {
4047   auto unsupport_cast_error = [&](string msg) {
4048     return errors::Unimplemented("Cast op is not supported - ", msg);
4049   };
4050 
4051   if (isExperimentalFeatureActivated("reject_all_fp_cast_ops")) {
4052     LOG(WARNING) << "`TF_TRT_EXPERIMENTAL_FEATURES=reject_all_fp_cast_ops`is "
4053                  << "meant as a workaround. If the Cast converter leads to any "
4054                  << "performance or accuracy regression, please open an issue "
4055                  << "on GitHub.";
4056     return unsupport_cast_error(
4057         "TF_TRT_EXPERIMENTAL_FEATURES=reject_all_fp_cast_ops has been defined");
4058   }
4059 
4060   std::set<DataType> allowed_types{DataType::DT_FLOAT, DataType::DT_HALF};
4061 
4062   DataType input_type;
4063   TF_RETURN_IF_ERROR(GetInputTfType(*params, &input_type, 0));
4064 
4065   if (allowed_types.find(input_type) == allowed_types.end()) {
4066     return unsupport_cast_error(
4067         StrCat("Allowed input dtypes: [", DataTypeString(DataType::DT_FLOAT),
4068                ", ", DataTypeString(DataType::DT_HALF),
4069                "]. Received: ", DataTypeString(input_type)));
4070   }
4071 
4072   DataType output_type;
4073   TF_RETURN_IF_ERROR(GetNodeDefTfType(params->node_def, &output_type,
4074                                       kCastOutputTypeAttrName));
4075 
4076   if (allowed_types.find(output_type) == allowed_types.end()) {
4077     return unsupport_cast_error(
4078         StrCat("Allowed output dtypes: [", DataTypeString(DataType::DT_FLOAT),
4079                ", ", DataTypeString(DataType::DT_HALF),
4080                "]. Received: ", DataTypeString(output_type)));
4081   }
4082 
4083   return ConvertIdentity(params);
4084 }
4085 
ConvertConcat(OpConverterParams * params)4086 Status ConvertConcat(OpConverterParams* params) {
4087   const auto& inputs = params->inputs;
4088   const auto& node_def = params->node_def;
4089 
4090   int num_inputs{0};
4091   TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node_def), "N", &num_inputs));
4092 
4093   if (num_inputs != static_cast<int>(inputs.size()) - 1) {
4094     return errors::InvalidArgument(
4095         "Number of inputs for ConcatV2 is inconsistent with N attributes.");
4096   }
4097   // Validate inputs.
4098   std::vector<std::pair<string, TrtInputArg>> inputs_kinds;
4099   TrtInputArg expected_input =
4100       params->use_implicit_batch ? TrtInputArg::kTensor : TrtInputArg::kBoth;
4101 
4102   inputs_kinds.reserve(num_inputs);
4103   for (int i = 0; i < num_inputs; ++i) {
4104     inputs_kinds.push_back({StrCat("values_", i), expected_input});
4105   }
4106   inputs_kinds.push_back({"axis", TrtInputArg::kWeight});
4107   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, inputs_kinds));
4108 
4109   std::set<DataType> allowed_types{DataType::DT_FLOAT, DataType::DT_HALF,
4110                                    DataType::DT_INT32};
4111 
4112   TF_RETURN_IF_ERROR(AllowDataTypes(*params, allowed_types));
4113   const auto axis = inputs.at(num_inputs).weights().GetSpan<int>();
4114   if (axis.size() != 1) {
4115     return errors::InvalidArgument("Axis for ConcatV2 must be a scalar");
4116   }
4117   int trt_axis = 0;
4118   const auto dim = inputs.at(0).GetTrtDims();
4119   TF_RETURN_IF_ERROR(ConvertAxis(axis[0], dim.nbDims, node_def.name(),
4120                                  params->use_implicit_batch, &trt_axis));
4121   // Check that dimensions match on non-concatenate axis.
4122   TF_RETURN_IF_ERROR(VerifyShapesMatch(
4123       absl::Span<const TRT_TensorOrWeights>(inputs).first(num_inputs), trt_axis,
4124       node_def.name()));
4125   if (params->validation_only) return Status::OK();
4126 
4127   // Gather inputs as tensors
4128   std::vector<ITensorProxyPtr> input_tensors;
4129   input_tensors.reserve(num_inputs);
4130 
4131   for (int i = 0; i < num_inputs; i++) {
4132     if (inputs.at(i).is_tensor()) {
4133       input_tensors.push_back(inputs.at(i).tensor());
4134     } else {
4135       input_tensors.push_back(params->converter->CreateConstantLayer(
4136           inputs.at(i).weights(), inputs.at(i).GetTrtDims()));
4137     }
4138   }
4139   std::vector<nvinfer1::ITensor*> trt_input_tensors;
4140   for (const auto& t : input_tensors) {
4141     trt_input_tensors.push_back(t->trt_tensor());
4142   }
4143   nvinfer1::IConcatenationLayer* layer =
4144       params->converter->network()->addConcatenation(
4145           static_cast<nvinfer1::ITensor* const*>(trt_input_tensors.data()),
4146           input_tensors.size());
4147   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
4148   params->converter->SetLayerName(layer, node_def);
4149   layer->setAxis(trt_axis);
4150   params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
4151   return Status::OK();
4152 }
4153 
ConvertFusedBatchNorm(OpConverterParams * params)4154 Status ConvertFusedBatchNorm(OpConverterParams* params) {
4155   const auto& inputs = params->inputs;
4156   const auto& node_def = params->node_def;
4157   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false},
4158                                                   {"scale", true},
4159                                                   {"offset", true},
4160                                                   {"mean", true},
4161                                                   {"variance", true}}));
4162   TF_RETURN_IF_ERROR(
4163       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
4164 
4165   float epsilon{0.1f};
4166   string data_format;
4167   bool is_training{false};
4168   AttrSlice attrs(node_def);
4169   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "epsilon", &epsilon));
4170   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
4171   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "is_training", &is_training));
4172 
4173   if (is_training) {
4174     // Trying to use batchnorm in training mode is a very common problem.
4175     // Because the error message will only be printed in VLOG(1) by the
4176     // segmenter, we issue a special warning so that users will actually see it.
4177     LOG_WARNING_WITH_PREFIX
4178         << node_def.op() << " only supports is_training=false. If you "
4179         << "are using Keras, please call "
4180         << "keras.backend.set_learning_phase(0) before constructing "
4181         << "your model. At " << node_def.name();
4182     return errors::Unimplemented(node_def.op(),
4183                                  " only supports is_training=false");
4184   }
4185   ITensorProxyPtr tensor = inputs.at(0).tensor();
4186   if (!params->use_implicit_batch) {
4187     // This check is to make sure that channel dimension is known during
4188     // conversion.
4189     //
4190     // We check this only in explicit batch mode and reject an op with unknown
4191     // channel dimension during segmentation. In implicit batch mode we have
4192     // known shapes during conversion even though the shapes may not be known
4193     // during segmentation (see the actual argument for input_shapes when
4194     // ConvertGraphDefToEngine is called from TRTEngineOp::BuildEngine).
4195     int channel_dim = (data_format == "NCHW" ? 1 : 3);
4196     if (tensor->getDimensions().d[channel_dim] == -1) {
4197       return errors::InvalidArgument("Channel dimension must be static");
4198     }
4199   }
4200   //  Check parameter types
4201   auto parameter_type = inputs.at(1).weights().TrtDType();
4202   if ((parameter_type != nvinfer1::DataType::kFLOAT) &&
4203       (parameter_type != nvinfer1::DataType::kHALF)) {
4204     return errors::Unimplemented(
4205         "Only float32 or float16 weight data type is supported,", " got ",
4206         DebugString(parameter_type));
4207   }
4208   for (int i = 1; i < 5; i++) {
4209     if (inputs.at(i).weights().TrtDType() != parameter_type) {
4210       return errors::Unimplemented(
4211           "Inconsistent parameter type for batchnorm is not supported");
4212     }
4213   }
4214 
4215   TRT_ShapedWeights dummy_power_weights(parameter_type);
4216   size_t nweight = 0;
4217   for (int i = 1; i < 5; i++) {
4218     nweight = std::max<size_t>(nweight, inputs.at(i).weights().count());
4219   }
4220   const TRT_ShapedWeights* ptr_shape_weights = nullptr;
4221   for (int i = 1; i < 5; i++) {
4222     if (inputs.at(i).weights().count() == nweight) {
4223       ptr_shape_weights = &(inputs.at(i).weights());
4224     } else if (inputs.at(i).weights().count() != 1) {
4225       return errors::InvalidArgument("Inconsistent batchnorm parameter count");
4226     }
4227   }
4228   if (params->validation_only) return Status::OK();
4229 
4230   //  We could technically have two weights with different shape.
4231   //  that requires two addScale op, arguably less performant
4232   StatusOr<TRT_ShapedWeights> combined_scale_weights =
4233       params->weight_store->GetTempWeights(*ptr_shape_weights);
4234   TRT_ENSURE_OK(combined_scale_weights);
4235   StatusOr<TRT_ShapedWeights> combined_offset_weights =
4236       params->weight_store->GetTempWeights(*ptr_shape_weights);
4237   TRT_ENSURE_OK(combined_offset_weights);
4238 
4239   const Eigen::half* cast_vals_array[4];
4240   const float* vals_array[4];
4241   for (int j = 0; j < 4; j++) {
4242     cast_vals_array[j] = inputs.at(j + 1).weights().GetPointer<Eigen::half>();
4243     vals_array[j] = inputs.at(j + 1).weights().GetPointer<float>();
4244   }
4245   Eigen::half* cast_combined_scale_vals =
4246       combined_scale_weights->GetPointer<Eigen::half>();
4247   Eigen::half* cast_combined_offset_vals =
4248       combined_offset_weights->GetPointer<Eigen::half>();
4249   float* combined_scale_vals = combined_scale_weights->GetPointer<float>();
4250   float* combined_offset_vals = combined_offset_weights->GetPointer<float>();
4251 
4252   for (size_t i = 0; i < nweight; ++i) {
4253     float batchnorm_data[4];
4254     for (int j = 0; j < 4; j++) {
4255       if (inputs.at(j + 1).weights().count() != 1) {
4256         if (parameter_type == nvinfer1::DataType::kFLOAT) {
4257           batchnorm_data[j] = vals_array[j][i];
4258         } else if (parameter_type == nvinfer1::DataType::kHALF) {
4259           batchnorm_data[j] = static_cast<float>(cast_vals_array[j][i]);
4260         }
4261       } else {
4262         if (parameter_type == nvinfer1::DataType::kFLOAT) {
4263           batchnorm_data[j] = vals_array[j][0];
4264         } else if (parameter_type == nvinfer1::DataType::kHALF) {
4265           batchnorm_data[j] = static_cast<float>(cast_vals_array[j][0]);
4266         }
4267       }
4268     }
4269     float scale = batchnorm_data[0];
4270     float offset = batchnorm_data[1];
4271     float mean = batchnorm_data[2];
4272     float variance = batchnorm_data[3];
4273     float combined_scale_val = scale / sqrtf(variance + epsilon);
4274     float combined_offset_val = offset - mean * combined_scale_val;
4275     if (parameter_type == nvinfer1::DataType::kFLOAT) {
4276       combined_scale_vals[i] = combined_scale_val;
4277       combined_offset_vals[i] = combined_offset_val;
4278     } else if (parameter_type == nvinfer1::DataType::kHALF) {
4279       cast_combined_scale_vals[i] = Eigen::half(combined_scale_val);
4280       cast_combined_offset_vals[i] = Eigen::half(combined_offset_val);
4281     }
4282   }
4283 
4284   ITensorProxyPtr output_tensor;
4285 
4286   if (data_format == "NCHW") {
4287     // IScaleLayer CHANNEL mode requires NCHW format.
4288     nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL;
4289     nvinfer1::IScaleLayer* layer = params->converter->network()->addScale(
4290         *tensor->trt_tensor(), mode, combined_offset_weights->GetTrtWeights(),
4291         combined_scale_weights->GetTrtWeights(),
4292         nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0});
4293     TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
4294     params->converter->SetLayerName(layer, node_def);
4295     output_tensor = layer->getOutput(0);
4296   }
4297   if (data_format == "NHWC") {
4298     // nweight is the number of channels. TensorRT IElementWiseLayer supports
4299     // implicit broadcasting for dimensions of size 1.
4300     nvinfer1::Dims dims = tensor->getDimensions();
4301     for (int i = 0; i < dims.nbDims - 1; i++) {
4302       dims.d[i] = 1;
4303     }
4304     dims.d[dims.nbDims - 1] = nweight;
4305     StatusOr<TRTNetworkBuilder> builder = TRTNetworkBuilder::Create(
4306         params->converter->network(), params->weight_store);
4307     TRT_ENSURE_OK(builder);
4308     auto scale_constant_layer = builder->WeightsToConstant(
4309         combined_scale_weights->GetTrtWeights(), dims);
4310     ITensorProxyPtr scale_constant = (*scale_constant_layer)->getOutput(0);
4311     auto scale_layer =
4312         builder->Mul(tensor->trt_tensor(), scale_constant->trt_tensor());
4313     auto offset_constant_layer = builder->WeightsToConstant(
4314         combined_offset_weights->GetTrtWeights(), dims);
4315     ITensorProxyPtr offset_constant = (*offset_constant_layer)->getOutput(0);
4316     auto offset_layer = builder->Add((*scale_layer)->getOutput(0),
4317                                      offset_constant->trt_tensor());
4318     output_tensor = (*offset_layer)->getOutput(0);
4319   }
4320 
4321   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
4322   return Status::OK();
4323 }
4324 
ConvertGather(OpConverterParams * params)4325 Status ConvertGather(OpConverterParams* params) {
4326   const auto& inputs = params->inputs;
4327   const auto& node_def = params->node_def;
4328   // TODO(tmorris): Use CheckInputsWeights by changing bool to enum with an
4329   // option for an input to be either tensor or weight.
4330   TF_RETURN_IF_ERROR(
4331       CheckInputsWeights(*params, {{"params", TrtInputArg::kBoth},
4332                                    {"indices", TrtInputArg::kBoth},
4333                                    {"axis", TrtInputArg::kWeight}}));
4334 
4335   const auto& params_input = inputs.at(0);
4336   const auto& indices_input = inputs.at(1);
4337   const auto& axis_input = inputs.at(2);
4338 
4339   TF_RETURN_IF_ERROR(AllowDataTypes(
4340       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32},
4341       /*dtype_attr_name=*/"Tparams"));
4342   TF_RETURN_IF_ERROR(AllowDataTypes(*params, {DataType::DT_INT32},
4343                                     /*dtype_attr_name=*/"Tindices"));
4344 
4345   absl::Span<const int> axis = axis_input.weights().GetSpan<int>();
4346   if (axis.size() != 1) {
4347     return errors::InvalidArgument("Axis for GatherV2 must be a scalar");
4348   }
4349   int trt_axis = 0;
4350   TF_RETURN_IF_ERROR(ConvertAxis(
4351       axis[0], params_input.GetTrtDims().nbDims, node_def.name(),
4352       params->use_implicit_batch && params_input.is_tensor(), &trt_axis));
4353   if (params->use_implicit_batch && params_input.is_weights() &&
4354       trt_axis != 0) {
4355     return errors::Unimplemented(
4356         "The input axis must be zero when params is a weight.");
4357   }
4358   if (params->use_implicit_batch &&
4359       (params_input.is_tensor() == indices_input.is_tensor()) &&
4360       (indices_input.batch_size() != 1 || params_input.batch_size() != 1)) {
4361     return errors::Unimplemented(
4362         "Params and indices must have a batch size of 1 when params and indices"
4363         " are both tensors or both constants.");
4364   }
4365 
4366   auto get_rank = [params](const auto& input) {
4367     return input.GetTrtDims().nbDims +
4368            (params->use_implicit_batch && input.is_tensor() ? 1 : 0);
4369   };
4370   // Both input are tensors, and the TF gather result will have rank:
4371   // (params.nbDims + 1) + (indices.nbDims + 1) - 1,
4372   // where "+ 1" adds the batch dim. If params is a weight, the TRT rank matches
4373   // the TF rank so we don't have to add + 1.
4374   const int params_tf_rank = get_rank(params_input);
4375   const int indices_tf_rank = get_rank(indices_input);
4376   const int tf_gather_output_rank = params_tf_rank + indices_tf_rank - 1;
4377   if (tf_gather_output_rank >
4378       nvinfer1::Dims::MAX_DIMS + (params->use_implicit_batch ? 1 : 0)) {
4379     return errors::InvalidArgument(
4380         "Result of gather has dimension greater than ",
4381         nvinfer1::Dims::MAX_DIMS + 1);
4382   }
4383   if (params->validation_only) return Status::OK();
4384 
4385   // Convert input or indices to tensor if it is a constant.
4386   auto populate_tensor = [params](const auto& input) -> ITensorProxyPtr {
4387     ITensorProxyPtr result_tensor = nullptr;
4388 
4389     if (input.is_weights()) {
4390       result_tensor = params->converter->CreateConstantLayer(
4391           input.weights(), input.GetTrtDims());
4392     } else {
4393       result_tensor = input.tensor();
4394     }
4395 
4396     return result_tensor;
4397   };
4398 
4399   ITensorProxyPtr params_tensor = populate_tensor(params_input);
4400   ITensorProxyPtr indices_tensor = populate_tensor(indices_input);
4401 
4402   // Note on how IGatherLayer works: if both the data and indices tensors have
4403   // a batch size dimension of size N, it performs:
4404   // for batchid in xrange(N):
4405   //   output[batchid, a0, ..., an, i, ..., j, b0, ..., bn] = (
4406   //       data[batchid, a0, ..., an, indices[batchid, i, ..., j] b0, ..., bn])
4407   nvinfer1::IGatherLayer* layer = params->converter->network()->addGather(
4408       *params_tensor->trt_tensor(), *indices_tensor->trt_tensor(), trt_axis);
4409   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
4410   params->converter->SetLayerName(layer, node_def);
4411 
4412   ITensorProxyPtr output_tensor = layer->getOutput(0);
4413   nvinfer1::Dims trt_gather_output_dims = output_tensor->getDimensions();
4414 
4415   if (params->use_implicit_batch) {
4416     // Note for the "- 2": one is for the output batch dim encapsulated by
4417     // TF-TRT, and the other is for the output dimension that is squeezed by
4418     // IGatherLayer because of the implicit batch dim in the indices (see the
4419     // above note).
4420     const int expected_trt_output_rank = tf_gather_output_rank -
4421                                          (params_input.is_tensor() ? 1 : 0) -
4422                                          (indices_input.is_tensor() ? 1 : 0);
4423 
4424     if (trt_gather_output_dims.nbDims != expected_trt_output_rank) {
4425       return errors::Internal(
4426           "Get unexpected output dimensions of IGatherLayer. Expect nbDims: ",
4427           expected_trt_output_rank,
4428           ", actual nbDims: ", trt_gather_output_dims.nbDims);
4429     }
4430   }
4431   // Reshape the output so after adding the implicit batch dim it'll match the
4432   // output shape of TF GatherV2.
4433   if (params->use_implicit_batch && params_input.is_tensor() &&
4434       indices_input.is_tensor()) {
4435     for (int i = trt_gather_output_dims.nbDims; i > trt_axis; --i) {
4436       trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i - 1];
4437     }
4438     trt_gather_output_dims.d[trt_axis] = 1;
4439     ++trt_gather_output_dims.nbDims;
4440 
4441     TF_RETURN_IF_ERROR(PrepareTensorForShape(
4442         params->converter, TRT_TensorOrWeights(output_tensor),
4443         trt_gather_output_dims,
4444         /*validation_only=*/false, &output_tensor, node_def));
4445   }
4446 
4447   // When input and indices are both constants, for the supported cases, reshape
4448   // output so that after removing the implicit batch dim it will match the
4449   // output shape of TF GatherV2 op.
4450   if (params->use_implicit_batch && params_input.is_weights() &&
4451       indices_input.is_weights()) {
4452     for (int i = trt_axis; i < trt_gather_output_dims.nbDims - 1; ++i) {
4453       trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i + 1];
4454     }
4455 
4456     // Squeeze the implicit batch dimension out. Note: this works only
4457     // when batch size for both inputs and indices are 1.
4458     --trt_gather_output_dims.nbDims;
4459 
4460     TF_RETURN_IF_ERROR(PrepareTensorForShape(
4461         params->converter, TRT_TensorOrWeights(output_tensor),
4462         trt_gather_output_dims,
4463         /*validation_only=*/false, &output_tensor, node_def));
4464   }
4465 
4466   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
4467   return Status::OK();
4468 }
4469 
4470 // Converts the input matrix multiplication node to a fully connected (FC) layer
4471 // if possible, as the FC layer has more tactics and INT implementations.
4472 // Returns the output ITensor* if the node is converted or nullptr if conversion
4473 // is not possible. An error status indicates internal problems during
4474 // conversion.
ConvertFullyConnectedImpl(OpConverterParams * params,TRT_TensorOrWeights input_a,TRT_TensorOrWeights input_b,bool transpose_a,bool transpose_b)4475 StatusOr<ITensorProxyPtr> ConvertFullyConnectedImpl(OpConverterParams* params,
4476                                                     TRT_TensorOrWeights input_a,
4477                                                     TRT_TensorOrWeights input_b,
4478                                                     bool transpose_a,
4479                                                     bool transpose_b) {
4480   if (!(!transpose_a && input_a.is_tensor() && input_b.is_weights())) {
4481     VLOG(2) << "Not FC compatible, A must be non transposed tensor, and B "
4482                "must be constant.";
4483     return ITensorProxyPtr(nullptr);
4484   }
4485 
4486   if (!params->use_implicit_batch && input_b.GetTrtDims().nbDims > 2 &&
4487       input_b.GetTrtDims().d[0] != 1) {
4488     // Implicit broadcasting, if needed, has already been considered to
4489     // transform the inputs and ensure the two operands have the same rank here.
4490     // If the inputs have rank >= 3, then d[0] is the explicit batch dimension.
4491     // The weight (input_b) must have batch size 1 in implicit batch mode.
4492     VLOG(2) << "Not FC compatible, if B has an explicit batch dimension, then "
4493                "it must be 1.";
4494     return ITensorProxyPtr(nullptr);
4495   }
4496 
4497   nvinfer1::Dims input_dim = input_a.GetTrtDims();
4498   if (input_dim.d[input_dim.nbDims - 1] == -1) {
4499     VLOG(2) << "Not FC compatible, last dim of A must be static.";
4500     return ITensorProxyPtr(nullptr);
4501   }
4502 
4503   if (input_dim.nbDims + 2 > nvinfer1::Dims::MAX_DIMS) {
4504     VLOG(2) << "Not FC compatible, cannot expand A's shape.";
4505     return ITensorProxyPtr(nullptr);
4506   }
4507 
4508   // Add two trailing 1's because FC layer combines the last three dims.
4509   ITensorProxyPtr tensor_a = nullptr;
4510 
4511   // Initialize the elements of reshap_dim to 0. A value 0 in
4512   // reshape_dim(i) will preserve the i-th dimension value from the shape of
4513   // input_a. Add two trailing dimensions of size 1.
4514   auto reshape_dim = DimsAdapter(input_dim.nbDims,
4515                                  DimsAdapter::StorageType(input_dim.nbDims, 0))
4516                          .Append(1)
4517                          .Append(1);
4518 
4519   const NodeDef& node_def = params->node_def;
4520   TF_RETURN_IF_ERROR(PrepareTensorForShape(
4521       params->converter, input_a, reshape_dim,
4522       /*validation_only=*/false, &tensor_a, node_def, /*op_instance=*/0,
4523       /*origin_node_name=*/"FULLY_CONNECTED"));
4524 
4525   VLOG(2) << "New shape of A " << DebugString(tensor_a->getDimensions());
4526 
4527   TRT_ShapedWeights weights_b = input_b.weights();
4528   TRT_ShapedWeights weights_2D(weights_b);
4529   if (weights_b.Shape().NumDims() > 2) {
4530     // Combine first nbDims-1 dims into a single dim, e.g. for a 4D tensor we
4531     // transform [N, H, W, C] -> [N*H*W, C]. This is only valid if all batch
4532     // dimensions are 1.
4533     if (std::any_of(weights_b.Shape().begin(),
4534                     weights_b.Shape().begin() + weights_b.Shape().NumDims() - 2,
4535                     [](int d) { return d != 1; })) {
4536       VLOG(2) << "Not FC compatible, B has a batch dim larger than 1";
4537       return ITensorProxyPtr(nullptr);
4538     }
4539     int k = weights_b.Shape().dim(weights_b.Shape().NumDims() - 1);
4540     nvinfer1::Dims dims{2, {static_cast<int>(weights_b.count() / k), k}};
4541     TF_RETURN_IF_ERROR(weights_2D.SetShape(dims));
4542   }
4543 
4544   // FC layer will transpose weights, so we need to pre-transpose.
4545   TRT_ShapedWeights weights(weights_2D.TrtDType());
4546   if (!transpose_b) {
4547     auto tmp = params->weight_store->GetTempWeights(weights_2D);
4548     TRT_ENSURE_OK(tmp);
4549     weights = std::move(tmp).value();
4550     ReorderCKtoKC(weights_2D, &weights);
4551   } else {
4552     weights = weights_2D;
4553   }
4554   TRT_ShapedWeights biases(weights.TrtDType());
4555   int k = weights.Shape().dim(weights.Shape().NumDims() - 1);
4556   const int noutput = weights.count() / k;
4557   VLOG(2) << "Using fully connected layer with k=" << k
4558           << ", n_output=" << noutput
4559           << " weights shape: " << weights.Shape().DebugString()
4560           << " to convert " << node_def.op();
4561   nvinfer1::IFullyConnectedLayer* layer =
4562       params->converter->network()->addFullyConnected(
4563           *tensor_a->trt_tensor(), noutput, weights.GetTrtWeights(),
4564           biases.GetTrtWeights());
4565 
4566   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
4567   params->converter->SetLayerName(layer, node_def);
4568   ITensorProxyPtr output_tensor = layer->getOutput(0);
4569 
4570   // A fully connected layer produces output with two trailing singleton
4571   // dimensions. We remove these.
4572   auto output_dim = output_tensor->getDimensions();
4573   output_dim.nbDims -= 2;
4574   // A zero in output_dim indicates copying the corresponding input dimension
4575   // value during reshape.
4576   std::fill(output_dim.d, output_dim.d + output_dim.nbDims, 0);
4577   TF_RETURN_IF_ERROR(PrepareTensorForShape(
4578       params->converter, TRT_TensorOrWeights(output_tensor), output_dim,
4579       /*validation_only=*/false, &output_tensor, node_def,
4580       /*op_instance=*/1, /*origin_node_name=*/"FULLY_CONNECTED"));
4581   return output_tensor;
4582 }
4583 
ConvertMatMulImpl(OpConverterParams * params,TRT_TensorOrWeights input_a,TRT_TensorOrWeights input_b,bool transpose_a,bool transpose_b)4584 StatusOr<ITensorProxyPtr> ConvertMatMulImpl(OpConverterParams* params,
4585                                             TRT_TensorOrWeights input_a,
4586                                             TRT_TensorOrWeights input_b,
4587                                             bool transpose_a,
4588                                             bool transpose_b) {
4589   if (params->use_implicit_batch) {
4590     // In implicit batch mode we are very limited when can we multiply 2D
4591     // matrices. If input_A is a 2D tensor, then nbDims==1 (implicit batch dim
4592     // not counted). If A is not transposed and B is weight, then we can convert
4593     // this treating A as a batch of vectors. This is the only possibility
4594     // to implement MatMul with 2D input in implicit batch mode.
4595     if ((input_a.GetTrtDims().nbDims < 2 &&
4596          (transpose_a || !input_b.is_weights())) ||
4597         (input_b.GetTrtDims().nbDims < 2)) {
4598       return errors::InvalidArgument(
4599           "MatMul with 2D tensors requires explicit batch mode, or that tensor"
4600           " A is not transposed and B is a constant tensor.");
4601     }
4602   }
4603 
4604   if (params->validation_only) return ITensorProxyPtr(nullptr);
4605 
4606   StatusOr<ITensorProxyPtr> result = ConvertFullyConnectedImpl(
4607       params, input_a, input_b, transpose_a, transpose_b);
4608   TF_RETURN_IF_ERROR(result.status());
4609   ITensorProxyPtr output = result.ValueOrDie();
4610   if (*output) {
4611     // FC conversion was successful, we can return.
4612     return output;
4613   }
4614   const auto convert_to_itensor =
4615       [&params](TRT_TensorOrWeights operand) -> ITensorProxyPtr {
4616     if (operand.is_tensor()) {
4617       return operand.tensor();
4618     } else {
4619       return params->converter->CreateConstantLayer(operand.weights(),
4620                                                     operand.GetTrtDims());
4621     }
4622   };
4623 
4624   ITensorProxyPtr tensor_a = convert_to_itensor(input_a);
4625   ITensorProxyPtr tensor_b = convert_to_itensor(input_b);
4626 
4627   const auto get_matrix_op = [](ITensorProxyPtr in,
4628                                 bool transpose) -> nvinfer1::MatrixOperation {
4629     return (transpose) ? nvinfer1::MatrixOperation::kTRANSPOSE
4630                        : nvinfer1::MatrixOperation::kNONE;
4631   };
4632   nvinfer1::MatrixOperation op_a, op_b;
4633   // Note: In implicit batch mode kTRANSPOSE and kNONE are only valid if the
4634   // matrix has at least 2 non-batch dimension. In implicit batch mode, if a has
4635   // 1 dim (excluding batch dim), then we can only use kVECTOR, which will treat
4636   // matrix A as a batch of vectors.
4637   op_a = (tensor_a->getDimensions().nbDims < 2)
4638              ? nvinfer1::MatrixOperation::kVECTOR
4639              : get_matrix_op(tensor_a, transpose_a);
4640   // In implicit batch mode, if B has only 1 dims (excluding batch dim) then we
4641   // already reject the case and don't convert. One could consider using the
4642   // kVECTOR flag to express C = MatMul(A, B.T) if A is weight, but the result
4643   // will not have the correct shape: in TRT's implicit batch implementation,
4644   // the result is a batch of vectors D_ji = A_ik * B_jk, where j is the batch
4645   // dimension. In contrast, the TF MatMul op produces C = D.T, and we cannot
4646   // transpose over the batch dimension (implicit batch mode).
4647   op_b = get_matrix_op(tensor_b, transpose_b);
4648 
4649   nvinfer1::IMatrixMultiplyLayer* layer =
4650       params->converter->network()->addMatrixMultiply(
4651           *tensor_a->trt_tensor(), op_a, *tensor_b->trt_tensor(), op_b);
4652 
4653   const auto& node_def = params->node_def;
4654   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
4655   params->converter->SetLayerName(layer, node_def);
4656   return ITensorProxyPtr(layer->getOutput(0));
4657 }
4658 
ConvertMatMulHelper(OpConverterParams * params,TRT_TensorOrWeights input_a,TRT_TensorOrWeights input_b,bool transpose_a,bool transpose_b)4659 Status ConvertMatMulHelper(OpConverterParams* params,
4660                            TRT_TensorOrWeights input_a,
4661                            TRT_TensorOrWeights input_b, bool transpose_a,
4662                            bool transpose_b) {
4663   StatusOr<ITensorProxyPtr> result =
4664       ConvertMatMulImpl(params, input_a, input_b, transpose_a, transpose_b);
4665   TF_RETURN_IF_ERROR(result.status());
4666   if (!params->validation_only) {
4667     params->outputs->push_back(TRT_TensorOrWeights(result.ValueOrDie()));
4668   }
4669   return Status::OK();
4670 }
4671 
4672 // inputs are both two dimensional (ops::MatMul)
ConvertMatMul(OpConverterParams * params)4673 Status ConvertMatMul(OpConverterParams* params) {
4674   const auto& inputs = params->inputs;
4675   const auto& node_def = params->node_def;
4676   TFTRT_CHECK_INPUT_SIZE(inputs.size(), 2, node_def);
4677 
4678   TF_RETURN_IF_ERROR(
4679       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
4680 
4681   bool transpose_a = false, transpose_b = false;
4682   AttrSlice attrs(node_def);
4683   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "transpose_a", &transpose_a));
4684   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "transpose_b", &transpose_b));
4685 
4686   return ConvertMatMulHelper(params, inputs.at(0), inputs.at(1), transpose_a,
4687                              transpose_b);
4688 }
4689 
ConvertBatchMatMul(OpConverterParams * params)4690 Status ConvertBatchMatMul(OpConverterParams* params) {
4691   const auto& inputs = params->inputs;
4692   const auto& node_def = params->node_def;
4693   TFTRT_CHECK_INPUT_SIZE(inputs.size(), 2, node_def);
4694 
4695   TF_RETURN_IF_ERROR(CheckInputsWeights(
4696       *params, {{"x", TrtInputArg::kBoth}, {"y", TrtInputArg::kBoth}}));
4697   // TODO(tfeher): Consider adding INT8 type because FC layer can support it.
4698   TF_RETURN_IF_ERROR(
4699       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
4700   if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) {
4701     // TODO(lsugy): don't assume that if all inputs are weights, grappler
4702     // should fold them, because variables are weights.
4703     return errors::InvalidArgument(
4704         "All inputs are weights, but Grappler is expected to fold them.");
4705   }
4706 
4707   bool transpose_a = false, transpose_b = false;
4708   AttrSlice attrs(node_def);
4709   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "adj_x", &transpose_a));
4710   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "adj_y", &transpose_b));
4711 
4712   // In case input_l is weight, check whether input_l has implicit batch mode
4713   // compatible batch dim.
4714   const auto check_weight_is_not_batched =
4715       [](const TRT_TensorOrWeights& input_l,
4716          const TRT_TensorOrWeights& input_r) {
4717         // There is no way to batch constants in TRT using implicit batch mode.
4718         // Example:
4719         // Tensor with TF Dims: 12 5 3 -> TRT Dims: 5 3
4720         // Weight with TF Dims: 12 3 6 -> TRT Dims: 12 3 6
4721         // It is not possible to treat the weight input as a batched [3, 6]
4722         // tensor. Batched weight tensors must have batch dim = 1 (after the
4723         // broadcast).
4724         if (input_l.is_weights() &&
4725             input_l.GetTrtDims().nbDims > input_r.GetTrtDims().nbDims &&
4726             input_l.GetTrtDims().d[0] != 1) {
4727           return errors::Unimplemented(
4728               "TensorRT does not support batched constants in implicit batch "
4729               "mode.");
4730         }
4731         return Status::OK();
4732       };
4733   if (params->use_implicit_batch) {
4734     TF_RETURN_IF_ERROR(check_weight_is_not_batched(inputs.at(0), inputs.at(1)));
4735     TF_RETURN_IF_ERROR(check_weight_is_not_batched(inputs.at(1), inputs.at(0)));
4736   }
4737 
4738   // Broadcast inputs. We don't check feasibility since the dimensions in a
4739   // MatMul don't need to match. For example, consider a valid set of inputs
4740   // which would produce an output of shape [N, T, K]:
4741   // input 0: [N, T, C]
4742   // input 1: [1, C, K]
4743   // Since C != K and T != C, check feasiblity would fail.
4744   auto input_l = std::make_unique<TRT_TensorOrWeights>(inputs.at(0));
4745   auto input_r = std::make_unique<TRT_TensorOrWeights>(inputs.at(1));
4746   TF_RETURN_IF_ERROR(BroadcastTensors(input_l, input_r,
4747                                       /*check_feasibility=*/false, params));
4748 
4749   if (params->validation_only) return Status::OK();
4750 
4751   return ConvertMatMulHelper(params, *input_l, *input_r, transpose_a,
4752                              transpose_b);
4753 }
4754 
ConvertSoftmax(OpConverterParams * params)4755 Status ConvertSoftmax(OpConverterParams* params) {
4756   const auto& inputs = params->inputs;
4757   const auto& node_def = params->node_def;
4758   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"logits", false}}));
4759   TF_RETURN_IF_ERROR(
4760       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
4761   ITensorProxyPtr tensor = inputs.at(0).tensor();
4762 
4763   const int num_trt_dims = tensor->getDimensions().nbDims;
4764   if (num_trt_dims == 0 && params->use_implicit_batch) {
4765     return errors::InvalidArgument(
4766         "TensorRT Softmax cannot apply on batch dimension");
4767   }
4768   if (params->validation_only) return Status::OK();
4769 
4770   nvinfer1::ISoftMaxLayer* layer =
4771       params->converter->network()->addSoftMax(*tensor->trt_tensor());
4772   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
4773   params->converter->SetLayerName(layer, node_def);
4774   // Tensorflow SoftMax assumes applying softmax on the last dimension.
4775   layer->setAxes(1 << (num_trt_dims - 1));
4776 
4777   ITensorProxyPtr output_tensor = layer->getOutput(0);
4778   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
4779   return Status::OK();
4780 }
4781 
ConvertArgMinMax(OpConverterParams * params)4782 Status ConvertArgMinMax(OpConverterParams* params) {
4783   const auto& inputs = params->inputs;
4784   const auto& node_def = params->node_def;
4785   TF_RETURN_IF_ERROR(
4786       CheckInputsWeights(*params, {{"input", false}, {"dimension", true}}));
4787   TF_RETURN_IF_ERROR(
4788       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
4789 
4790   DataType output_dtype{DataType::DT_INT32};
4791   TF_RETURN_IF_ERROR(
4792       GetNodeAttr(AttrSlice(node_def), "output_type", &output_dtype));
4793 
4794   if (output_dtype != DataType::DT_INT32) {
4795     return errors::Unimplemented("Output type ", DataTypeString(output_dtype),
4796                                  " is not supported");
4797   }
4798   int tf_axis = inputs.at(1).weights().GetSpan<int>()[0];
4799   int trt_axis;
4800   nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
4801   TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(),
4802                                  params->use_implicit_batch, &trt_axis));
4803   nvinfer1::TopKOperation topk_op;
4804   if (node_def.op() == "ArgMin") {
4805     topk_op = nvinfer1::TopKOperation::kMIN;
4806   } else if (node_def.op() == "ArgMax") {
4807     topk_op = nvinfer1::TopKOperation::kMAX;
4808   } else {
4809     return errors::InvalidArgument("Unsupported ArgMin/Max operation");
4810   }
4811 
4812 #if !IS_TRT_VERSION_GE(7, 0, 0, 11)
4813   const nvinfer1::Dims trt_dims = params->inputs.at(0).GetTrtDims();
4814   if (trt_dims.nbDims >= 4) {
4815     string trt_dim_str = DebugString(trt_dims);
4816 
4817     return errors::Unimplemented(node_def.op(), "op is not able to support",
4818                                  " tensors with 4+ dimensions (excluding batch",
4819                                  " size). Received: ", trt_dim_str);
4820   }
4821 #endif
4822 
4823   if (params->validation_only) return Status::OK();
4824 
4825   // Use TopK with k = 1. Only indices output is needed (output 1).
4826   const uint32_t reduce_axes = 1 << trt_axis;
4827   nvinfer1::ITopKLayer* layer = params->converter->network()->addTopK(
4828       *inputs.at(0).tensor()->trt_tensor(), topk_op, 1, reduce_axes);
4829   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
4830   params->converter->SetLayerName(layer, node_def, "topk");
4831   ITensorProxyPtr output_indices_tensor = layer->getOutput(1);
4832 
4833   // Squeeze on axis.
4834   std::vector<int> input_dims(dims.d, dims.d + dims.nbDims);
4835   input_dims[trt_axis] = 0;
4836   ITensorProxyPtr output_tensor = nullptr;
4837   TF_RETURN_IF_ERROR(params->converter->SqueezeTensor(
4838       /*input=*/output_indices_tensor,
4839       /*input_dims=*/&input_dims,
4840       /*params=*/params,
4841       /*output=*/&output_tensor));
4842   params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
4843 
4844   return Status::OK();
4845 }
4846 
ConvertTopK(OpConverterParams * params)4847 Status ConvertTopK(OpConverterParams* params) {
4848   const auto& inputs = params->inputs;
4849   const auto& node_def = params->node_def;
4850   TF_RETURN_IF_ERROR(
4851       CheckInputsWeights(*params, {{"input", false}, {"k", true}}));
4852   TF_RETURN_IF_ERROR(
4853       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
4854   bool sorted{false};
4855   TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node_def), "sorted", &sorted));
4856 
4857   if (!sorted) {
4858     // TensorRT only supports sorted output. Although TensorFlow API
4859     // doesn't specify the order of output elements in case sorted=false,
4860     // but it's safer to not convert because the output of TensorRT might
4861     // be different with TensorFlow which can cause confusion.
4862     return errors::InvalidArgument("Only sorted=True is supported");
4863   }
4864 
4865   ITensorProxyPtr tensor = inputs.at(0).tensor();
4866   const int num_dims = tensor->getDimensions().nbDims;
4867   if (num_dims == 0) {
4868     return errors::InvalidArgument(
4869         "TensorRT TopK cannot apply on batch dimension");
4870   }
4871 
4872   TRT_ShapedWeights k_w = inputs.at(1).weights();
4873   if (k_w.count() != 1) {
4874     return errors::InvalidArgument("k value of TopK should be a scalar");
4875   }
4876   // Note that ITopKLayer always have sorted outputs, so we don't need to handle
4877   // the 'sorted' attribute of the node.
4878   if (params->validation_only) return Status::OK();
4879 
4880   const nvinfer1::TopKOperation op = nvinfer1::TopKOperation::kMAX;
4881   const int k = *(k_w.GetPointer<int>());
4882   const uint32_t reduce_axes = 1 << (num_dims - 1);
4883   nvinfer1::ITopKLayer* layer = params->converter->network()->addTopK(
4884       *tensor->trt_tensor(), op, k, reduce_axes);
4885   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
4886   params->converter->SetLayerName(layer, node_def);
4887 
4888   ITensorProxyPtr output_value_tensor = layer->getOutput(0);
4889   ITensorProxyPtr output_indices_tensor = layer->getOutput(1);
4890   params->outputs->push_back(TRT_TensorOrWeights(output_value_tensor));
4891   params->outputs->push_back(TRT_TensorOrWeights(output_indices_tensor));
4892   return Status::OK();
4893 }
4894 
4895 StatusOr<std::pair<ITensorProxyPtr, ITensorProxyPtr>>
CalcDepthSpaceDynamicShape(OpConverterParams * params,int block_size,string data_format)4896 CalcDepthSpaceDynamicShape(OpConverterParams* params, int block_size,
4897                            string data_format) {
4898   // Instead we use a shape layer and shape arithmetic to calculate the reshape
4899   // dimensions.
4900   const auto& inputs = params->inputs;
4901   const auto& node_def = params->node_def;
4902 
4903   const int channels_axis = data_format == "NCHW" ? 1 : 3;
4904   const int h_axis = data_format == "NCHW" ? 2 : 1;
4905   const int w_axis = data_format == "NCHW" ? 3 : 2;
4906 
4907   // Get shapes.
4908   ITensorProxyPtr shape = params->converter->network()
4909                               ->addShape(*inputs.at(0).tensor()->trt_tensor())
4910                               ->getOutput(0);
4911   ITensorProxyPtr batch_size =
4912       params->converter->network()
4913           ->addSlice(*shape->trt_tensor(), {1, {0}}, {1, {1}}, {1, {1}})
4914           ->getOutput(0);
4915   ITensorProxyPtr num_channels =
4916       params->converter->network()
4917           ->addSlice(*shape->trt_tensor(), {1, {channels_axis}}, {1, {1}},
4918                      {1, {1}})
4919           ->getOutput(0);
4920   ITensorProxyPtr h =
4921       params->converter->network()
4922           ->addSlice(*shape->trt_tensor(), {1, {h_axis}}, {1, {1}}, {1, {1}})
4923           ->getOutput(0);
4924   ITensorProxyPtr w =
4925       params->converter->network()
4926           ->addSlice(*shape->trt_tensor(), {1, {w_axis}}, {1, {1}}, {1, {1}})
4927           ->getOutput(0);
4928   ITensorProxyPtr r;
4929   TF_RETURN_IF_ERROR(CreateScalarConstant(params, block_size, &r));
4930   ITensorProxyPtr r_squared;
4931   TF_RETURN_IF_ERROR(
4932       CreateScalarConstant(params, block_size * block_size, &r_squared));
4933   // Get shuffle parameters.
4934   std::vector<ITensorProxyPtr> first_shuffle_tensors(6, nullptr);
4935   std::vector<ITensorProxyPtr> second_shuffle_tensors(4, nullptr);
4936   if (node_def.op() == "DepthToSpace") {
4937     // First Reshape [N, C, H, W] - > [N, r, r, C/(r*r), H, W].
4938     first_shuffle_tensors[0] = batch_size;
4939     first_shuffle_tensors[1] = r;
4940     first_shuffle_tensors[2] = r;
4941     first_shuffle_tensors[3] =
4942         params->converter->network()
4943             ->addElementWise(*num_channels->trt_tensor(),
4944                              *r_squared->trt_tensor(),
4945                              nvinfer1::ElementWiseOperation::kDIV)
4946             ->getOutput(0);
4947     first_shuffle_tensors[4] = h;
4948     first_shuffle_tensors[5] = w;
4949     // Second Reshape [N, C/(r*r), H, r, W, r] -> [N, C/(r*r), H * r, W * r].
4950     second_shuffle_tensors[0] = batch_size;
4951     second_shuffle_tensors[1] =
4952         params->converter->network()
4953             ->addElementWise(*num_channels->trt_tensor(),
4954                              *r_squared->trt_tensor(),
4955                              nvinfer1::ElementWiseOperation::kDIV)
4956             ->getOutput(0);
4957     second_shuffle_tensors[2] =
4958         params->converter->network()
4959             ->addElementWise(*h->trt_tensor(), *r->trt_tensor(),
4960                              nvinfer1::ElementWiseOperation::kPROD)
4961             ->getOutput(0);
4962     second_shuffle_tensors[3] =
4963         params->converter->network()
4964             ->addElementWise(*w->trt_tensor(), *r->trt_tensor(),
4965                              nvinfer1::ElementWiseOperation::kPROD)
4966             ->getOutput(0);
4967   } else if (node_def.op() == "SpaceToDepth") {
4968     // First Reshape [N, C, H, W] -> [N, C, H/r, r, W/r, r].
4969     first_shuffle_tensors[0] = batch_size;
4970     first_shuffle_tensors[1] = num_channels;
4971     first_shuffle_tensors[2] =
4972         params->converter->network()
4973             ->addElementWise(*h->trt_tensor(), *r->trt_tensor(),
4974                              nvinfer1::ElementWiseOperation::kDIV)
4975             ->getOutput(0);
4976     first_shuffle_tensors[3] = r;
4977     first_shuffle_tensors[4] =
4978         params->converter->network()
4979             ->addElementWise(*w->trt_tensor(), *r->trt_tensor(),
4980                              nvinfer1::ElementWiseOperation::kDIV)
4981             ->getOutput(0);
4982     first_shuffle_tensors[5] = r;
4983 
4984     // Second Reshape  [N, r, r, C, H/r, W/r] -> [N, C*r*r, H/r, W/r].
4985     second_shuffle_tensors[0] = batch_size;
4986     second_shuffle_tensors[1] =
4987         params->converter->network()
4988             ->addElementWise(*num_channels->trt_tensor(),
4989                              *r_squared->trt_tensor(),
4990                              nvinfer1::ElementWiseOperation::kPROD)
4991             ->getOutput(0);
4992     second_shuffle_tensors[2] =
4993         params->converter->network()
4994             ->addElementWise(*h->trt_tensor(), *r->trt_tensor(),
4995                              nvinfer1::ElementWiseOperation::kDIV)
4996             ->getOutput(0);
4997     second_shuffle_tensors[3] =
4998         params->converter->network()
4999             ->addElementWise(*w->trt_tensor(), *r->trt_tensor(),
5000                              nvinfer1::ElementWiseOperation::kDIV)
5001             ->getOutput(0);
5002   }
5003 
5004   StatusOr<ITensorProxyPtr> result =
5005       ConcatenateTensors(params, first_shuffle_tensors, 0);
5006   TF_RETURN_IF_ERROR(result.status());
5007   ITensorProxyPtr first_shuffle_shape = result.ValueOrDie();
5008 
5009   result = ConcatenateTensors(params, second_shuffle_tensors, 1);
5010   TF_RETURN_IF_ERROR(result.status());
5011   ITensorProxyPtr second_shuffle_shape = result.ValueOrDie();
5012 
5013   return std::make_pair(first_shuffle_shape, second_shuffle_shape);
5014 }
5015 
ConvertDepthSpaceShuffle(OpConverterParams * params)5016 Status ConvertDepthSpaceShuffle(OpConverterParams* params) {
5017   const auto& inputs = params->inputs;
5018   const auto& node_def = params->node_def;
5019   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
5020   TF_RETURN_IF_ERROR(AllowDataTypes(
5021       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
5022 
5023   string data_format;
5024   int block_size;
5025   AttrSlice attrs(node_def);
5026   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
5027   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "block_size", &block_size));
5028 
5029   if (block_size < 2) {
5030     return errors::InvalidArgument("Block size must be 2 or greater");
5031   }
5032 
5033   if (data_format != "NCHW" && data_format != "NHWC") {
5034     return errors::Unimplemented("Data format ", data_format,
5035                                  " is not supported");
5036   }
5037   int idx_offset = params->use_implicit_batch ? 0 : 1;
5038   nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
5039   const int required_rank = 3 + idx_offset;
5040   if (dims.nbDims != required_rank) {
5041     return errors::InvalidArgument("The input to ", node_def.op(),
5042                                    " must be rank 4");
5043   }
5044   const int num_channels =
5045       data_format == "NCHW" ? dims.d[0 + idx_offset] : dims.d[2 + idx_offset];
5046   const int h =
5047       data_format == "NCHW" ? dims.d[1 + idx_offset] : dims.d[0 + idx_offset];
5048   const int w =
5049       data_format == "NCHW" ? dims.d[2 + idx_offset] : dims.d[1 + idx_offset];
5050   // Get shuffle parameters.
5051   nvinfer1::Dims first_shuffle_shape;
5052   nvinfer1::Permutation transpose_perm;
5053   nvinfer1::Dims second_shuffle_shape;
5054 
5055   // We define all the shuffle and transpose dimensions assuming implicit batch
5056   // mode. Afterwards we will update them to explicit batch mode if needed.
5057   // Additionally, an NCHW layout is assumed, and this assumption is corrected
5058   // afterwards with an initial transpose op. TODO(tfeher): Get rid of the
5059   // layout_transpose ops by defining shuffle shape specifically for NCHW and
5060   // NHCW.
5061   if (node_def.op() == "DepthToSpace") {
5062     if (num_channels != -1 && num_channels % (block_size * block_size) != 0) {
5063       return errors::InvalidArgument(
5064           "Number of channels must be divisible by block_size*block_size");
5065     }
5066     // First Reshape [C, H, W] - > [r, r, C/(r*r), H, W]
5067     first_shuffle_shape = {
5068         /*nbDims=*/5,
5069         /*d=*/{block_size, block_size, num_channels / (block_size * block_size),
5070                h, w}};
5071     // Transpose [r, r, C/(r*r), H, W] -> [C/(r*r), H, r, W, r]
5072     transpose_perm = {2, 3, 0, 4, 1};
5073     // Second Reshape [C/(r*r), H, r, W, r] -> [C/(r*r), H * r, W * r]
5074     second_shuffle_shape =
5075         nvinfer1::Dims3(num_channels / (block_size * block_size),
5076                         h * block_size, w * block_size);
5077   } else {
5078     if (node_def.op() != "SpaceToDepth")
5079       return errors::InvalidArgument("Incorrect op type ", node_def.op());
5080     if ((h != -1 && h % block_size != 0) || (w != -1 && w % block_size != 0)) {
5081       return errors::InvalidArgument(
5082           "Width and height must be divisible by block_size");
5083     }
5084     // First Reshape [C, H, W] -> [C, H/r, r, W/r, r]
5085     first_shuffle_shape = {/*nbDims=*/5,
5086                            /*d=*/{num_channels, h / block_size, block_size,
5087                                   w / block_size, block_size}};
5088     // Transpose [C, H/r, r, W/r, r] -> [r, r, C, H/r, W/r]
5089     transpose_perm = {2, 4, 0, 1, 3};
5090     // Second Reshape  [r, r, C, H/r, W/r] -> [C*r*r, H/r, W/r]
5091     second_shuffle_shape = nvinfer1::Dims3(
5092         num_channels * block_size * block_size, h / block_size, w / block_size);
5093   }
5094   if (params->validation_only) return Status::OK();
5095 
5096   nvinfer1::IShuffleLayer* first_shuffle =
5097       params->converter->network()->addShuffle(
5098           *inputs.at(0).tensor()->trt_tensor());
5099   TFTRT_RETURN_ERROR_IF_NULLPTR(first_shuffle, node_def.name());
5100   params->converter->SetLayerName(first_shuffle, node_def, "shuffle",
5101                                   /*op_instance=*/0);
5102 
5103   ITensorProxyPtr second_shuffle_shape_tensor;
5104 
5105   if (HasStaticShape(inputs.at(0).GetTrtDims())) {
5106     // Adjust a reshape constructed at implicit batch mode for explicit batch
5107     // mode. In particular, we need to insert the batch dimension size to the
5108     // beginning of all the dimension sizes. Example: reshape {20,10,30} for
5109     // implicit batch mode becomes reshape {N,20,10,30} for explicit batch mode.
5110     auto adjust_reshape = [](int N, nvinfer1::Dims dims,
5111                              bool use_implicit_batch) {
5112       if (use_implicit_batch) return dims;
5113       for (int i = dims.nbDims; i > 0; i--) {
5114         dims.d[i] = dims.d[i - 1];
5115       }
5116       dims.d[0] = N;
5117       dims.nbDims++;
5118       return dims;
5119     };
5120 
5121     first_shuffle_shape = adjust_reshape(dims.d[0], first_shuffle_shape,
5122                                          params->use_implicit_batch);
5123     second_shuffle_shape = adjust_reshape(dims.d[0], second_shuffle_shape,
5124                                           params->use_implicit_batch);
5125 
5126     first_shuffle->setReshapeDimensions(first_shuffle_shape);
5127   } else {
5128     StatusOr<std::pair<ITensorProxyPtr, ITensorProxyPtr>> result =
5129         CalcDepthSpaceDynamicShape(params, block_size, data_format);
5130     TF_RETURN_IF_ERROR(result.status());
5131     first_shuffle->setInput(1, *result.ValueOrDie().first->trt_tensor());
5132     second_shuffle_shape_tensor = result.ValueOrDie().second;
5133   }
5134 
5135   // Adjust a transpose constructed assuming implicit batch mode for explicit
5136   // batch mode. In particular, we need to add the batch dimension to d0 and
5137   // add 1 to all the dimension id in the transpose. Example: permutation
5138   // for implicit batch mode becomes permutation {0,3,2,1} for explicit batch
5139   // mode.
5140   auto adjust_perm = [](int n, nvinfer1::Permutation perm,
5141                         bool use_implicit_batch) {
5142     if (use_implicit_batch) return perm;
5143     for (int i = n; i > 0; i--) {
5144       perm.order[i] = perm.order[i - 1] + 1;
5145     }
5146     perm.order[0] = 0;
5147     return perm;
5148   };
5149   transpose_perm = adjust_perm(5, transpose_perm, params->use_implicit_batch);
5150 
5151   if (data_format == "NHWC") {
5152     nvinfer1::Permutation layout_transpose =
5153         adjust_perm(3, {2, 0, 1}, params->use_implicit_batch);
5154     first_shuffle->setFirstTranspose(layout_transpose);
5155   }
5156   first_shuffle->setSecondTranspose(transpose_perm);
5157 
5158   nvinfer1::IShuffleLayer* second_shuffle =
5159       params->converter->network()->addShuffle(*first_shuffle->getOutput(0));
5160   TFTRT_RETURN_ERROR_IF_NULLPTR(second_shuffle, node_def.name());
5161   params->converter->SetLayerName(second_shuffle, node_def, "shuffle",
5162                                   /*op_instance=*/1);
5163 
5164   if (HasStaticShape(inputs.at(0).GetTrtDims())) {
5165     second_shuffle->setReshapeDimensions(second_shuffle_shape);
5166   } else {
5167     second_shuffle->setInput(1, *second_shuffle_shape_tensor->trt_tensor());
5168   }
5169   if (data_format == "NHWC") {
5170     nvinfer1::Permutation layout_transpose =
5171         adjust_perm(3, {1, 2, 0}, params->use_implicit_batch);
5172     second_shuffle->setSecondTranspose(layout_transpose);
5173   }
5174 
5175   params->outputs->push_back(TRT_TensorOrWeights(second_shuffle->getOutput(0)));
5176   return Status::OK();
5177 }
5178 
ConvertSquaredDifference(OpConverterParams * params)5179 Status ConvertSquaredDifference(OpConverterParams* params) {
5180   TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y", false}}));
5181   TF_RETURN_IF_ERROR(
5182       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
5183   const auto& inputs = params->inputs;
5184   const auto& node_def = params->node_def;
5185   // Broadcast inputs.
5186   nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
5187   TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
5188       inputs.at(0), inputs.at(1), /*check_feasibility=*/true,
5189       params->use_implicit_batch, &broadcasted_dims_l, &broadcasted_dims_r));
5190   ITensorProxyPtr tensor_l = nullptr;
5191   ITensorProxyPtr tensor_r = nullptr;
5192   TF_RETURN_IF_ERROR(
5193       PrepareTensorForShape(params->converter, inputs.at(0), broadcasted_dims_l,
5194                             params->validation_only, &tensor_l, node_def));
5195   TF_RETURN_IF_ERROR(
5196       PrepareTensorForShape(params->converter, inputs.at(1), broadcasted_dims_r,
5197                             params->validation_only, &tensor_r, node_def));
5198   if (params->validation_only) return Status::OK();
5199 
5200   // Subtract x - y.
5201   nvinfer1::IElementWiseLayer* sub =
5202       params->converter->network()->addElementWise(
5203           *tensor_l->trt_tensor(), *tensor_r->trt_tensor(),
5204           nvinfer1::ElementWiseOperation::kSUB);
5205   TFTRT_RETURN_ERROR_IF_NULLPTR(sub, node_def.name());
5206   params->converter->SetLayerName(sub, node_def, "sub");
5207 
5208   // Multiply (x - y) * (x - y).
5209   nvinfer1::IElementWiseLayer* mul =
5210       params->converter->network()->addElementWise(
5211           *sub->getOutput(0), *sub->getOutput(0),
5212           nvinfer1::ElementWiseOperation::kPROD);
5213   TFTRT_RETURN_ERROR_IF_NULLPTR(mul, node_def.name());
5214   params->converter->SetLayerName(mul, node_def, "mul");
5215 
5216   params->outputs->push_back(TRT_TensorOrWeights(mul->getOutput(0)));
5217   return Status::OK();
5218 }
5219 
5220 #if IS_TRT_VERSION_GE(8, 2, 1, 6) || defined(TF_TRT_USE_EFFICIENT_NMS_PLUGIN)
5221 
ConvertCombinedNMS(OpConverterParams * params)5222 Status ConvertCombinedNMS(OpConverterParams* params) {
5223   TF_RETURN_IF_ERROR(CheckInputsWeights(
5224       *params, {{"boxes", TrtInputArg::kTensor},
5225                 {"scores", TrtInputArg::kTensor},
5226                 {"max_output_size_per_class", TrtInputArg::kWeight},
5227                 {"max_total_size", TrtInputArg::kWeight},
5228                 {"iou_threshold", TrtInputArg::kWeight},
5229                 {"score_threshold", TrtInputArg::kWeight}}));
5230   const auto& inputs = params->inputs;
5231   const auto& node_def = params->node_def;
5232   ITensorProxyPtr boxes_tensor = inputs.at(0).tensor();
5233   ITensorProxyPtr scores_tensor = inputs.at(1).tensor();
5234   if (params->use_implicit_batch) {
5235     return errors::Unimplemented(
5236         "Implict batch mode not supported with CombinedNMS", node_def.name());
5237   }
5238 
5239   TRT_ShapedWeights output_size_per_class = inputs.at(2).weights();
5240   TRT_ShapedWeights total_size = inputs.at(3).weights();
5241   TRT_ShapedWeights iou_threshold = inputs.at(4).weights();
5242   TRT_ShapedWeights score_threshold = inputs.at(5).weights();
5243   const int max_size_per_class = *(output_size_per_class.GetPointer<int>());
5244   int max_total_size = *(total_size.GetPointer<int>());
5245   const float iou_thresh = *(iou_threshold.GetPointer<float>());
5246   const float score_thresh = *(score_threshold.GetPointer<float>());
5247 
5248   AttrSlice attrs(node_def);
5249   bool clip_boxes = false, pad_per_class = false;
5250   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "clip_boxes", &clip_boxes));
5251   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "pad_per_class", &pad_per_class));
5252 
5253   // Validate tensors and weights
5254   const auto boxes_dims = boxes_tensor->getDimensions();
5255   const auto scores_dims = scores_tensor->getDimensions();
5256   if (boxes_dims.nbDims != 4) {
5257     return errors::InvalidArgument(
5258         "NMS TRT Plugin input boxes must be 4-D including batch ",
5259         node_def.name());
5260   }
5261   const int num_classes = scores_dims.d[2];
5262   bool box_check = boxes_dims.d[2] == 1 || boxes_dims.d[2] == num_classes;
5263   if (!box_check) {
5264     return errors::InvalidArgument(
5265         "NMS TRT Plugin third dimension of boxes must be either 1 "
5266         "or match the num_classes dimension of scores ",
5267         node_def.name());
5268   }
5269 
5270   if (output_size_per_class.count() != 1) {
5271     return errors::InvalidArgument(
5272         "NMS TRT Plugin max_output_size_per_class must be scalar ",
5273         node_def.name());
5274   }
5275   if (max_size_per_class <= 0) {
5276     return errors::InvalidArgument(
5277         "NMS TRT Plugin max_output_size_per_class should be > 0",
5278         node_def.name());
5279   }
5280   if (total_size.count() != 1) {
5281     return errors::InvalidArgument(
5282         "NMS TRT Plugin max_total_size must be scalar ", node_def.name());
5283   }
5284   if (max_total_size <= 0) {
5285     return errors::InvalidArgument(
5286         "NMS TRT Plugin max_total_size should be > 0", node_def.name());
5287   }
5288   if (iou_threshold.count() != 1) {
5289     return errors::InvalidArgument(
5290         "NMS TRT Plugin iou_threshold must be scalar ", node_def.name());
5291   }
5292   if (iou_thresh < 0.0 || iou_thresh > 1.0) {
5293     return errors::InvalidArgument(
5294         "NMS TRT Plugin iou_threshold must be in [0, 1]", node_def.name());
5295   }
5296   if (score_threshold.count() != 1) {
5297     return errors::InvalidArgument(
5298         "NMS TRT Plugin score_threshold must be scalar ", node_def.name());
5299   }
5300 
5301   if (params->validation_only) return Status::OK();
5302 
5303   // Create plugin
5304   nvinfer1::PluginField fields[6] = {
5305       nvinfer1::PluginField{"max_output_size_per_class", &max_size_per_class,
5306                             nvinfer1::PluginFieldType::kINT32, 1},
5307       nvinfer1::PluginField{"max_total_size", &max_total_size,
5308                             nvinfer1::PluginFieldType::kINT32, 1},
5309       nvinfer1::PluginField{"iou_threshold", &iou_thresh,
5310                             nvinfer1::PluginFieldType::kFLOAT32, 1},
5311       nvinfer1::PluginField{"score_threshold", &score_thresh,
5312                             nvinfer1::PluginFieldType::kFLOAT32, 1},
5313       nvinfer1::PluginField{"pad_per_class", &pad_per_class,
5314                             nvinfer1::PluginFieldType::kINT32, 1},
5315       nvinfer1::PluginField{"clip_boxes", &clip_boxes,
5316                             nvinfer1::PluginFieldType::kINT32, 1},
5317   };
5318   nvinfer1::PluginFieldCollection fc{6, fields};
5319 
5320   auto creator =
5321       getPluginRegistry()->getPluginCreator("EfficientNMS_TFTRT_TRT", "1", "");
5322   TFTRT_RETURN_ERROR_IF_NULLPTR(creator, node_def.name());
5323 
5324   TrtUniquePtrType<nvinfer1::IPluginV2> plugin(
5325       creator->createPlugin(node_def.name().c_str(), &fc));
5326   TFTRT_RETURN_ERROR_IF_NULLPTR(plugin, node_def.name());
5327 
5328   // Set plugin inputs
5329   std::vector<nvinfer1::ITensor*> trt_plugin_inputs;
5330   trt_plugin_inputs.push_back(boxes_tensor->trt_tensor());
5331   trt_plugin_inputs.push_back(scores_tensor->trt_tensor());
5332 
5333   // Add plugin to network
5334   nvinfer1::IPluginV2Layer* layer = params->converter->network()->addPluginV2(
5335       &trt_plugin_inputs[0], static_cast<int>(trt_plugin_inputs.size()),
5336       *plugin);
5337   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
5338   params->converter->SetLayerName(layer, node_def, "plugin");
5339 
5340   // Set plugin outputs
5341   ITensorProxyPtr output_num_detections = layer->getOutput(0);
5342   ITensorProxyPtr output_detection_boxes = layer->getOutput(1);
5343   ITensorProxyPtr output_detection_scores = layer->getOutput(2);
5344   ITensorProxyPtr output_detection_classes = layer->getOutput(3);
5345 
5346   // Cast the classes output from int32 to float32
5347   nvinfer1::IIdentityLayer* layer_detection_classes =
5348       params->converter->network()->addIdentity(
5349           *output_detection_classes->trt_tensor());
5350   layer_detection_classes->setOutputType(0, nvinfer1::DataType::kFLOAT);
5351   output_detection_classes = layer_detection_classes->getOutput(0);
5352 
5353   // The plugin produces a [N, 1] tensor for the num output, squeeze it to [N]
5354   std::vector<int> input_dims{output_num_detections->getDimensions().d[0], 0};
5355   TF_RETURN_IF_ERROR(params->converter->SqueezeTensor(
5356       /*input=*/output_num_detections,
5357       /*input_dims=*/&input_dims,
5358       /*params=*/params,
5359       /*output=*/&output_num_detections));
5360 
5361   // Final outputs
5362   params->outputs->push_back(TRT_TensorOrWeights(output_detection_boxes));
5363   params->outputs->push_back(TRT_TensorOrWeights(output_detection_scores));
5364   params->outputs->push_back(TRT_TensorOrWeights(output_detection_classes));
5365   params->outputs->push_back(TRT_TensorOrWeights(output_num_detections));
5366 
5367   return Status::OK();
5368 }
5369 
5370 #elif IS_TRT_VERSION_GE(7, 1, 3, 0)
5371 
AllowNmsTopkOverride()5372 bool AllowNmsTopkOverride() {
5373   static bool result = [] {
5374     bool value;
5375     Status status = ReadBoolFromEnvVar("TF_TRT_ALLOW_NMS_TOPK_OVERRIDE",
5376                                        /*default_value=*/false, &value);
5377     if (!status.ok()) {
5378       LOG(ERROR) << status;
5379     }
5380     return value;
5381   }();
5382   return result;
5383 }
5384 
ConvertCombinedNMS(OpConverterParams * params)5385 Status ConvertCombinedNMS(OpConverterParams* params) {
5386   TF_RETURN_IF_ERROR(
5387       CheckInputsWeights(*params, {{"boxes", false},
5388                                    {"scores", false},
5389                                    {"max_output_size_per_class", true},
5390                                    {"max_total_size", true},
5391                                    {"iou_threshold", true},
5392                                    {"score_threshold", true}}));
5393   const auto& inputs = params->inputs;
5394   const auto& node_def = params->node_def;
5395 
5396   ITensorProxyPtr boxes_tensor = inputs.at(0).tensor();
5397   ITensorProxyPtr scores_tensor = inputs.at(1).tensor();
5398   TRT_ShapedWeights output_size_per_class = inputs.at(2).weights();
5399   TRT_ShapedWeights total_size = inputs.at(3).weights();
5400   TRT_ShapedWeights iou_threshold = inputs.at(4).weights();
5401   TRT_ShapedWeights score_threshold = inputs.at(5).weights();
5402 
5403   // Validate tensors and weights (also set some of the needed plugin fields)
5404   const auto boxes_dims = boxes_tensor->getDimensions();
5405   const auto scores_dims = scores_tensor->getDimensions();
5406   if (!params->use_implicit_batch &&
5407       (!HasStaticShape(boxes_dims) || !HasStaticShape(scores_dims))) {
5408     return errors::Unimplemented(
5409         "TensorRT BatchedNMS Plugin requires input with static shape");
5410   }
5411   const int offset = params->use_implicit_batch ? 0 : 1;
5412   if (boxes_dims.nbDims != 3 + offset) {
5413     return errors::InvalidArgument(
5414         "TensorRT BatchedNMS Plugin input boxes must be 4-D including batch");
5415   }
5416   const int class_idx = 1 + offset;
5417   const int num_classes = scores_dims.d[class_idx];
5418   const int num_boxes = boxes_dims.d[0 + offset];
5419   bool box_check =
5420       boxes_dims.d[class_idx] == 1 || boxes_dims.d[class_idx] == num_classes;
5421   if (!box_check) {
5422     return errors::InvalidArgument(
5423         "TensorRT BatchedNMS Plugin third dimension of boxes must be either 1 "
5424         "or num_classes");
5425   }
5426 
5427   if (output_size_per_class.count() != 1) {
5428     return errors::InvalidArgument(
5429         "TensorRT BatchedNMS Plugin max_output_size_per_class must be scalar");
5430   }
5431   int max_size_per_class = *(output_size_per_class.GetPointer<int>());
5432   if (max_size_per_class <= 0) {
5433     return errors::InvalidArgument(
5434         "TensorRT BatchedNMS Plugin max_output_size_per_class should be > 0");
5435   }
5436   if (total_size.count() != 1) {
5437     return errors::InvalidArgument(
5438         "TensorRT BatchedNMS Plugin max_total_size must be scalar");
5439   }
5440   int max_total_size = *(total_size.GetPointer<int>());
5441   if (max_total_size <= 0) {
5442     return errors::InvalidArgument(
5443         "TensorRT BatchedNMS Plugin max_total_size should be > 0");
5444   }
5445   if (iou_threshold.count() != 1) {
5446     return errors::InvalidArgument(
5447         "TensorRT BatchedNMS Plugin iou_threshold must be scalar");
5448   }
5449   float iou_thresh = *(iou_threshold.GetPointer<float>());
5450   if (iou_thresh < 0.0 || iou_thresh > 1.0) {
5451     return errors::InvalidArgument(
5452         "TensorRT BatchedNMS Plugin iou_threshold must be in [0, 1]");
5453   }
5454   if (score_threshold.count() != 1) {
5455     return errors::InvalidArgument(
5456         "TensorRT BatchedNMS Plugin score_threshold must be scalar");
5457   }
5458 
5459   bool pad_per_class = false, clip_boxes = false;
5460   AttrSlice attrs(node_def);
5461   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "pad_per_class", &pad_per_class));
5462   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "clip_boxes", &clip_boxes));
5463 
5464   // TRT op is_normalized=False treats input corrdinates as pixels and
5465   // calculates width/height as (max - min + 1).
5466   //
5467   // TF op CombinedNonMaxSuppression doesn't care about the normalization and
5468   // calculates width/height  as (max-min).
5469   //
5470   // We set is_normalized = true to be consistent with TF IOU calculaton.
5471   const bool is_normalized = true;
5472 
5473   bool share_location = (boxes_dims.d[class_idx] == 1);
5474   int keep_top_k = 0;
5475   if (pad_per_class) {
5476     keep_top_k = std::min(max_size_per_class * num_classes, max_total_size);
5477   } else {
5478     keep_top_k = max_total_size;
5479   }
5480 
5481   // According to the batchedNMS plugin description we need to set top_k so that
5482   // keep_top_k <= top_k
5483   // https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
5484   // Before the NMS step, TRT selects top_k candidate from each class and
5485   // discards the rest. The NMS step is performed only among the top_k
5486   // candidates. To be strictly compatible with the TF op, we need that top_k is
5487   // greater equal to num_boxes.
5488   int top_k = std::max(num_boxes, keep_top_k);
5489   // TRT has a limitation: top_k <=4096.
5490   if (top_k > 4096) {
5491     if (AllowNmsTopkOverride()) {
5492       top_k = 4096;
5493       keep_top_k = std::min(top_k, keep_top_k);
5494     } else {
5495       return errors::InvalidArgument(
5496           "TRT NMS plugin allow top_k<=4096, where top_k = max(num_boxes, "
5497           "max_total_size). You can override this by setting "
5498           "TF_TRT_ALLOW_NMS_TOPK_OVERRIDE=1 environment variable, but this can "
5499           "result in a loss of accuracy.");
5500     }
5501   }
5502 
5503   if (params->validation_only) return Status::OK();
5504   float score_thresh = *(score_threshold.GetPointer<float>());
5505   const int background_id = -1;
5506   nvinfer1::PluginField fields[9] = {
5507       nvinfer1::PluginField{"shareLocation", &share_location,
5508                             nvinfer1::PluginFieldType::kINT32, 1},
5509       nvinfer1::PluginField{"backgroundLabelId", &background_id,
5510                             nvinfer1::PluginFieldType::kINT32, 1},
5511       nvinfer1::PluginField{"numClasses", &num_classes,
5512                             nvinfer1::PluginFieldType::kINT32, 1},
5513       nvinfer1::PluginField{"topK", &top_k, nvinfer1::PluginFieldType::kINT32,
5514                             1},
5515       nvinfer1::PluginField{"keepTopK", &keep_top_k,
5516                             nvinfer1::PluginFieldType::kINT32, 1},
5517       nvinfer1::PluginField{"scoreThreshold", &score_thresh,
5518                             nvinfer1::PluginFieldType::kFLOAT32, 1},
5519       nvinfer1::PluginField{"iouThreshold", &iou_thresh,
5520                             nvinfer1::PluginFieldType::kFLOAT32, 1},
5521       nvinfer1::PluginField{"isNormalized", &is_normalized,
5522                             nvinfer1::PluginFieldType::kINT32, 1},
5523       nvinfer1::PluginField{"clipBoxes", &clip_boxes,
5524                             nvinfer1::PluginFieldType::kINT32, 1}};
5525   nvinfer1::PluginFieldCollection fc{9, fields};
5526 
5527   // Get plugin creator
5528   auto creator =
5529       getPluginRegistry()->getPluginCreator("BatchedNMS_TRT", "1", "");
5530   TFTRT_RETURN_ERROR_IF_NULLPTR(creator, node_def.name());
5531 
5532   // Create plugin
5533   TrtUniquePtrType<nvinfer1::IPluginV2> plugin(
5534       creator->createPlugin(node_def.name().c_str(), &fc));
5535   TFTRT_RETURN_ERROR_IF_NULLPTR(plugin, node_def.name());
5536 
5537   // Set plugin inputs
5538   std::vector<nvinfer1::ITensor*> trt_plugin_inputs;
5539   trt_plugin_inputs.push_back(boxes_tensor->trt_tensor());
5540   trt_plugin_inputs.push_back(scores_tensor->trt_tensor());
5541 
5542   // Add plugin to network
5543   nvinfer1::IPluginV2Layer* layer = params->converter->network()->addPluginV2(
5544       &trt_plugin_inputs[0], static_cast<int>(trt_plugin_inputs.size()),
5545       *plugin);
5546   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
5547   params->converter->SetLayerName(layer, node_def, "plugin");
5548 
5549   // Set plugin outputs
5550   ITensorProxyPtr output_nmsed_boxes = layer->getOutput(1);
5551 
5552   // TensorRT fixes (removes) the extra last dimension in CombinedNMS outputs
5553   ITensorProxyPtr output_num_detections = layer->getOutput(0);
5554   ITensorProxyPtr output_nmsed_scores = layer->getOutput(2);
5555   ITensorProxyPtr output_nmsed_classes = layer->getOutput(3);
5556 
5557   params->outputs->push_back(TRT_TensorOrWeights(output_nmsed_boxes));
5558   params->outputs->push_back(TRT_TensorOrWeights(output_nmsed_scores));
5559   params->outputs->push_back(TRT_TensorOrWeights(output_nmsed_classes));
5560   params->outputs->push_back(TRT_TensorOrWeights(output_num_detections));
5561 
5562   return Status::OK();
5563 }
5564 
5565 #endif  // IS_TRT_VERSION_GE(7, 1, 3, 0)
5566 
ConvertResize(OpConverterParams * params)5567 Status ConvertResize(OpConverterParams* params) {
5568   const auto& inputs = params->inputs;
5569   const auto& node_def = params->node_def;
5570   TF_RETURN_IF_ERROR(CheckInputsWeights(
5571       *params,
5572       {{"input", TrtInputArg::kTensor}, {"size", TrtInputArg::kBoth}}));
5573   TF_RETURN_IF_ERROR(AllowDataTypes(
5574       *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
5575 
5576   // Get input tensor.
5577   ITensorProxyPtr inputs_tensor = inputs.at(0).tensor();
5578   TFTRT_RETURN_ERROR_IF_NULLPTR(inputs_tensor, params->node_def.name());
5579 
5580   // Check output size. It must constain two values i.e. [H_out, W_out]
5581   const bool const_output_size = inputs.at(1).is_weights();
5582   if (const_output_size) {
5583     // Output size is given as a constant.
5584     if (inputs.at(1).weights().count() != 2) {
5585       return errors::Unimplemented("Resize requires 2D values for the size");
5586     }
5587   } else {
5588     // Output size is given as a tensor, possibly as the result of shape
5589     // calculation ops in the graph.
5590     if (params->use_implicit_batch) {
5591       return errors::Unimplemented(
5592           "Resize requires constant size in implicit batch mode");
5593     }
5594     TF_RETURN_IF_ERROR(ExpectShapeTensor(inputs.at(1)));
5595     if (inputs.at(1).tensor()->getDimensions().d[0] != 2) {
5596       return errors::Unimplemented("Resize requires 2D values for the size");
5597     }
5598   }
5599 
5600   // Verify and consume node attributes.
5601   bool align_corners;
5602   TF_RETURN_IF_ERROR(
5603       GetNodeAttr(AttrSlice(node_def), "align_corners", &align_corners));
5604   TF_RETURN_IF_ERROR(
5605       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
5606 
5607   // Verify resize mode. Initialize resize mode if supported.
5608   nvinfer1::ResizeMode resize_mode;
5609   if (node_def.op() == "ResizeBilinear") {
5610 #if IS_TRT_VERSION_GE(7, 1, 0, 0)
5611     if (!align_corners) {
5612       return errors::InvalidArgument(
5613           "Cannot Convert Bilinear Resize when align_corners=False");
5614     }
5615 #endif
5616     resize_mode = nvinfer1::ResizeMode::kLINEAR;
5617   } else if (node_def.op() == "ResizeNearestNeighbor") {
5618     resize_mode = nvinfer1::ResizeMode::kNEAREST;
5619   } else {
5620     return errors::Unimplemented(node_def.op(), " is not yet implemented");
5621   }
5622 
5623   // return after validation if only validation is requested.
5624   if (params->validation_only) return Status::OK();
5625 
5626   // Transpose tensor from NHWC to NCHW format.
5627   TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
5628       inputs_tensor, {0, 3, 1, 2}, &inputs_tensor, node_def, "to_NCHW"));
5629 
5630   // Calculate the output shape as static dimensions or a shape tensor:
5631   // Given input shape [N, C, H, W] and output size [H_out, W_out],
5632   // output shape equals [N, C, H_out, W_out].
5633   nvinfer1::Dims output_shape_dims;
5634   ITensorProxyPtr output_shape_tensor;
5635   const bool static_output_shape =
5636       HasStaticShape(inputs_tensor->getDimensions()) && const_output_size;
5637   if (static_output_shape) {
5638     // If the output shape can be fully determined at build time, calculate it
5639     // as a set of dimensions.
5640     output_shape_dims.nbDims = inputs_tensor->getDimensions().nbDims;
5641     for (int i = 0; i < output_shape_dims.nbDims; ++i) {
5642       output_shape_dims.d[i] = inputs_tensor->getDimensions().d[i];
5643     }
5644     const int* weights_ptr = inputs.at(1).weights().GetPointer<int>();
5645     output_shape_dims.d[output_shape_dims.nbDims - 2] = weights_ptr[0];
5646     output_shape_dims.d[output_shape_dims.nbDims - 1] = weights_ptr[1];
5647   } else {
5648     // Otherwise, build the output shape as a shape tensor that will be computed
5649     // at run time.
5650     // The batch size and num of channels will be copied from the input shape.
5651     ITensorProxyPtr shape = params->converter->network()
5652                                 ->addShape(*inputs_tensor->trt_tensor())
5653                                 ->getOutput(0);
5654     ITensorProxyPtr batch_size =
5655         params->converter->network()
5656             ->addSlice(*shape->trt_tensor(), {1, {0}}, {1, {1}}, {1, {1}})
5657             ->getOutput(0);
5658     ITensorProxyPtr num_channels =
5659         params->converter->network()
5660             ->addSlice(*shape->trt_tensor(), {1, {1}}, {1, {1}}, {1, {1}})
5661             ->getOutput(0);
5662 
5663     // The height and width will be obtained from the requested output size.
5664     ITensorProxyPtr height, width;
5665     if (const_output_size) {
5666       // If the output size is constant, the height and width dimensions can be
5667       // created as constants from the size values.
5668       const int* weights_ptr = inputs.at(1).weights().GetPointer<int>();
5669       TF_RETURN_IF_ERROR(CreateScalarConstant(params, weights_ptr[0], &height));
5670       TF_RETURN_IF_ERROR(CreateScalarConstant(params, weights_ptr[1], &width));
5671     } else {
5672       // Otherwise, the size is a tensor which can be sliced, and each element
5673       // used directly as the output height and width dimensions.
5674       ITensorProxyPtr size = inputs.at(1).tensor();
5675       height = params->converter->network()
5676                    ->addSlice(*size->trt_tensor(), {1, {0}}, {1, {1}}, {1, {1}})
5677                    ->getOutput(0);
5678       width = params->converter->network()
5679                   ->addSlice(*size->trt_tensor(), {1, {1}}, {1, {1}}, {1, {1}})
5680                   ->getOutput(0);
5681     }
5682 
5683     StatusOr<ITensorProxyPtr> result = ConcatenateTensors(
5684         params, {batch_size, num_channels, height, width}, 0);
5685     TF_RETURN_IF_ERROR(result.status());
5686     output_shape_tensor = result.ValueOrDie();
5687   }
5688 
5689   // Add resize layer.
5690   nvinfer1::IResizeLayer* layer =
5691       params->converter->network()->addResize(*inputs_tensor->trt_tensor());
5692   TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
5693   params->converter->SetLayerName(layer, node_def);
5694 
5695   // Set layer parameters.
5696   layer->setResizeMode(resize_mode);
5697   layer->setAlignCorners(align_corners);
5698 
5699   // Set output shape.
5700   if (static_output_shape) {
5701     // If the shapes are fully known at build time, pass the static output shape
5702     // to the resize layer as expected output dimensions.
5703     layer->setOutputDimensions(output_shape_dims);
5704   } else {
5705     // Otherwise, pass the output shape tensor to the resize layer as an input.
5706     layer->setInput(1, *output_shape_tensor->trt_tensor());
5707   }
5708 
5709   // Get output tensor. Transpose it from NCHW to NHWC.
5710   ITensorProxyPtr output = layer->getOutput(0);
5711 
5712   TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
5713       output, {0, 2, 3, 1}, &output, node_def, "to_NHWC"));
5714   params->outputs->push_back(TRT_TensorOrWeights(output));
5715   // Success
5716   return Status::OK();
5717 }  // ConvertResize
5718 
ConvertAddN(OpConverterParams * params)5719 Status ConvertAddN(OpConverterParams* params) {
5720   const auto& inputs = params->inputs;
5721   const auto& node_def = params->node_def;
5722   TF_RETURN_IF_ERROR(
5723       AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
5724 
5725   int num_inputs;
5726   TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node_def), "N", &num_inputs));
5727 
5728   if (num_inputs < 2) {
5729     return errors::InvalidArgument("AddN requires at least two inputs");
5730   }
5731 
5732   TFTRT_CHECK_INPUT_SIZE(inputs.size(), num_inputs, node_def);
5733 
5734   for (const auto& input : inputs) {
5735     if (!input.is_tensor() && input.weights().Shape().dim(0) != 1) {
5736       return errors::InvalidArgument(
5737           "Weights input to AddN is required to have batch dimension 1.");
5738     }
5739   }
5740   if (params->validation_only) return Status::OK();
5741 
5742   // AddN doesn't support broadcast.
5743   std::vector<ITensorProxyPtr> tensor_inputs;
5744   tensor_inputs.reserve(inputs.size());
5745   for (const auto& input : inputs) {
5746     if (input.is_tensor()) {
5747       tensor_inputs.push_back(input.tensor());
5748     } else {
5749       auto dims = input.weights().Shape();
5750       TF_RETURN_IF_ERROR(dims.RemoveBatchDimension());
5751       tensor_inputs.push_back(params->converter->CreateConstantLayer(
5752           input.weights(), dims.AsTrtDims()));
5753     }
5754   }
5755   ITensorProxyPtr lhs = tensor_inputs[0];
5756   for (int i = 1; i < num_inputs; ++i) {
5757     ITensorProxyPtr rhs = tensor_inputs[i];
5758     nvinfer1::ILayer* layer = params->converter->network()->addElementWise(
5759         *lhs->trt_tensor(), *rhs->trt_tensor(),
5760         nvinfer1::ElementWiseOperation::kSUM);
5761     TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
5762     params->converter->SetLayerName(layer, node_def, std::to_string(i));
5763     lhs = layer->getOutput(0);
5764   }
5765   params->outputs->push_back(TRT_TensorOrWeights(lhs));
5766   return Status::OK();
5767 }
5768 
5769 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertBiasAdd, "BiasAdd");
5770 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertClipByValue, "ClipByValue");
5771 #if IS_TRT_VERSION_GE(7, 1, 3, 0) || defined(TF_TRT_USE_EFFICIENT_NMS_PLUGIN)
5772 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertCombinedNMS,
5773                                   "CombinedNonMaxSuppression");
5774 #endif
5775 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertAddN, "AddN");
5776 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertCast, "Cast");
5777 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertConcat, "ConcatV2");
5778 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertConst, "Const");
5779 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertConv2D, "Conv2D");
5780 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertConv2DBackpropInput,
5781                                   "Conv2DBackpropInput");
5782 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertDepthSpaceShuffle, "DepthToSpace");
5783 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertConv2DDepthwise,
5784                                   "DepthwiseConv2dNative");
5785 
5786 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertExpandDims, "ExpandDims");
5787 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertFusedConv2DBiasActivation,
5788                                   "FusedConv2DBiasActivation");
5789 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertGather, "GatherV2");
5790 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertMatMul, "MatMul");
5791 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertPack, "Pack");
5792 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertPad, "Pad");
5793 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertReshape, "Reshape");
5794 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertConv3D, "Conv3D");
5795 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertConv3DBackpropInputV2,
5796                                   "Conv3DBackpropInputV2");
5797 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertResize, "ResizeBilinear");
5798 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertResize, "ResizeNearestNeighbor");
5799 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertPool3D, "AvgPool3D");
5800 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertPool3D, "MaxPool3D");
5801 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertShape, "Shape");
5802 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertSlice, "Slice");
5803 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertSoftmax, "Softmax");
5804 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertDepthSpaceShuffle, "SpaceToDepth");
5805 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertSplit, "Split");
5806 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertSquare, "Square");
5807 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertSquaredDifference,
5808                                   "SquaredDifference");
5809 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertSqueeze, "Squeeze");
5810 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertStridedSlice, "StridedSlice");
5811 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertTopK, "TopKV2");
5812 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertTranspose, "Transpose");
5813 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertUnpack, "Unpack");
5814 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertPool, {"MaxPool", "AvgPool"});
5815 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertFusedBatchNorm,
5816                                   {"FusedBatchNorm", "FusedBatchNormV2",
5817                                    "FusedBatchNormV3"});
5818 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertReduce,
5819                                   {"Sum", "Prod", "Max", "Min", "Mean"});
5820 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertArgMinMax, {"ArgMin", "ArgMax"});
5821 // The following are no-ops during inference and will not be mapped to any
5822 // TRT layer.
5823 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertIdentity,
5824                                   {"Identity", "IdentityN", "Snapshot",
5825                                    "StopGradient", "_CopyFromHostToGpu"});
5826 REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertBatchMatMul,
5827                                   {"BatchMatMul", "BatchMatMulV2"});
5828 
ConvertGraphDefToEngine(const GraphDef & gdef,OpKernelContext * ctx,TrtPrecisionMode precision_mode,int max_batch_size,size_t max_workspace_size_bytes,const std::vector<PartialTensorShape> & input_shapes,nvinfer1::ILogger * trt_logger,nvinfer1::IGpuAllocator * allocator,TRTInt8Calibrator * calibrator,TrtUniquePtrType<nvinfer1::ICudaEngine> * engine,bool use_calibration,const bool use_implicit_batch,bool * convert_successfully,TrtShapeOptimizationProfile * profiles,absl::string_view engine_name,bool use_explicit_precision,tensorflow::grappler::Cluster * cluster)5829 Status ConvertGraphDefToEngine(
5830     const GraphDef& gdef, OpKernelContext* ctx, TrtPrecisionMode precision_mode,
5831     int max_batch_size, size_t max_workspace_size_bytes,
5832     const std::vector<PartialTensorShape>& input_shapes,
5833     nvinfer1::ILogger* trt_logger, nvinfer1::IGpuAllocator* allocator,
5834     TRTInt8Calibrator* calibrator,
5835     TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool use_calibration,
5836     const bool use_implicit_batch, bool* convert_successfully,
5837     TrtShapeOptimizationProfile* profiles, absl::string_view engine_name,
5838     bool use_explicit_precision, tensorflow::grappler::Cluster* cluster) {
5839   engine->reset();
5840   if (convert_successfully) *convert_successfully = false;
5841 
5842   // Creating converter, TensorRT builder and network
5843   auto statusor = Converter::Create(precision_mode, use_calibration, trt_logger,
5844                                     use_implicit_batch, engine_name,
5845                                     use_explicit_precision, ctx);
5846 
5847   TF_RETURN_IF_ERROR(statusor.status());
5848   std::unique_ptr<Converter> converter = std::move(statusor.ValueOrDie());
5849 
5850   GraphDef graph = gdef;
5851   if (cluster != nullptr) {
5852     bool apply_layout_optim;
5853     Status status =
5854         ReadBoolFromEnvVar("TF_TRT_ENABLE_LAYOUT_OPTIMIZER",
5855                            /*default_value=*/true, &apply_layout_optim);
5856     if (!status.ok()) {
5857       LOG(ERROR) << status;
5858     }
5859     if (apply_layout_optim) {
5860       tensorflow::grappler::GrapplerItem grappler_item;
5861       grappler_item.graph = gdef;
5862       // TensorRT API requires the input for convolution to be in NCHW.
5863       tensorflow::grappler::GenericLayoutOptimizer layout_optimizer("NCHW");
5864       TF_RETURN_IF_ERROR(
5865           layout_optimizer.Optimize(cluster, grappler_item, &graph));
5866 
5867       grappler_item.graph = graph;
5868 
5869       tensorflow::grappler::ConstantFolding const_optimizer(
5870           nullptr,
5871           /*disable_compressed_tensor_optimization=*/false,
5872           /*fold_quantization_emulation=*/false);
5873       TF_RETURN_IF_ERROR(
5874           const_optimizer.Optimize(cluster, grappler_item, &graph));
5875 
5876       // The optimizers may break the topological order
5877       // so we need these steps to restore it
5878       Graph g(OpRegistry::Global());
5879       TF_RETURN_IF_ERROR(
5880           ConvertGraphDefToGraph(GraphConstructorOptions(), graph, &g));
5881       g.ToGraphDef(&graph);
5882     }
5883   }
5884   VLOG(1) << "Starting to convert TensorFlow ops to TensorRT layers";
5885   std::vector<Converter::EngineOutputInfo> output_tensors;
5886   int num_layers = converter->network()->getNbLayers();
5887   absl::flat_hash_set<const char*> layer_names;
5888   // Graph nodes are already topologically sorted during construction
5889   for (const auto& node_def : graph.node()) {
5890     const string& node_name = node_def.name();
5891     VLOG(2) << "Converting node " << node_name << ", op=" << node_def.op();
5892     if (IsEngineInput(node_name)) {
5893       int32 slot_number = -1;
5894       string type_key;
5895       if (node_def.op() == "Placeholder") {
5896         if (!strings::safe_strto32(  // non-absl ok
5897                 node_name.c_str() + strlen(IONamePrefixes::kInputPHName),
5898                 &slot_number)) {
5899           return errors::InvalidArgument("Failed to parse slot number from ",
5900                                          node_name);
5901         }
5902         type_key = "dtype";
5903       } else if (tensorflow::grappler::IsArg(node_def)) {
5904         // Maybe remove the dependence on grappler and re-implement IsArg,
5905         // which is pretty simple (but could change if new Arg nodes are added)
5906         slot_number = node_def.attr().at("index").i();
5907         type_key = "T";
5908       } else {
5909         return errors::InvalidArgument(
5910             "Node ", node_name,
5911             " with is neither Placeholder nor Arg, instead ", node_def.op());
5912       }
5913       DataType tf_dtype = node_def.attr().at(type_key).type();
5914       if (tf_dtype == DT_RESOURCE) {
5915         VLOG(2) << "Adding engine input resource " << node_name;
5916         TF_RETURN_IF_ERROR(converter->AddInputResource(
5917             node_name, ctx->input(slot_number).flat<ResourceHandle>()(0)));
5918       } else {
5919         nvinfer1::DataType trt_dtype;
5920         nvinfer1::Dims trt_dims;
5921         int batch_size = -1;
5922         const auto shape = input_shapes.at(slot_number);
5923         const auto status = ValidateTensorProperties(
5924             node_def.op(), node_def.attr().at(type_key).type(), shape,
5925             use_implicit_batch, /*validation_only=*/false, &trt_dtype,
5926             &trt_dims, &batch_size);
5927         if (!status.ok()) {
5928           const string error_message =
5929               StrCat("Validation failed for ", node_name, " and input slot ",
5930                      slot_number, ": ", status.error_message());
5931           LOG_WARNING_WITH_PREFIX << error_message;
5932           return errors::CreateWithUpdatedMessage(status, error_message);
5933         }
5934         VLOG(2) << "Adding engine input tensor " << node_name << " with shape "
5935                 << DebugString(trt_dims);
5936         // TODO(laigd): the conversion should always happen at runtime where all
5937         // the shapes are known, and we can provide a mode to generate the
5938         // engines offline, by calling sess.run() and cache/serialize the
5939         // engines.
5940         TF_RETURN_IF_ERROR(converter->AddInputTensor(node_name, trt_dtype,
5941                                                      trt_dims, batch_size));
5942       }
5943     } else if (IsEngineOutput(node_name)) {
5944       int32 slot_number = -1;
5945       if (node_def.op() == "Identity") {
5946         if (!strings::safe_strto32(  // non-absl ok
5947                 node_name.c_str() + strlen(IONamePrefixes::kOutputPHName),
5948                 &slot_number)) {
5949           return errors::InvalidArgument("Failed to parse slot number from ",
5950                                          node_name);
5951         }
5952       } else if (tensorflow::grappler::IsRetval(node_def)) {
5953         slot_number = node_def.attr().at("index").i();
5954       } else {
5955         return errors::InvalidArgument(
5956             "Node with name ", node_name,
5957             " starting with IONamePrefixes::kOutputPHName is "
5958             "neither Identity nor Retval, instead ",
5959             node_def.op());
5960       }
5961       // Get output type that TensorFlow expects
5962       string out_type_key;
5963       if (node_def.op() == "ReadVariableOp" ||
5964           node_def.op() == "ResourceGather") {
5965         out_type_key = "dtype";
5966       } else {
5967         out_type_key = "T";
5968       }
5969       DataType tf_dtype;
5970       TF_RETURN_IF_ERROR(
5971           GetNodeAttr(AttrSlice(node_def), out_type_key, &tf_dtype));
5972       nvinfer1::DataType trt_dtype;
5973       TF_RETURN_IF_ERROR(TfTypeToTrtType(tf_dtype, &trt_dtype));
5974       if (output_tensors.size() <= slot_number) {
5975         output_tensors.resize(slot_number + 1);
5976       }
5977       output_tensors.at(slot_number) = {node_def.input(0), node_name,
5978                                         trt_dtype};
5979     } else {
5980       TF_RETURN_IF_ERROR(converter->ConvertNode(node_def));
5981     }
5982 
5983     // To support TF-TRT profiling, we ensure each ILayer has a non-empty name.
5984     // BuildCudaEngine returns an error if there is any ILayer name collision.
5985     // We want to report the error here before BuildCudaEngine in a more
5986     // meaningful way.
5987     int new_num_layers = converter->network()->getNbLayers();
5988     for (int i = num_layers; i < new_num_layers; i++) {
5989       auto layer = converter->network()->getLayer(i);
5990       if (layer->getName() == nullptr ||
5991           !layer_names.insert(layer->getName()).second) {
5992         std::string error_message = absl::StrCat(
5993             "Converting node ", node_name, ", op=", node_def.op(),
5994             layer->getName() ? " creates a layer with name collision"
5995                              : " creates a layer without a name");
5996         LOG_WARNING_WITH_PREFIX << error_message;
5997         return errors::Internal(error_message);
5998       }
5999     }
6000     num_layers = new_num_layers;
6001   }
6002   TF_RETURN_IF_ERROR(converter->RenameAndMarkOutputTensors(output_tensors));
6003   if (convert_successfully) *convert_successfully = true;
6004 
6005   // Apply user provided quantization ranges to tensors
6006   if (!use_explicit_precision) {
6007     converter->MaybeApplyQuantizationRanges();
6008   }
6009 
6010   // Build the engine.
6011   TF_RETURN_IF_ERROR(converter->BuildCudaEngine(
6012       engine, max_batch_size, max_workspace_size_bytes, allocator, calibrator,
6013       profiles));
6014 
6015   VLOG(1) << "Finished conversion";
6016   return Status::OK();
6017 }
6018 
ConvertSegmentToGraphDef(const Graph * graph,const grappler::GraphProperties & graph_properties,const std::vector<const Node * > & subgraph_nodes,EngineInfo * engine_info)6019 Status ConvertSegmentToGraphDef(
6020     const Graph* graph, const grappler::GraphProperties& graph_properties,
6021     const std::vector<const Node*>& subgraph_nodes,  // In topological order
6022     EngineInfo* engine_info) {
6023   std::vector<EngineConnection>* connections = &engine_info->connections;
6024   GraphDef* segment_def = &engine_info->segment_graph_def;
6025   std::set<string> marker_nodes;
6026   // Update connection shapes/data types and add corresponding input/output
6027   // nodes in the segment graphdef.
6028   for (size_t i = 0; i < connections->size(); ++i) {
6029     auto& connection = connections->at(i);
6030     if (connection.is_control_edge()) continue;
6031     auto outside_node = graph->FindNodeId(connection.outside_id);
6032     if (!outside_node) {
6033       // This should never happen, unless the original graph is problematic.
6034       return errors::NotFound("Cannot find node with id ",
6035                               connection.outside_id, " in the graph.");
6036     }
6037     // Updates the shape and data types of input/output connections.
6038     DataType dtype;
6039     PartialTensorShape partial_shape;
6040     if (connection.is_input_edge) {
6041       GetOutputProperties(graph_properties,
6042                           graph->FindNodeId(connection.outside_id),
6043                           connection.outside_port, &partial_shape, &dtype);
6044       connection.outside_shape = partial_shape;
6045     } else {
6046       GetInputProperties(graph_properties,
6047                          graph->FindNodeId(connection.outside_id),
6048                          connection.outside_port, &partial_shape, &dtype);
6049       connection.inside_shape = partial_shape;
6050     }
6051     connection.connection_type = dtype;
6052 
6053     // Add dummy input/output nodes to the segment graphdef.
6054     if (connection.is_input_edge) {
6055       const string node_name =
6056           StrCat(IONamePrefixes::kInputPHName, connection.port_number);
6057       if (marker_nodes.count(node_name)) {
6058         VLOG(1) << "Reusing input " << node_name << " for the edge "
6059                 << connection.outside_node_name << ":"
6060                 << connection.outside_port << " -> "
6061                 << connection.inside_node_name << ":" << connection.inside_port;
6062         continue;
6063       }
6064       marker_nodes.insert(node_name);
6065       auto seg_node = segment_def->add_node();
6066       NodeDefBuilder builder(node_name, "_Arg");
6067       auto status = builder.Attr("shape", partial_shape)
6068                         .Attr("T", dtype)
6069                         .Attr("index", connection.port_number)
6070                         .Finalize(seg_node);
6071       VLOG(1) << "Constructing input " << node_name << " for the edge "
6072               << connection.outside_node_name << ":" << connection.outside_port
6073               << " -> " << connection.inside_node_name << ":"
6074               << connection.inside_port;
6075     } else {
6076       const string node_name =
6077           StrCat(IONamePrefixes::kOutputPHName, connection.port_number);
6078       if (marker_nodes.count(node_name)) {
6079         VLOG(1) << "Reusing output " << node_name << " for the edge "
6080                 << connection.inside_node_name << ":" << connection.inside_port
6081                 << " -> " << connection.outside_node_name << ":"
6082                 << connection.outside_port;
6083         continue;
6084       }
6085       marker_nodes.insert(node_name);
6086       auto seg_node = segment_def->add_node();
6087       NodeDefBuilder builder(node_name, "_Retval");
6088       auto status =
6089           builder.Attr("T", dtype)
6090               .Attr("index", connection.port_number)
6091               .Input(connection.inside_node_name, connection.inside_port, dtype)
6092               .Finalize(seg_node);
6093       VLOG(1) << "Constructing output " << node_name << " for the edge "
6094               << connection.inside_node_name << ":" << connection.inside_port
6095               << " -> " << connection.outside_node_name << ":"
6096               << connection.outside_port;
6097     }
6098   }  // for each connection.
6099 
6100   std::unordered_map<int, int> old_to_new_id_map;
6101   // Copy internal nodes to new graphdef
6102   string local_scope = subgraph_nodes.front()->name();
6103   for (const Node* node : subgraph_nodes) {
6104     local_scope = GetCommonNameScope(local_scope, node->name());
6105     old_to_new_id_map[node->id()] = segment_def->node_size();
6106     auto snode = segment_def->add_node();
6107     *snode = node->def();
6108     VLOG(2) << "Copying " << snode->name() << " to subgraph";
6109   }
6110   // Update the inputs of the new input nodes to point to placeholder nodes.
6111   for (int i = 0; i < connections->size(); ++i) {
6112     auto& connection = connections->at(i);
6113     if (connection.is_control_edge() || !connection.is_input_edge) continue;
6114     auto snode =
6115         segment_def->mutable_node(old_to_new_id_map[connection.inside_id]);
6116     const string arg_name =
6117         StrCat(IONamePrefixes::kInputPHName, connection.port_number);
6118     VLOG(1) << "Updating " << snode->name() << ":" << connection.inside_port
6119             << " from " << snode->input(connection.inside_port) << " to "
6120             << arg_name;
6121     snode->set_input(connection.inside_port, arg_name);
6122   }
6123   std::set<string> subgraph_node_names;
6124   for (const Node* node : subgraph_nodes) {
6125     subgraph_node_names.insert(node->name());
6126   }
6127 
6128   // Remove control inputs that are not inside the segment.
6129   for (int i = 0; i < segment_def->node_size(); ++i) {
6130     auto snode = segment_def->mutable_node(i);
6131     const int input_size = snode->input_size();
6132     int input_idx = 0;
6133     int actual_input_idx = 0;
6134     while (input_idx < input_size) {
6135       TensorId input = ParseTensorName(snode->input(input_idx));
6136       if (!subgraph_node_names.count(
6137               string(input.first.data(), input.first.size())) &&
6138           !IsEngineInput(input.first)) {
6139         if (input.second == Graph::kControlSlot) {
6140           VLOG(1) << "... removing control inputs " << input.first
6141                   << " from subgraph.";
6142           ++input_idx;
6143           continue;
6144         }
6145         /// TODO(lsugy): throw error when it's not a resource input.
6146       }
6147       if (actual_input_idx != input_idx) {
6148         snode->set_input(actual_input_idx, snode->input(input_idx));
6149       }
6150       ++input_idx;
6151       ++actual_input_idx;
6152     }
6153     for (int remove = input_size - actual_input_idx; remove > 0; --remove) {
6154       snode->mutable_input()->RemoveLast();
6155     }
6156   }
6157   return Status::OK();
6158 }
6159 
operator ()(const Edge * out_edge) const6160 bool OutputEdgeValidator::operator()(const Edge* out_edge) const {
6161   if (out_edge->IsControlEdge()) return true;
6162   if (out_edge->src()->type_string() == "Const") {
6163     VLOG(1) << "--> Need to remove output node " << out_edge->src()->name()
6164             << " which is a Const.";
6165     return false;
6166   }
6167   return true;
6168 }
6169 
6170 }  // namespace convert
6171 }  // namespace tensorrt
6172 }  // namespace tensorflow
6173 
6174 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
6175