xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OPS_LAYER_UTILS_H_
16 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OPS_LAYER_UTILS_H_
17 #if GOOGLE_CUDA && GOOGLE_TENSORRT
18 
19 #include <type_traits>
20 
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
23 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/platform/statusor.h"
26 #include "third_party/tensorrt/NvInfer.h"
27 #include "third_party/tensorrt/NvInferRuntimeCommon.h"
28 
29 namespace tensorflow {
30 namespace tensorrt {
31 
32 namespace convert {
33 
34 // Facilitates the creation of TensorRT layers inside a network. The user
35 // provides a INetworkDefinition pointer during construction. They can then add
36 // operations to the network through the provided functions. Each function
37 // returns a struct which contains the symbolic result of the operation (ITensor
38 // pointer) as well as a pointer to the last TensorRT ILayer created. Some
39 // operations may create multiple layers in order to accomplish the desired
40 // result (e.g. Sign).
41 class TRTNetworkBuilder {
42  public:
Create(nvinfer1::INetworkDefinition * network,TrtWeightStore * weight_store)43   static StatusOr<TRTNetworkBuilder> Create(
44       nvinfer1::INetworkDefinition* network, TrtWeightStore* weight_store) {
45     TRT_ENSURE(network);
46     TRT_ENSURE(weight_store);
47     return TRTNetworkBuilder(network, weight_store);
48   }
49 
50  private:
TRTNetworkBuilder(nvinfer1::INetworkDefinition * network,TrtWeightStore * weight_store)51   TRTNetworkBuilder(nvinfer1::INetworkDefinition* network,
52                     TrtWeightStore* weight_store)
53       : network_(network), weight_store_(weight_store) {}
54 
55  public:
56   // Adds an Add operation to the network.
Add(nvinfer1::ITensor * lhs,nvinfer1::ITensor * rhs)57   StatusOr<nvinfer1::IElementWiseLayer*> Add(nvinfer1::ITensor* lhs,
58                                              nvinfer1::ITensor* rhs) noexcept {
59     TRT_ENSURE(lhs);
60     TRT_ENSURE(rhs);
61     nvinfer1::IElementWiseLayer* layer = network_->addElementWise(
62         *lhs, *rhs, nvinfer1::ElementWiseOperation::kSUM);
63     TRT_ENSURE(layer);
64     return layer;
65   };
66 
67   // Adds an elementwise min(lhs, rhs) operation to the network. The output has
68   // the same data type as the input.
Min(nvinfer1::ITensor * lhs,nvinfer1::ITensor * rhs)69   StatusOr<nvinfer1::IElementWiseLayer*> Min(nvinfer1::ITensor* lhs,
70                                              nvinfer1::ITensor* rhs) noexcept {
71     TRT_ENSURE(lhs);
72     TRT_ENSURE(rhs);
73     nvinfer1::IElementWiseLayer* layer = network_->addElementWise(
74         *lhs, *rhs, nvinfer1::ElementWiseOperation::kMIN);
75     TRT_ENSURE(layer);
76     return layer;
77   };
78 
79   // Adds an elementwise max(lhs, rhs) operation to the network. The output has
80   // the same datatype as the input.
Max(nvinfer1::ITensor * lhs,nvinfer1::ITensor * rhs)81   StatusOr<nvinfer1::IElementWiseLayer*> Max(nvinfer1::ITensor* lhs,
82                                              nvinfer1::ITensor* rhs) noexcept {
83     TRT_ENSURE(lhs);
84     TRT_ENSURE(rhs);
85     nvinfer1::IElementWiseLayer* layer = network_->addElementWise(
86         *lhs, *rhs, nvinfer1::ElementWiseOperation::kMAX);
87     TRT_ENSURE(layer);
88     return layer;
89   };
90 
91   // Adds an absolute value operation to the network. Note that this unary
92   // operation will do an implict float conversion. For int32 tensors, use
93   // "AbsInt".
AbsFloat(nvinfer1::ITensor * input)94   StatusOr<nvinfer1::IUnaryLayer*> AbsFloat(nvinfer1::ITensor* input) noexcept {
95     TRT_ENSURE(input);
96     TRT_ENSURE(input->getType() != nvinfer1::DataType::kFLOAT &&
97                input->getType() != nvinfer1::DataType::kHALF);
98     nvinfer1::IUnaryLayer* layer =
99         network_->addUnary(*input, nvinfer1::UnaryOperation::kABS);
100     TRT_ENSURE(layer);
101     return layer;
102   }
103 
104   // Performs Abs without implict float conversion. The input should be of type
105   // kInt32. For float datatypes, use "Abs".
AbsInt(nvinfer1::ITensor * input)106   StatusOr<nvinfer1::IElementWiseLayer*> AbsInt(
107       nvinfer1::ITensor* input) noexcept {
108     TRT_ENSURE(input);
109     TRT_ENSURE(input->getType() == nvinfer1::DataType::kINT32);
110     StatusOr<nvinfer1::IElementWiseLayer*> sign = this->SignInt(input);
111     return this->Mul(input, (*sign)->getOutput(0));
112   }
113 
114   // Returns elementwise sign(x) for int32 input tensors where sign(x) is
115   // defined as 1 where x > 0, -1 where x < 0 and 0 where x == 0.
SignInt(nvinfer1::ITensor * input)116   StatusOr<nvinfer1::IElementWiseLayer*> SignInt(
117       nvinfer1::ITensor* input) noexcept {
118     TRT_ENSURE(input);
119 
120     // Create constants +1 and -1.
121     StatusOr<nvinfer1::IConstantLayer*> one =
122         this->Constant<int32>(1, input->getDimensions().nbDims);
123     TRT_ENSURE_PTR_OK(one);
124 
125     StatusOr<nvinfer1::IConstantLayer*> neg_one =
126         this->Constant<int32>(-1, input->getDimensions().nbDims);
127     TRT_ENSURE_PTR_OK(neg_one);
128 
129     // Turn all negaitve elements into -1, positive and zero elements
130     // unaffected.
131     StatusOr<nvinfer1::IElementWiseLayer*> max =
132         this->Max(input, (*neg_one)->getOutput(0));
133     TRT_ENSURE_PTR_OK(max);
134 
135     // Turn all positive elements into +1, negative and zero elements
136     // unaffected.
137     StatusOr<nvinfer1::IElementWiseLayer*> min =
138         this->Min((*max)->getOutput(0), (*one)->getOutput(0));
139     TRT_ENSURE_PTR_OK(min);
140     return min;
141   }
142 
143   // Adds a Sub operation to the network.
Sub(nvinfer1::ITensor * lhs,nvinfer1::ITensor * rhs)144   StatusOr<nvinfer1::IElementWiseLayer*> Sub(nvinfer1::ITensor* lhs,
145                                              nvinfer1::ITensor* rhs) noexcept {
146     TRT_ENSURE(lhs);
147     TRT_ENSURE(rhs);
148     nvinfer1::IElementWiseLayer* layer = network_->addElementWise(
149         *lhs, *rhs, nvinfer1::ElementWiseOperation::kSUB);
150     TRT_ENSURE(layer);
151     return layer;
152   }
153 
154   // Adds an Greater operation to the network.
Greater(nvinfer1::ITensor * lhs,nvinfer1::ITensor * rhs)155   StatusOr<nvinfer1::IElementWiseLayer*> Greater(
156       nvinfer1::ITensor* lhs, nvinfer1::ITensor* rhs) noexcept {
157     TRT_ENSURE(lhs);
158     TRT_ENSURE(rhs);
159     nvinfer1::IElementWiseLayer* layer = network_->addElementWise(
160         *lhs, *rhs, nvinfer1::ElementWiseOperation::kGREATER);
161     TRT_ENSURE(layer);
162     return layer;
163   }
164 
165   // Adds an Equal operation to the network.
Equal(nvinfer1::ITensor * lhs,nvinfer1::ITensor * rhs)166   StatusOr<nvinfer1::IElementWiseLayer*> Equal(
167       nvinfer1::ITensor* lhs, nvinfer1::ITensor* rhs) noexcept {
168     TRT_ENSURE(lhs);
169     TRT_ENSURE(rhs);
170     nvinfer1::IElementWiseLayer* layer = network_->addElementWise(
171         *lhs, *rhs, nvinfer1::ElementWiseOperation::kEQUAL);
172     TRT_ENSURE(layer);
173     return layer;
174   }
175 
176   // Adds a FloorDiv operation to the network.
FloorDiv(nvinfer1::ITensor * lhs,nvinfer1::ITensor * rhs)177   StatusOr<nvinfer1::IElementWiseLayer*> FloorDiv(
178       nvinfer1::ITensor* lhs, nvinfer1::ITensor* rhs) noexcept {
179     TRT_ENSURE(lhs);
180     TRT_ENSURE(rhs);
181     nvinfer1::IElementWiseLayer* layer = network_->addElementWise(
182         *lhs, *rhs, nvinfer1::ElementWiseOperation::kFLOOR_DIV);
183     TRT_ENSURE(layer);
184     return layer;
185   }
186 
187   // Returns the equivalent of ceil_divide(abs(x)/abs(y))) operation. The inputs
188   // "lhs" and "rhs" should be int32 tensors.
AbsCeilDivInt(nvinfer1::ITensor * lhs,nvinfer1::ITensor * rhs)189   StatusOr<nvinfer1::IElementWiseLayer*> AbsCeilDivInt(
190       nvinfer1::ITensor* lhs, nvinfer1::ITensor* rhs) noexcept {
191     TRT_ENSURE(lhs);
192     TRT_ENSURE(rhs);
193     TRT_ENSURE(lhs->getType() == nvinfer1::DataType::kINT32);
194     TRT_ENSURE(rhs->getType() == nvinfer1::DataType::kINT32);
195 
196     StatusOr<nvinfer1::IElementWiseLayer*> rhs_abs = this->AbsInt(rhs);
197     TRT_ENSURE_PTR_OK(rhs_abs);
198     StatusOr<nvinfer1::IElementWiseLayer*> lhs_abs = this->AbsInt(lhs);
199     TRT_ENSURE_PTR_OK(lhs_abs);
200     StatusOr<nvinfer1::IElementWiseLayer*> add1 =
201         this->Add((*lhs_abs)->getOutput(0), (*rhs_abs)->getOutput(0));
202     TRT_ENSURE_PTR_OK(add1);
203     StatusOr<nvinfer1::IConstantLayer*> one_const =
204         this->Constant<int32>(1, rhs->getDimensions().nbDims);
205     TRT_ENSURE_PTR_OK(one_const);
206     StatusOr<nvinfer1::IElementWiseLayer*> numerator =
207         this->Sub((*add1)->getOutput(0), (*one_const)->getOutput(0));
208     TRT_ENSURE_PTR_OK(numerator);
209     return FloorDiv((*numerator)->getOutput(0), (*rhs_abs)->getOutput(0));
210   }
211 
212   // Adds an elementwise multiplication operation to the network.
Mul(nvinfer1::ITensor * lhs,nvinfer1::ITensor * rhs)213   StatusOr<nvinfer1::IElementWiseLayer*> Mul(nvinfer1::ITensor* lhs,
214                                              nvinfer1::ITensor* rhs) noexcept {
215     TRT_ENSURE(lhs);
216     TRT_ENSURE(rhs);
217     nvinfer1::IElementWiseLayer* layer = network_->addElementWise(
218         *lhs, *rhs, nvinfer1::ElementWiseOperation::kPROD);
219     TRT_ENSURE(layer);
220     return layer;
221   }
222 
223   // Adds a sequence of elementwise multiplication operations to the network.
224   // The returned layer's output contains the cumulative elementwise product of
225   // all tensors in the input.
CumulativeProd(absl::Span<nvinfer1::ITensor * > inputs)226   StatusOr<nvinfer1::ILayer*> CumulativeProd(
227       absl::Span<nvinfer1::ITensor*> inputs) noexcept {
228     TRT_ENSURE(!absl::c_any_of(
229         inputs, [](nvinfer1::ITensor* x) { return x == nullptr; }));
230     nvinfer1::ILayer* out = nullptr;
231     if (inputs.size() == 1) {
232       out = network_->addIdentity(*inputs[0]);
233       TRT_ENSURE(out != nullptr);
234       return out;
235     }
236     nvinfer1::ITensor* last = inputs[0];
237     for (int i = 1; i < inputs.size(); i++) {
238       StatusOr<nvinfer1::IElementWiseLayer*> mul = this->Mul(last, inputs[i]);
239       TRT_ENSURE_PTR_OK(mul);
240       out = *mul;
241       last = (*mul)->getOutput(0);
242     }
243     return out;
244   }
245 
246   // Adds a Constant layer whose output is a TensorRT shape tensor. The shape
247   // tensor's size and values correspond to dim's nbDims and d[], respectively.
ConstantShape(const DimsAdapter & shape_data)248   StatusOr<nvinfer1::IConstantLayer*> ConstantShape(
249       const DimsAdapter& shape_data) noexcept {
250     TRT_ENSURE(shape_data.NumDims() > 0);
251     nvinfer1::Dims shape_dims;
252     shape_dims.nbDims = 1;
253     shape_dims.d[0] = shape_data.NumDims();
254     StatusOr<TRT_ShapedWeights> const_weights =
255         weight_store_->GetTempWeights(nvinfer1::DataType::kINT32, shape_dims);
256     TRT_ENSURE_OK(const_weights);
257     absl::c_copy(shape_data, const_weights->GetPointer<int32>());
258     StatusOr<nvinfer1::Dims> trt_dims = const_weights->Shape().AsTrtDims();
259     TRT_ENSURE_OK(trt_dims);
260     nvinfer1::IConstantLayer* const_layer =
261         network_->addConstant(*trt_dims, const_weights->GetTrtWeights());
262     TRT_ENSURE(const_layer);
263     nvinfer1::ITensor* output = const_layer->getOutput(0);
264     TRT_ENSURE(output);
265     TRT_ENSURE(output->getType() == nvinfer1::DataType::kINT32);
266     return const_layer;
267   }
268 
269   // Adds a Constant layer whose output is a TensorRT shape tensor. The shape
270   // tensor's size and values correspond to dim's nbDims and d[], respectively.
Constant(const std::vector<int> & data)271   StatusOr<nvinfer1::IConstantLayer*> Constant(
272       const std::vector<int>& data) noexcept {
273     nvinfer1::Dims shape_dims;
274     shape_dims.nbDims = 1;
275     shape_dims.d[0] = data.size();
276     StatusOr<TRT_ShapedWeights> const_weights =
277         weight_store_->GetTempWeights(nvinfer1::DataType::kINT32, shape_dims);
278     TRT_ENSURE_OK(const_weights);
279     int32* values = const_weights->GetPointer<int32>();
280     for (int i = 0; i < data.size(); i++) {
281       values[i] = static_cast<int32>(data[i]);
282     }
283     StatusOr<nvinfer1::Dims> trt_dims = const_weights->Shape().AsTrtDims();
284     TRT_ENSURE_OK(trt_dims);
285     nvinfer1::IConstantLayer* const_layer =
286         network_->addConstant(*trt_dims, const_weights->GetTrtWeights());
287     TRT_ENSURE(const_layer);
288     nvinfer1::ITensor* output = const_layer->getOutput(0);
289     TRT_ENSURE(output);
290     TRT_ENSURE(output->getType() == nvinfer1::DataType::kINT32);
291     TRT_ENSURE(const_layer);
292     return const_layer;
293   }
294 
295   // Adds a Constant layer that produces a tensor of shape "shape",
296   // type "data_type" and filled with value "scalar".
297   template <typename T>
Constant(const T value,nvinfer1::Dims shape,nvinfer1::DataType data_type)298   StatusOr<nvinfer1::IConstantLayer*> Constant(
299       const T value, nvinfer1::Dims shape,
300       nvinfer1::DataType data_type) noexcept {
301     StatusOr<TRT_ShapedWeights> const_weights =
302         weight_store_->GetTempWeights(data_type, shape);
303     TRT_ENSURE_OK(const_weights);
304     TRT_ENSURE(const_weights->SetValues(value).ok());
305     nvinfer1::IConstantLayer* const_layer =
306         network_->addConstant(shape, const_weights->GetTrtWeights());
307     TRT_ENSURE(const_layer);
308     return const_layer;
309   }
310 
311   // Adds a Constant layer that produces a tensor with a single value "scalar".
312   // The tensor has "nb_dims" dimensions and each dimension has only one
313   // element. The data type of the tensor is determined by the data type of
314   // "scalar".
315   template <typename T,
316             typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
Constant(const T scalar,const int nb_dims)317   StatusOr<nvinfer1::IConstantLayer*> Constant(const T scalar,
318                                                const int nb_dims) noexcept {
319     TRT_ENSURE(nb_dims <= nvinfer1::Dims::MAX_DIMS);
320     auto data_type = nvinfer1::DataType::kINT32;
321     if (std::is_floating_point<T>::value) {
322       data_type = nvinfer1::DataType::kFLOAT;
323     }
324     nvinfer1::Dims zero_shape;
325     zero_shape.nbDims = nb_dims;
326     std::fill_n(zero_shape.d, nb_dims, 1);
327     return Constant<T>(scalar, zero_shape, data_type);
328   }
329 
330   // Adds a Constant layer from a TRT_ShapedWeights object.
WeightsToConstant(const nvinfer1::Weights & weights,const DimsAdapter & dims)331   StatusOr<nvinfer1::IConstantLayer*> WeightsToConstant(
332       const nvinfer1::Weights& weights, const DimsAdapter& dims) noexcept {
333     StatusOr<int64_t> vol = dims.Volume();
334     TRT_ENSURE_OK(vol);
335     TRT_ENSURE(*vol == weights.count);
336     StatusOr<nvinfer1::Dims> trt_dims = dims.AsTrtDims();
337     TRT_ENSURE_OK(trt_dims);
338     nvinfer1::IConstantLayer* const_layer =
339         network_->addConstant(*trt_dims, weights);
340     TRT_ENSURE(const_layer);
341     return const_layer;
342   }
343 
get_tensor4TensorOrWeights(const TRT_TensorOrWeights & input,ITensorProxyPtr * pTensor)344   Status get_tensor4TensorOrWeights(const TRT_TensorOrWeights& input,
345                                     ITensorProxyPtr* pTensor) {
346     if (input.is_weights()) {
347       StatusOr<nvinfer1::IConstantLayer*> const_layer = WeightsToConstant(
348           input.weights().GetTrtWeights(), input.GetTrtDims());
349       if (!const_layer.status().ok()) return const_layer.status();
350       *pTensor = (*const_layer)->getOutput(0);
351     } else {
352       *pTensor = input.tensor();
353     }
354     return Status::OK();
355   }
356   // Creates a nvinfer1::Weights object containing a single scalar.
357   template <typename T,
358             typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
ScalarWeights(const T scalar,const int nb_dims)359   StatusOr<nvinfer1::Weights> ScalarWeights(const T scalar,
360                                             const int nb_dims) noexcept {
361     TRT_ENSURE(nb_dims <= nvinfer1::Dims::MAX_DIMS);
362     auto data_type = nvinfer1::DataType::kINT32;
363     if (std::is_floating_point<T>::value) {
364       data_type = nvinfer1::DataType::kFLOAT;
365     }
366     nvinfer1::Dims weights_shape;
367     weights_shape.nbDims = nb_dims;
368     std::fill_n(weights_shape.d, nb_dims, 1);
369     StatusOr<TRT_ShapedWeights> const_weights =
370         weight_store_->GetTempWeights(data_type, weights_shape);
371     TRT_ENSURE_OK(const_weights);
372     const_weights->GetPointer<T>()[0] = scalar;
373     return const_weights->GetTrtWeights();
374   }
375 
376   // Adds a TensorRT Slice operation to the network.
Slice(nvinfer1::ITensor * input,const nvinfer1::Dims & begin,const nvinfer1::Dims & size,const nvinfer1::Dims & stride)377   StatusOr<nvinfer1::ISliceLayer*> Slice(
378       nvinfer1::ITensor* input, const nvinfer1::Dims& begin,
379       const nvinfer1::Dims& size, const nvinfer1::Dims& stride) noexcept {
380     nvinfer1::ISliceLayer* layer =
381         network_->addSlice(*input, begin, size, stride);
382     TRT_ENSURE(layer);
383     return layer;
384   }
385 
386   // Adds a TensorRT Concatenate operation to the network.
Concat(absl::Span<nvinfer1::ITensor * const> inputs,const int axis)387   StatusOr<nvinfer1::IConcatenationLayer*> Concat(
388       absl::Span<nvinfer1::ITensor* const> inputs, const int axis) {
389     for (nvinfer1::ITensor* input : inputs) {
390       TRT_ENSURE(input);
391     }
392     nvinfer1::IConcatenationLayer* layer = network_->addConcatenation(
393         inputs.data(), static_cast<int32_t>(inputs.size()));
394     TRT_ENSURE(layer);
395     layer->setAxis(axis);
396     return layer;
397   }
398 
399   // Adds a TensorRT Concatenate operation to the network.
Concat(const std::vector<nvinfer1::ITensor * > & inputs,const int axis)400   StatusOr<nvinfer1::IConcatenationLayer*> Concat(
401       const std::vector<nvinfer1::ITensor*>& inputs, const int axis) {
402     return this->Concat(absl::MakeSpan(inputs), axis);
403   }
404 
405   // Adds a TensorRT Shape operation, which determines the runtime shape of the
406   // input tensor, to the network.
Shape(nvinfer1::ITensor * input)407   StatusOr<nvinfer1::IShapeLayer*> Shape(nvinfer1::ITensor* input) {
408     TRT_ENSURE(input);
409     nvinfer1::IShapeLayer* layer = network_->addShape(*input);
410     TRT_ENSURE(layer);
411     return layer;
412   }
413 
414   // Creates a Gather operation on the shape of the input tensor. The output of
415   // the gather operation is a 1D shape tensor where output[i] = (!sub_one ?
416   // input_shape[i] : input_shape[i] -1) if i is in "indices", otherwise zero.
417   StatusOr<nvinfer1::IGatherLayer*> GetPartialShapeOf(
418       nvinfer1::ITensor* input, absl::InlinedVector<int64, 4> indices,
419       bool sub_one = false) {
420     TRT_ENSURE(input);
421     TRT_ENSURE(indices.size() <= nvinfer1::Dims::MAX_DIMS);
422 
423     // Get the runtime shape of input;
424     StatusOr<nvinfer1::IShapeLayer*> shape_layer = this->Shape(input);
425     TRT_ENSURE_PTR_OK(shape_layer);
426     nvinfer1::ITensor* runtime_shape = (*shape_layer)->getOutput(0);
427 
428     if (sub_one) {
429       StatusOr<nvinfer1::IConstantLayer*> ones = this->Constant<int32>(1, 1);
430       TRT_ENSURE_PTR_OK(ones);
431       StatusOr<nvinfer1::IElementWiseLayer*> sub =
432           this->Sub(runtime_shape, (*ones)->getOutput(0));
433       TRT_ENSURE_PTR_OK(sub);
434       runtime_shape = (*sub)->getOutput(0);
435     }
436 
437     // Create a constant tensor containing the gather indices.
438     // For any dim not in "indices", we mark it size to gather a zero.
439     const int input_nb_dims = input->getDimensions().nbDims;
440     std::vector<int> indices_all(input_nb_dims, input_nb_dims);
441     for (auto idx : indices) {
442       TRT_ENSURE(idx < input_nb_dims);
443       indices_all[idx] = idx;
444     }
445 
446     StatusOr<nvinfer1::IConstantLayer*> indices_result =
447         this->Constant(indices_all);
448     TRT_ENSURE_PTR_OK(indices_result);
449     nvinfer1::ITensor* gather_indices = (*indices_result)->getOutput(0);
450     TRT_ENSURE(gather_indices->getDimensions().nbDims == 1);
451     TRT_ENSURE(gather_indices->getType() == nvinfer1::DataType::kINT32);
452 
453     // Append a zero to the shape tensor.
454     StatusOr<nvinfer1::IConstantLayer*> zero_result =
455         this->Constant(std::vector<int>{0});
456     TRT_ENSURE_PTR_OK(zero_result);
457     std::array<nvinfer1::ITensor*, 2> cat_inputs = {
458         runtime_shape, (*zero_result)->getOutput(0)};
459     nvinfer1::IConcatenationLayer* cat_layer =
460         network_->addConcatenation(cat_inputs.data(), cat_inputs.size());
461     TRT_ENSURE(cat_layer);
462     nvinfer1::ITensor* gather_input = cat_layer->getOutput(0);
463     TRT_ENSURE(gather_input);
464 
465     // Finally, gather the indices from the input.
466     nvinfer1::IGatherLayer* gather =
467         network_->addGather(*gather_input, *gather_indices, 0);
468     TRT_ENSURE(gather);
469     return gather;
470   }
471 
472   // Adds a scale layer that uniformly scales the input tensor by the specified
473   // amount.
AddUniformScale(nvinfer1::ITensor * input,float scale,const std::string & name)474   StatusOr<nvinfer1::IScaleLayer*> AddUniformScale(nvinfer1::ITensor* input,
475                                                    float scale,
476                                                    const std::string& name) {
477     TRT_ENSURE(input);
478     TRT_ENSURE(!name.empty());
479     StatusOr<nvinfer1::Weights> weight = this->ScalarWeights<float>(scale, 1);
480     TRT_ENSURE_OK(weight);
481     const nvinfer1::Weights empty_weights =
482         nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0};
483     nvinfer1::IScaleLayer* scale_layer =
484         network_->addScale(*input, nvinfer1::ScaleMode::kUNIFORM, empty_weights,
485                            (*weight), empty_weights);
486     TRT_ENSURE(scale_layer != nullptr);
487     scale_layer->setName(name.c_str());
488     TRT_ENSURE((*scale_layer).getPower().count == 0);
489     TRT_ENSURE((*scale_layer).getShift().count == 0);
490     TRT_ENSURE((*scale_layer).getScale().count == 1);
491     return scale_layer;
492   }
493 
494   StatusOr<nvinfer1::ILayer*> AddFill(const TRT_TensorOrWeights& value_input,
495                                       const TRT_TensorOrWeights& dims_input,
496                                       bool is_value_static, bool is_dims_static,
497                                       int nbDims,
498                                       const nvinfer1::Dims& trt_dims,
499                                       ITensorProxyPtr scalar_tensor = nullptr,
500                                       ITensorProxyPtr beta_tensor = nullptr,
501                                       const float delta = 0) {
502     // TensorRT IFillLayer requires a rank 0 scalar.
503     nvinfer1::Dims scalar_dims;
504     scalar_dims.nbDims = 0;
505     if (is_value_static) {
506       StatusOr<nvinfer1::IConstantLayer*> const_layer =
507           WeightsToConstant(value_input.weights().GetTrtWeights(), scalar_dims);
508       if (!const_layer.status().ok()) return const_layer.status();
509       scalar_tensor = (*const_layer)->getOutput(0);
510     } else {
511       if (scalar_tensor == nullptr) {
512         StatusOr<nvinfer1::IShuffleLayer*> shuffler_layer =
513             Reshape(value_input.tensor()->trt_tensor(), scalar_dims);
514         if (!shuffler_layer.status().ok()) return shuffler_layer.status();
515         scalar_tensor = (*shuffler_layer)->getOutput(0);
516       }
517     }
518 
519     if (beta_tensor == nullptr) {
520       nvinfer1::Dims beta_shape{1, {nbDims}};
521       StatusOr<nvinfer1::IConstantLayer*> const_layer =
522           Constant(delta, beta_shape, value_input.TrtDType());
523       TF_RETURN_IF_ERROR(const_layer.status());
524       beta_tensor = (*const_layer)->getOutput(0);
525     }
526 
527     nvinfer1::IFillLayer* layer =
528         network_->addFill(trt_dims, nvinfer1::FillOperation::kLINSPACE);
529     TRT_ENSURE(layer);
530     if (!is_dims_static) {
531       layer->setInput(0, *dims_input.tensor()->trt_tensor());
532     }
533     layer->setInput(1, *scalar_tensor->trt_tensor());
534     layer->setInput(2, *beta_tensor->trt_tensor());
535     return layer;
536   }
537 
538   // Adds a quantization layer that uniformly scales the input tensor
539   // by the given multiplicative "scaling_factor", then rounds
540   // (round-to-nearest-ties-to-even) to the nearest integer and clamps in the
541   // range of [-128, 127].
Quantize(nvinfer1::ITensor * input,const float scaling_factor,const std::string & name)542   StatusOr<nvinfer1::ILayer*> Quantize(nvinfer1::ITensor* input,
543                                        const float scaling_factor,
544                                        const std::string& name) {
545     TRT_ENSURE(input);
546     TRT_ENSURE(!name.empty());
547     // Preprocessor usage here is unavoidable because TRT8 API is new.
548 #if IS_TRT_VERSION_GE(8, 0, 0, 0)
549     // The TensorRT IQuantizeLayer divides by the scale factor rather than
550     // multiplies. To be consistent, in this function we expect a multiplicative
551     // scale factor, so we take the reciprical.
552     StatusOr<nvinfer1::IConstantLayer*> scaling_const =
553         this->Constant<float>(1.0f / scaling_factor, 1);
554     TRT_ENSURE_PTR_OK(scaling_const);
555     (*scaling_const)->setDimensions(nvinfer1::Dims{0, {}});
556     nvinfer1::IQuantizeLayer* quant_layer =
557         network_->addQuantize(*input, *(*scaling_const)->getOutput(0));
558     TRT_ENSURE(quant_layer);
559     quant_layer->setAxis(1);
560     return quant_layer;
561 #else
562     StatusOr<nvinfer1::IScaleLayer*> result =
563         this->AddUniformScale(input, scaling_factor, name);
564     TRT_ENSURE_PTR_OK(result);
565     (*result)->setOutputType(0, nvinfer1::DataType::kINT8);
566     (*result)->setPrecision(nvinfer1::DataType::kFLOAT);
567     return result;
568 #endif
569   }
570 
571   // Adds a dequantize layer that casts the input tensor to TensorRT float type
572   // and scales it uniformly by the given multiplicative "scaling_factor".
Dequantize(nvinfer1::ITensor * input,const float scaling_factor,const std::string & name)573   StatusOr<nvinfer1::ILayer*> Dequantize(nvinfer1::ITensor* input,
574                                          const float scaling_factor,
575                                          const std::string& name) {
576     TRT_ENSURE(input);
577     TRT_ENSURE(!name.empty());
578 #if IS_TRT_VERSION_GE(8, 0, 0, 0)
579     StatusOr<nvinfer1::IConstantLayer*> scaling_const =
580         this->Constant<float>(scaling_factor, 1);
581     TRT_ENSURE_PTR_OK(scaling_const);
582     (*scaling_const)->setDimensions(nvinfer1::Dims{0, {}});
583     nvinfer1::IDequantizeLayer* dequant_layer =
584         network_->addDequantize(*input, *(*scaling_const)->getOutput(0));
585     dequant_layer->setAxis(1);
586     TRT_ENSURE(dequant_layer);
587     return dequant_layer;
588 #else
589     StatusOr<nvinfer1::IScaleLayer*> result =
590         this->AddUniformScale(input, scaling_factor, name);
591     TRT_ENSURE_PTR_OK(result);
592     (*result)->setOutputType(0, nvinfer1::DataType::kFLOAT);
593     (*result)->setPrecision(nvinfer1::DataType::kINT8);
594     return result;
595 #endif
596   }
597 
598   // Adds TensorRT Q/DQ operations. This is for explicit precision mode.
UniformQuantizeDequantizeExplicit(nvinfer1::ITensor * input,float quantize_scale,float dequantize_scale,const std::string & name)599   StatusOr<nvinfer1::ILayer*> UniformQuantizeDequantizeExplicit(
600       nvinfer1::ITensor* input, float quantize_scale, float dequantize_scale,
601       const std::string& name) {
602     TRT_ENSURE(input);
603     if (!IS_TRT_VERSION_GE(8, 0, 0, 0)) {
604       TRT_ENSURE(network_->hasExplicitPrecision());
605     }
606     TRT_ENSURE(IS_TRT_VERSION_GE(7, 1, 0, 0));
607 
608     static int count = 0;
609     TRT_ENSURE(input->getType() == nvinfer1::DataType::kFLOAT);
610     std::string quant_name = absl::StrCat(input->getName(), "_quant_", count);
611 
612     StatusOr<nvinfer1::ILayer*> quant =
613         this->Quantize(input, quantize_scale, quant_name);
614     TRT_ENSURE_PTR_OK(quant);
615 
616     std::string dequant_name =
617         absl::StrCat(input->getName(), "_dequant_", count);
618     StatusOr<nvinfer1::ILayer*> dequant = this->Dequantize(
619         (*quant)->getOutput(0), dequantize_scale, dequant_name);
620     TRT_ENSURE_PTR_OK(dequant);
621 
622     count++;
623     return dequant;
624   }
625 
Reshape(nvinfer1::ITensor * input,const nvinfer1::Dims & new_shape)626   StatusOr<nvinfer1::IShuffleLayer*> Reshape(nvinfer1::ITensor* input,
627                                              const nvinfer1::Dims& new_shape) {
628     TRT_ENSURE(input);
629     nvinfer1::IShuffleLayer* layer = network_->addShuffle(*input);
630     TRT_ENSURE(layer);
631     layer->setReshapeDimensions(new_shape);
632     return layer;
633   }
634 
FindProducerOf(const nvinfer1::ITensor * tensor)635   StatusOr<nvinfer1::ILayer*> FindProducerOf(const nvinfer1::ITensor* tensor) {
636     const char* name = tensor->getName();
637     const int num_layers = network_->getNbLayers();
638     for (int i = 0; i < num_layers; i++) {
639       nvinfer1::ILayer* layer = network_->getLayer(i);
640       const int num_outputs = layer->getNbOutputs();
641       for (int j = 0; j < num_outputs; j++) {
642         nvinfer1::ITensor* t = layer->getOutput(j);
643         if (std::string(t->getName()) == name) {
644           return layer;
645         }
646       }
647     }
648     return errors::NotFound("could not find producing layer of ", name);
649   }
650 
651   StatusOr<nvinfer1::ILayer*> UniqueParentOf(const nvinfer1::ILayer* layer,
652                                              int input_idx = 0) {
653     return FindProducerOf(layer->getInput(input_idx));
654   }
655 
Network()656   nvinfer1::INetworkDefinition* Network() { return network_; }
657 
658  private:
659   nvinfer1::INetworkDefinition* const network_;
660   TrtWeightStore* const weight_store_;
661 };
662 
663 class ShuffleBuilder {
664  private:
ShuffleBuilder(TRTNetworkBuilder * builder,nvinfer1::ITensor * input)665   explicit ShuffleBuilder(TRTNetworkBuilder* builder, nvinfer1::ITensor* input)
666       : builder_(builder) {
667     layer_ = builder->Network()->addShuffle(*input);
668   }
669 
670  public:
Create(TRTNetworkBuilder * builder,nvinfer1::ITensor * input)671   static StatusOr<ShuffleBuilder> Create(TRTNetworkBuilder* builder,
672                                          nvinfer1::ITensor* input) {
673     TRT_ENSURE(builder != nullptr);
674     TRT_ENSURE(input != nullptr);
675     return ShuffleBuilder(builder, input);
676   }
677 
SetReshape(const nvinfer1::Dims & dims)678   ShuffleBuilder& SetReshape(const nvinfer1::Dims& dims) {
679     layer_->setReshapeDimensions(dims);
680     return *this;
681   }
682 
SetReshape(nvinfer1::ITensor * shape)683   ShuffleBuilder& SetReshape(nvinfer1::ITensor* shape) {
684     layer_->setInput(1, *shape);
685     return *this;
686   }
687 
SetFirstTranspose(const nvinfer1::Permutation & perm)688   ShuffleBuilder& SetFirstTranspose(const nvinfer1::Permutation& perm) {
689     layer_->setFirstTranspose(perm);
690     return *this;
691   }
692 
SetSecondTranspose(const nvinfer1::Permutation & perm)693   ShuffleBuilder& SetSecondTranspose(const nvinfer1::Permutation& perm) {
694     layer_->setSecondTranspose(perm);
695     return *this;
696   }
697 
Output()698   StatusOr<nvinfer1::ITensor*> Output() {
699     TRT_ENSURE(layer_ != nullptr);
700     TRT_ENSURE(layer_->getOutput(0) != nullptr);
701     return layer_->getOutput(0);
702   }
703 
704  private:
705   TRTNetworkBuilder* builder_;
706   nvinfer1::IShuffleLayer* layer_;
707 };
708 
709 }  // namespace convert
710 }  // namespace tensorrt
711 }  // namespace tensorflow
712 
713 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
714 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OPS_LAYER_UTILS_H_
715