1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OPS_LAYER_UTILS_H_ 16 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OPS_LAYER_UTILS_H_ 17 #if GOOGLE_CUDA && GOOGLE_TENSORRT 18 19 #include <type_traits> 20 21 #include "absl/strings/str_cat.h" 22 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" 23 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" 24 #include "tensorflow/core/framework/types.h" 25 #include "tensorflow/core/platform/statusor.h" 26 #include "third_party/tensorrt/NvInfer.h" 27 #include "third_party/tensorrt/NvInferRuntimeCommon.h" 28 29 namespace tensorflow { 30 namespace tensorrt { 31 32 namespace convert { 33 34 // Facilitates the creation of TensorRT layers inside a network. The user 35 // provides a INetworkDefinition pointer during construction. They can then add 36 // operations to the network through the provided functions. Each function 37 // returns a struct which contains the symbolic result of the operation (ITensor 38 // pointer) as well as a pointer to the last TensorRT ILayer created. Some 39 // operations may create multiple layers in order to accomplish the desired 40 // result (e.g. Sign). 41 class TRTNetworkBuilder { 42 public: Create(nvinfer1::INetworkDefinition * network,TrtWeightStore * weight_store)43 static StatusOr<TRTNetworkBuilder> Create( 44 nvinfer1::INetworkDefinition* network, TrtWeightStore* weight_store) { 45 TRT_ENSURE(network); 46 TRT_ENSURE(weight_store); 47 return TRTNetworkBuilder(network, weight_store); 48 } 49 50 private: TRTNetworkBuilder(nvinfer1::INetworkDefinition * network,TrtWeightStore * weight_store)51 TRTNetworkBuilder(nvinfer1::INetworkDefinition* network, 52 TrtWeightStore* weight_store) 53 : network_(network), weight_store_(weight_store) {} 54 55 public: 56 // Adds an Add operation to the network. Add(nvinfer1::ITensor * lhs,nvinfer1::ITensor * rhs)57 StatusOr<nvinfer1::IElementWiseLayer*> Add(nvinfer1::ITensor* lhs, 58 nvinfer1::ITensor* rhs) noexcept { 59 TRT_ENSURE(lhs); 60 TRT_ENSURE(rhs); 61 nvinfer1::IElementWiseLayer* layer = network_->addElementWise( 62 *lhs, *rhs, nvinfer1::ElementWiseOperation::kSUM); 63 TRT_ENSURE(layer); 64 return layer; 65 }; 66 67 // Adds an elementwise min(lhs, rhs) operation to the network. The output has 68 // the same data type as the input. Min(nvinfer1::ITensor * lhs,nvinfer1::ITensor * rhs)69 StatusOr<nvinfer1::IElementWiseLayer*> Min(nvinfer1::ITensor* lhs, 70 nvinfer1::ITensor* rhs) noexcept { 71 TRT_ENSURE(lhs); 72 TRT_ENSURE(rhs); 73 nvinfer1::IElementWiseLayer* layer = network_->addElementWise( 74 *lhs, *rhs, nvinfer1::ElementWiseOperation::kMIN); 75 TRT_ENSURE(layer); 76 return layer; 77 }; 78 79 // Adds an elementwise max(lhs, rhs) operation to the network. The output has 80 // the same datatype as the input. Max(nvinfer1::ITensor * lhs,nvinfer1::ITensor * rhs)81 StatusOr<nvinfer1::IElementWiseLayer*> Max(nvinfer1::ITensor* lhs, 82 nvinfer1::ITensor* rhs) noexcept { 83 TRT_ENSURE(lhs); 84 TRT_ENSURE(rhs); 85 nvinfer1::IElementWiseLayer* layer = network_->addElementWise( 86 *lhs, *rhs, nvinfer1::ElementWiseOperation::kMAX); 87 TRT_ENSURE(layer); 88 return layer; 89 }; 90 91 // Adds an absolute value operation to the network. Note that this unary 92 // operation will do an implict float conversion. For int32 tensors, use 93 // "AbsInt". AbsFloat(nvinfer1::ITensor * input)94 StatusOr<nvinfer1::IUnaryLayer*> AbsFloat(nvinfer1::ITensor* input) noexcept { 95 TRT_ENSURE(input); 96 TRT_ENSURE(input->getType() != nvinfer1::DataType::kFLOAT && 97 input->getType() != nvinfer1::DataType::kHALF); 98 nvinfer1::IUnaryLayer* layer = 99 network_->addUnary(*input, nvinfer1::UnaryOperation::kABS); 100 TRT_ENSURE(layer); 101 return layer; 102 } 103 104 // Performs Abs without implict float conversion. The input should be of type 105 // kInt32. For float datatypes, use "Abs". AbsInt(nvinfer1::ITensor * input)106 StatusOr<nvinfer1::IElementWiseLayer*> AbsInt( 107 nvinfer1::ITensor* input) noexcept { 108 TRT_ENSURE(input); 109 TRT_ENSURE(input->getType() == nvinfer1::DataType::kINT32); 110 StatusOr<nvinfer1::IElementWiseLayer*> sign = this->SignInt(input); 111 return this->Mul(input, (*sign)->getOutput(0)); 112 } 113 114 // Returns elementwise sign(x) for int32 input tensors where sign(x) is 115 // defined as 1 where x > 0, -1 where x < 0 and 0 where x == 0. SignInt(nvinfer1::ITensor * input)116 StatusOr<nvinfer1::IElementWiseLayer*> SignInt( 117 nvinfer1::ITensor* input) noexcept { 118 TRT_ENSURE(input); 119 120 // Create constants +1 and -1. 121 StatusOr<nvinfer1::IConstantLayer*> one = 122 this->Constant<int32>(1, input->getDimensions().nbDims); 123 TRT_ENSURE_PTR_OK(one); 124 125 StatusOr<nvinfer1::IConstantLayer*> neg_one = 126 this->Constant<int32>(-1, input->getDimensions().nbDims); 127 TRT_ENSURE_PTR_OK(neg_one); 128 129 // Turn all negaitve elements into -1, positive and zero elements 130 // unaffected. 131 StatusOr<nvinfer1::IElementWiseLayer*> max = 132 this->Max(input, (*neg_one)->getOutput(0)); 133 TRT_ENSURE_PTR_OK(max); 134 135 // Turn all positive elements into +1, negative and zero elements 136 // unaffected. 137 StatusOr<nvinfer1::IElementWiseLayer*> min = 138 this->Min((*max)->getOutput(0), (*one)->getOutput(0)); 139 TRT_ENSURE_PTR_OK(min); 140 return min; 141 } 142 143 // Adds a Sub operation to the network. Sub(nvinfer1::ITensor * lhs,nvinfer1::ITensor * rhs)144 StatusOr<nvinfer1::IElementWiseLayer*> Sub(nvinfer1::ITensor* lhs, 145 nvinfer1::ITensor* rhs) noexcept { 146 TRT_ENSURE(lhs); 147 TRT_ENSURE(rhs); 148 nvinfer1::IElementWiseLayer* layer = network_->addElementWise( 149 *lhs, *rhs, nvinfer1::ElementWiseOperation::kSUB); 150 TRT_ENSURE(layer); 151 return layer; 152 } 153 154 // Adds an Greater operation to the network. Greater(nvinfer1::ITensor * lhs,nvinfer1::ITensor * rhs)155 StatusOr<nvinfer1::IElementWiseLayer*> Greater( 156 nvinfer1::ITensor* lhs, nvinfer1::ITensor* rhs) noexcept { 157 TRT_ENSURE(lhs); 158 TRT_ENSURE(rhs); 159 nvinfer1::IElementWiseLayer* layer = network_->addElementWise( 160 *lhs, *rhs, nvinfer1::ElementWiseOperation::kGREATER); 161 TRT_ENSURE(layer); 162 return layer; 163 } 164 165 // Adds an Equal operation to the network. Equal(nvinfer1::ITensor * lhs,nvinfer1::ITensor * rhs)166 StatusOr<nvinfer1::IElementWiseLayer*> Equal( 167 nvinfer1::ITensor* lhs, nvinfer1::ITensor* rhs) noexcept { 168 TRT_ENSURE(lhs); 169 TRT_ENSURE(rhs); 170 nvinfer1::IElementWiseLayer* layer = network_->addElementWise( 171 *lhs, *rhs, nvinfer1::ElementWiseOperation::kEQUAL); 172 TRT_ENSURE(layer); 173 return layer; 174 } 175 176 // Adds a FloorDiv operation to the network. FloorDiv(nvinfer1::ITensor * lhs,nvinfer1::ITensor * rhs)177 StatusOr<nvinfer1::IElementWiseLayer*> FloorDiv( 178 nvinfer1::ITensor* lhs, nvinfer1::ITensor* rhs) noexcept { 179 TRT_ENSURE(lhs); 180 TRT_ENSURE(rhs); 181 nvinfer1::IElementWiseLayer* layer = network_->addElementWise( 182 *lhs, *rhs, nvinfer1::ElementWiseOperation::kFLOOR_DIV); 183 TRT_ENSURE(layer); 184 return layer; 185 } 186 187 // Returns the equivalent of ceil_divide(abs(x)/abs(y))) operation. The inputs 188 // "lhs" and "rhs" should be int32 tensors. AbsCeilDivInt(nvinfer1::ITensor * lhs,nvinfer1::ITensor * rhs)189 StatusOr<nvinfer1::IElementWiseLayer*> AbsCeilDivInt( 190 nvinfer1::ITensor* lhs, nvinfer1::ITensor* rhs) noexcept { 191 TRT_ENSURE(lhs); 192 TRT_ENSURE(rhs); 193 TRT_ENSURE(lhs->getType() == nvinfer1::DataType::kINT32); 194 TRT_ENSURE(rhs->getType() == nvinfer1::DataType::kINT32); 195 196 StatusOr<nvinfer1::IElementWiseLayer*> rhs_abs = this->AbsInt(rhs); 197 TRT_ENSURE_PTR_OK(rhs_abs); 198 StatusOr<nvinfer1::IElementWiseLayer*> lhs_abs = this->AbsInt(lhs); 199 TRT_ENSURE_PTR_OK(lhs_abs); 200 StatusOr<nvinfer1::IElementWiseLayer*> add1 = 201 this->Add((*lhs_abs)->getOutput(0), (*rhs_abs)->getOutput(0)); 202 TRT_ENSURE_PTR_OK(add1); 203 StatusOr<nvinfer1::IConstantLayer*> one_const = 204 this->Constant<int32>(1, rhs->getDimensions().nbDims); 205 TRT_ENSURE_PTR_OK(one_const); 206 StatusOr<nvinfer1::IElementWiseLayer*> numerator = 207 this->Sub((*add1)->getOutput(0), (*one_const)->getOutput(0)); 208 TRT_ENSURE_PTR_OK(numerator); 209 return FloorDiv((*numerator)->getOutput(0), (*rhs_abs)->getOutput(0)); 210 } 211 212 // Adds an elementwise multiplication operation to the network. Mul(nvinfer1::ITensor * lhs,nvinfer1::ITensor * rhs)213 StatusOr<nvinfer1::IElementWiseLayer*> Mul(nvinfer1::ITensor* lhs, 214 nvinfer1::ITensor* rhs) noexcept { 215 TRT_ENSURE(lhs); 216 TRT_ENSURE(rhs); 217 nvinfer1::IElementWiseLayer* layer = network_->addElementWise( 218 *lhs, *rhs, nvinfer1::ElementWiseOperation::kPROD); 219 TRT_ENSURE(layer); 220 return layer; 221 } 222 223 // Adds a sequence of elementwise multiplication operations to the network. 224 // The returned layer's output contains the cumulative elementwise product of 225 // all tensors in the input. CumulativeProd(absl::Span<nvinfer1::ITensor * > inputs)226 StatusOr<nvinfer1::ILayer*> CumulativeProd( 227 absl::Span<nvinfer1::ITensor*> inputs) noexcept { 228 TRT_ENSURE(!absl::c_any_of( 229 inputs, [](nvinfer1::ITensor* x) { return x == nullptr; })); 230 nvinfer1::ILayer* out = nullptr; 231 if (inputs.size() == 1) { 232 out = network_->addIdentity(*inputs[0]); 233 TRT_ENSURE(out != nullptr); 234 return out; 235 } 236 nvinfer1::ITensor* last = inputs[0]; 237 for (int i = 1; i < inputs.size(); i++) { 238 StatusOr<nvinfer1::IElementWiseLayer*> mul = this->Mul(last, inputs[i]); 239 TRT_ENSURE_PTR_OK(mul); 240 out = *mul; 241 last = (*mul)->getOutput(0); 242 } 243 return out; 244 } 245 246 // Adds a Constant layer whose output is a TensorRT shape tensor. The shape 247 // tensor's size and values correspond to dim's nbDims and d[], respectively. ConstantShape(const DimsAdapter & shape_data)248 StatusOr<nvinfer1::IConstantLayer*> ConstantShape( 249 const DimsAdapter& shape_data) noexcept { 250 TRT_ENSURE(shape_data.NumDims() > 0); 251 nvinfer1::Dims shape_dims; 252 shape_dims.nbDims = 1; 253 shape_dims.d[0] = shape_data.NumDims(); 254 StatusOr<TRT_ShapedWeights> const_weights = 255 weight_store_->GetTempWeights(nvinfer1::DataType::kINT32, shape_dims); 256 TRT_ENSURE_OK(const_weights); 257 absl::c_copy(shape_data, const_weights->GetPointer<int32>()); 258 StatusOr<nvinfer1::Dims> trt_dims = const_weights->Shape().AsTrtDims(); 259 TRT_ENSURE_OK(trt_dims); 260 nvinfer1::IConstantLayer* const_layer = 261 network_->addConstant(*trt_dims, const_weights->GetTrtWeights()); 262 TRT_ENSURE(const_layer); 263 nvinfer1::ITensor* output = const_layer->getOutput(0); 264 TRT_ENSURE(output); 265 TRT_ENSURE(output->getType() == nvinfer1::DataType::kINT32); 266 return const_layer; 267 } 268 269 // Adds a Constant layer whose output is a TensorRT shape tensor. The shape 270 // tensor's size and values correspond to dim's nbDims and d[], respectively. Constant(const std::vector<int> & data)271 StatusOr<nvinfer1::IConstantLayer*> Constant( 272 const std::vector<int>& data) noexcept { 273 nvinfer1::Dims shape_dims; 274 shape_dims.nbDims = 1; 275 shape_dims.d[0] = data.size(); 276 StatusOr<TRT_ShapedWeights> const_weights = 277 weight_store_->GetTempWeights(nvinfer1::DataType::kINT32, shape_dims); 278 TRT_ENSURE_OK(const_weights); 279 int32* values = const_weights->GetPointer<int32>(); 280 for (int i = 0; i < data.size(); i++) { 281 values[i] = static_cast<int32>(data[i]); 282 } 283 StatusOr<nvinfer1::Dims> trt_dims = const_weights->Shape().AsTrtDims(); 284 TRT_ENSURE_OK(trt_dims); 285 nvinfer1::IConstantLayer* const_layer = 286 network_->addConstant(*trt_dims, const_weights->GetTrtWeights()); 287 TRT_ENSURE(const_layer); 288 nvinfer1::ITensor* output = const_layer->getOutput(0); 289 TRT_ENSURE(output); 290 TRT_ENSURE(output->getType() == nvinfer1::DataType::kINT32); 291 TRT_ENSURE(const_layer); 292 return const_layer; 293 } 294 295 // Adds a Constant layer that produces a tensor of shape "shape", 296 // type "data_type" and filled with value "scalar". 297 template <typename T> Constant(const T value,nvinfer1::Dims shape,nvinfer1::DataType data_type)298 StatusOr<nvinfer1::IConstantLayer*> Constant( 299 const T value, nvinfer1::Dims shape, 300 nvinfer1::DataType data_type) noexcept { 301 StatusOr<TRT_ShapedWeights> const_weights = 302 weight_store_->GetTempWeights(data_type, shape); 303 TRT_ENSURE_OK(const_weights); 304 TRT_ENSURE(const_weights->SetValues(value).ok()); 305 nvinfer1::IConstantLayer* const_layer = 306 network_->addConstant(shape, const_weights->GetTrtWeights()); 307 TRT_ENSURE(const_layer); 308 return const_layer; 309 } 310 311 // Adds a Constant layer that produces a tensor with a single value "scalar". 312 // The tensor has "nb_dims" dimensions and each dimension has only one 313 // element. The data type of the tensor is determined by the data type of 314 // "scalar". 315 template <typename T, 316 typename std::enable_if<std::is_pod<T>::value>::type* = nullptr> Constant(const T scalar,const int nb_dims)317 StatusOr<nvinfer1::IConstantLayer*> Constant(const T scalar, 318 const int nb_dims) noexcept { 319 TRT_ENSURE(nb_dims <= nvinfer1::Dims::MAX_DIMS); 320 auto data_type = nvinfer1::DataType::kINT32; 321 if (std::is_floating_point<T>::value) { 322 data_type = nvinfer1::DataType::kFLOAT; 323 } 324 nvinfer1::Dims zero_shape; 325 zero_shape.nbDims = nb_dims; 326 std::fill_n(zero_shape.d, nb_dims, 1); 327 return Constant<T>(scalar, zero_shape, data_type); 328 } 329 330 // Adds a Constant layer from a TRT_ShapedWeights object. WeightsToConstant(const nvinfer1::Weights & weights,const DimsAdapter & dims)331 StatusOr<nvinfer1::IConstantLayer*> WeightsToConstant( 332 const nvinfer1::Weights& weights, const DimsAdapter& dims) noexcept { 333 StatusOr<int64_t> vol = dims.Volume(); 334 TRT_ENSURE_OK(vol); 335 TRT_ENSURE(*vol == weights.count); 336 StatusOr<nvinfer1::Dims> trt_dims = dims.AsTrtDims(); 337 TRT_ENSURE_OK(trt_dims); 338 nvinfer1::IConstantLayer* const_layer = 339 network_->addConstant(*trt_dims, weights); 340 TRT_ENSURE(const_layer); 341 return const_layer; 342 } 343 get_tensor4TensorOrWeights(const TRT_TensorOrWeights & input,ITensorProxyPtr * pTensor)344 Status get_tensor4TensorOrWeights(const TRT_TensorOrWeights& input, 345 ITensorProxyPtr* pTensor) { 346 if (input.is_weights()) { 347 StatusOr<nvinfer1::IConstantLayer*> const_layer = WeightsToConstant( 348 input.weights().GetTrtWeights(), input.GetTrtDims()); 349 if (!const_layer.status().ok()) return const_layer.status(); 350 *pTensor = (*const_layer)->getOutput(0); 351 } else { 352 *pTensor = input.tensor(); 353 } 354 return Status::OK(); 355 } 356 // Creates a nvinfer1::Weights object containing a single scalar. 357 template <typename T, 358 typename std::enable_if<std::is_pod<T>::value>::type* = nullptr> ScalarWeights(const T scalar,const int nb_dims)359 StatusOr<nvinfer1::Weights> ScalarWeights(const T scalar, 360 const int nb_dims) noexcept { 361 TRT_ENSURE(nb_dims <= nvinfer1::Dims::MAX_DIMS); 362 auto data_type = nvinfer1::DataType::kINT32; 363 if (std::is_floating_point<T>::value) { 364 data_type = nvinfer1::DataType::kFLOAT; 365 } 366 nvinfer1::Dims weights_shape; 367 weights_shape.nbDims = nb_dims; 368 std::fill_n(weights_shape.d, nb_dims, 1); 369 StatusOr<TRT_ShapedWeights> const_weights = 370 weight_store_->GetTempWeights(data_type, weights_shape); 371 TRT_ENSURE_OK(const_weights); 372 const_weights->GetPointer<T>()[0] = scalar; 373 return const_weights->GetTrtWeights(); 374 } 375 376 // Adds a TensorRT Slice operation to the network. Slice(nvinfer1::ITensor * input,const nvinfer1::Dims & begin,const nvinfer1::Dims & size,const nvinfer1::Dims & stride)377 StatusOr<nvinfer1::ISliceLayer*> Slice( 378 nvinfer1::ITensor* input, const nvinfer1::Dims& begin, 379 const nvinfer1::Dims& size, const nvinfer1::Dims& stride) noexcept { 380 nvinfer1::ISliceLayer* layer = 381 network_->addSlice(*input, begin, size, stride); 382 TRT_ENSURE(layer); 383 return layer; 384 } 385 386 // Adds a TensorRT Concatenate operation to the network. Concat(absl::Span<nvinfer1::ITensor * const> inputs,const int axis)387 StatusOr<nvinfer1::IConcatenationLayer*> Concat( 388 absl::Span<nvinfer1::ITensor* const> inputs, const int axis) { 389 for (nvinfer1::ITensor* input : inputs) { 390 TRT_ENSURE(input); 391 } 392 nvinfer1::IConcatenationLayer* layer = network_->addConcatenation( 393 inputs.data(), static_cast<int32_t>(inputs.size())); 394 TRT_ENSURE(layer); 395 layer->setAxis(axis); 396 return layer; 397 } 398 399 // Adds a TensorRT Concatenate operation to the network. Concat(const std::vector<nvinfer1::ITensor * > & inputs,const int axis)400 StatusOr<nvinfer1::IConcatenationLayer*> Concat( 401 const std::vector<nvinfer1::ITensor*>& inputs, const int axis) { 402 return this->Concat(absl::MakeSpan(inputs), axis); 403 } 404 405 // Adds a TensorRT Shape operation, which determines the runtime shape of the 406 // input tensor, to the network. Shape(nvinfer1::ITensor * input)407 StatusOr<nvinfer1::IShapeLayer*> Shape(nvinfer1::ITensor* input) { 408 TRT_ENSURE(input); 409 nvinfer1::IShapeLayer* layer = network_->addShape(*input); 410 TRT_ENSURE(layer); 411 return layer; 412 } 413 414 // Creates a Gather operation on the shape of the input tensor. The output of 415 // the gather operation is a 1D shape tensor where output[i] = (!sub_one ? 416 // input_shape[i] : input_shape[i] -1) if i is in "indices", otherwise zero. 417 StatusOr<nvinfer1::IGatherLayer*> GetPartialShapeOf( 418 nvinfer1::ITensor* input, absl::InlinedVector<int64, 4> indices, 419 bool sub_one = false) { 420 TRT_ENSURE(input); 421 TRT_ENSURE(indices.size() <= nvinfer1::Dims::MAX_DIMS); 422 423 // Get the runtime shape of input; 424 StatusOr<nvinfer1::IShapeLayer*> shape_layer = this->Shape(input); 425 TRT_ENSURE_PTR_OK(shape_layer); 426 nvinfer1::ITensor* runtime_shape = (*shape_layer)->getOutput(0); 427 428 if (sub_one) { 429 StatusOr<nvinfer1::IConstantLayer*> ones = this->Constant<int32>(1, 1); 430 TRT_ENSURE_PTR_OK(ones); 431 StatusOr<nvinfer1::IElementWiseLayer*> sub = 432 this->Sub(runtime_shape, (*ones)->getOutput(0)); 433 TRT_ENSURE_PTR_OK(sub); 434 runtime_shape = (*sub)->getOutput(0); 435 } 436 437 // Create a constant tensor containing the gather indices. 438 // For any dim not in "indices", we mark it size to gather a zero. 439 const int input_nb_dims = input->getDimensions().nbDims; 440 std::vector<int> indices_all(input_nb_dims, input_nb_dims); 441 for (auto idx : indices) { 442 TRT_ENSURE(idx < input_nb_dims); 443 indices_all[idx] = idx; 444 } 445 446 StatusOr<nvinfer1::IConstantLayer*> indices_result = 447 this->Constant(indices_all); 448 TRT_ENSURE_PTR_OK(indices_result); 449 nvinfer1::ITensor* gather_indices = (*indices_result)->getOutput(0); 450 TRT_ENSURE(gather_indices->getDimensions().nbDims == 1); 451 TRT_ENSURE(gather_indices->getType() == nvinfer1::DataType::kINT32); 452 453 // Append a zero to the shape tensor. 454 StatusOr<nvinfer1::IConstantLayer*> zero_result = 455 this->Constant(std::vector<int>{0}); 456 TRT_ENSURE_PTR_OK(zero_result); 457 std::array<nvinfer1::ITensor*, 2> cat_inputs = { 458 runtime_shape, (*zero_result)->getOutput(0)}; 459 nvinfer1::IConcatenationLayer* cat_layer = 460 network_->addConcatenation(cat_inputs.data(), cat_inputs.size()); 461 TRT_ENSURE(cat_layer); 462 nvinfer1::ITensor* gather_input = cat_layer->getOutput(0); 463 TRT_ENSURE(gather_input); 464 465 // Finally, gather the indices from the input. 466 nvinfer1::IGatherLayer* gather = 467 network_->addGather(*gather_input, *gather_indices, 0); 468 TRT_ENSURE(gather); 469 return gather; 470 } 471 472 // Adds a scale layer that uniformly scales the input tensor by the specified 473 // amount. AddUniformScale(nvinfer1::ITensor * input,float scale,const std::string & name)474 StatusOr<nvinfer1::IScaleLayer*> AddUniformScale(nvinfer1::ITensor* input, 475 float scale, 476 const std::string& name) { 477 TRT_ENSURE(input); 478 TRT_ENSURE(!name.empty()); 479 StatusOr<nvinfer1::Weights> weight = this->ScalarWeights<float>(scale, 1); 480 TRT_ENSURE_OK(weight); 481 const nvinfer1::Weights empty_weights = 482 nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0}; 483 nvinfer1::IScaleLayer* scale_layer = 484 network_->addScale(*input, nvinfer1::ScaleMode::kUNIFORM, empty_weights, 485 (*weight), empty_weights); 486 TRT_ENSURE(scale_layer != nullptr); 487 scale_layer->setName(name.c_str()); 488 TRT_ENSURE((*scale_layer).getPower().count == 0); 489 TRT_ENSURE((*scale_layer).getShift().count == 0); 490 TRT_ENSURE((*scale_layer).getScale().count == 1); 491 return scale_layer; 492 } 493 494 StatusOr<nvinfer1::ILayer*> AddFill(const TRT_TensorOrWeights& value_input, 495 const TRT_TensorOrWeights& dims_input, 496 bool is_value_static, bool is_dims_static, 497 int nbDims, 498 const nvinfer1::Dims& trt_dims, 499 ITensorProxyPtr scalar_tensor = nullptr, 500 ITensorProxyPtr beta_tensor = nullptr, 501 const float delta = 0) { 502 // TensorRT IFillLayer requires a rank 0 scalar. 503 nvinfer1::Dims scalar_dims; 504 scalar_dims.nbDims = 0; 505 if (is_value_static) { 506 StatusOr<nvinfer1::IConstantLayer*> const_layer = 507 WeightsToConstant(value_input.weights().GetTrtWeights(), scalar_dims); 508 if (!const_layer.status().ok()) return const_layer.status(); 509 scalar_tensor = (*const_layer)->getOutput(0); 510 } else { 511 if (scalar_tensor == nullptr) { 512 StatusOr<nvinfer1::IShuffleLayer*> shuffler_layer = 513 Reshape(value_input.tensor()->trt_tensor(), scalar_dims); 514 if (!shuffler_layer.status().ok()) return shuffler_layer.status(); 515 scalar_tensor = (*shuffler_layer)->getOutput(0); 516 } 517 } 518 519 if (beta_tensor == nullptr) { 520 nvinfer1::Dims beta_shape{1, {nbDims}}; 521 StatusOr<nvinfer1::IConstantLayer*> const_layer = 522 Constant(delta, beta_shape, value_input.TrtDType()); 523 TF_RETURN_IF_ERROR(const_layer.status()); 524 beta_tensor = (*const_layer)->getOutput(0); 525 } 526 527 nvinfer1::IFillLayer* layer = 528 network_->addFill(trt_dims, nvinfer1::FillOperation::kLINSPACE); 529 TRT_ENSURE(layer); 530 if (!is_dims_static) { 531 layer->setInput(0, *dims_input.tensor()->trt_tensor()); 532 } 533 layer->setInput(1, *scalar_tensor->trt_tensor()); 534 layer->setInput(2, *beta_tensor->trt_tensor()); 535 return layer; 536 } 537 538 // Adds a quantization layer that uniformly scales the input tensor 539 // by the given multiplicative "scaling_factor", then rounds 540 // (round-to-nearest-ties-to-even) to the nearest integer and clamps in the 541 // range of [-128, 127]. Quantize(nvinfer1::ITensor * input,const float scaling_factor,const std::string & name)542 StatusOr<nvinfer1::ILayer*> Quantize(nvinfer1::ITensor* input, 543 const float scaling_factor, 544 const std::string& name) { 545 TRT_ENSURE(input); 546 TRT_ENSURE(!name.empty()); 547 // Preprocessor usage here is unavoidable because TRT8 API is new. 548 #if IS_TRT_VERSION_GE(8, 0, 0, 0) 549 // The TensorRT IQuantizeLayer divides by the scale factor rather than 550 // multiplies. To be consistent, in this function we expect a multiplicative 551 // scale factor, so we take the reciprical. 552 StatusOr<nvinfer1::IConstantLayer*> scaling_const = 553 this->Constant<float>(1.0f / scaling_factor, 1); 554 TRT_ENSURE_PTR_OK(scaling_const); 555 (*scaling_const)->setDimensions(nvinfer1::Dims{0, {}}); 556 nvinfer1::IQuantizeLayer* quant_layer = 557 network_->addQuantize(*input, *(*scaling_const)->getOutput(0)); 558 TRT_ENSURE(quant_layer); 559 quant_layer->setAxis(1); 560 return quant_layer; 561 #else 562 StatusOr<nvinfer1::IScaleLayer*> result = 563 this->AddUniformScale(input, scaling_factor, name); 564 TRT_ENSURE_PTR_OK(result); 565 (*result)->setOutputType(0, nvinfer1::DataType::kINT8); 566 (*result)->setPrecision(nvinfer1::DataType::kFLOAT); 567 return result; 568 #endif 569 } 570 571 // Adds a dequantize layer that casts the input tensor to TensorRT float type 572 // and scales it uniformly by the given multiplicative "scaling_factor". Dequantize(nvinfer1::ITensor * input,const float scaling_factor,const std::string & name)573 StatusOr<nvinfer1::ILayer*> Dequantize(nvinfer1::ITensor* input, 574 const float scaling_factor, 575 const std::string& name) { 576 TRT_ENSURE(input); 577 TRT_ENSURE(!name.empty()); 578 #if IS_TRT_VERSION_GE(8, 0, 0, 0) 579 StatusOr<nvinfer1::IConstantLayer*> scaling_const = 580 this->Constant<float>(scaling_factor, 1); 581 TRT_ENSURE_PTR_OK(scaling_const); 582 (*scaling_const)->setDimensions(nvinfer1::Dims{0, {}}); 583 nvinfer1::IDequantizeLayer* dequant_layer = 584 network_->addDequantize(*input, *(*scaling_const)->getOutput(0)); 585 dequant_layer->setAxis(1); 586 TRT_ENSURE(dequant_layer); 587 return dequant_layer; 588 #else 589 StatusOr<nvinfer1::IScaleLayer*> result = 590 this->AddUniformScale(input, scaling_factor, name); 591 TRT_ENSURE_PTR_OK(result); 592 (*result)->setOutputType(0, nvinfer1::DataType::kFLOAT); 593 (*result)->setPrecision(nvinfer1::DataType::kINT8); 594 return result; 595 #endif 596 } 597 598 // Adds TensorRT Q/DQ operations. This is for explicit precision mode. UniformQuantizeDequantizeExplicit(nvinfer1::ITensor * input,float quantize_scale,float dequantize_scale,const std::string & name)599 StatusOr<nvinfer1::ILayer*> UniformQuantizeDequantizeExplicit( 600 nvinfer1::ITensor* input, float quantize_scale, float dequantize_scale, 601 const std::string& name) { 602 TRT_ENSURE(input); 603 if (!IS_TRT_VERSION_GE(8, 0, 0, 0)) { 604 TRT_ENSURE(network_->hasExplicitPrecision()); 605 } 606 TRT_ENSURE(IS_TRT_VERSION_GE(7, 1, 0, 0)); 607 608 static int count = 0; 609 TRT_ENSURE(input->getType() == nvinfer1::DataType::kFLOAT); 610 std::string quant_name = absl::StrCat(input->getName(), "_quant_", count); 611 612 StatusOr<nvinfer1::ILayer*> quant = 613 this->Quantize(input, quantize_scale, quant_name); 614 TRT_ENSURE_PTR_OK(quant); 615 616 std::string dequant_name = 617 absl::StrCat(input->getName(), "_dequant_", count); 618 StatusOr<nvinfer1::ILayer*> dequant = this->Dequantize( 619 (*quant)->getOutput(0), dequantize_scale, dequant_name); 620 TRT_ENSURE_PTR_OK(dequant); 621 622 count++; 623 return dequant; 624 } 625 Reshape(nvinfer1::ITensor * input,const nvinfer1::Dims & new_shape)626 StatusOr<nvinfer1::IShuffleLayer*> Reshape(nvinfer1::ITensor* input, 627 const nvinfer1::Dims& new_shape) { 628 TRT_ENSURE(input); 629 nvinfer1::IShuffleLayer* layer = network_->addShuffle(*input); 630 TRT_ENSURE(layer); 631 layer->setReshapeDimensions(new_shape); 632 return layer; 633 } 634 FindProducerOf(const nvinfer1::ITensor * tensor)635 StatusOr<nvinfer1::ILayer*> FindProducerOf(const nvinfer1::ITensor* tensor) { 636 const char* name = tensor->getName(); 637 const int num_layers = network_->getNbLayers(); 638 for (int i = 0; i < num_layers; i++) { 639 nvinfer1::ILayer* layer = network_->getLayer(i); 640 const int num_outputs = layer->getNbOutputs(); 641 for (int j = 0; j < num_outputs; j++) { 642 nvinfer1::ITensor* t = layer->getOutput(j); 643 if (std::string(t->getName()) == name) { 644 return layer; 645 } 646 } 647 } 648 return errors::NotFound("could not find producing layer of ", name); 649 } 650 651 StatusOr<nvinfer1::ILayer*> UniqueParentOf(const nvinfer1::ILayer* layer, 652 int input_idx = 0) { 653 return FindProducerOf(layer->getInput(input_idx)); 654 } 655 Network()656 nvinfer1::INetworkDefinition* Network() { return network_; } 657 658 private: 659 nvinfer1::INetworkDefinition* const network_; 660 TrtWeightStore* const weight_store_; 661 }; 662 663 class ShuffleBuilder { 664 private: ShuffleBuilder(TRTNetworkBuilder * builder,nvinfer1::ITensor * input)665 explicit ShuffleBuilder(TRTNetworkBuilder* builder, nvinfer1::ITensor* input) 666 : builder_(builder) { 667 layer_ = builder->Network()->addShuffle(*input); 668 } 669 670 public: Create(TRTNetworkBuilder * builder,nvinfer1::ITensor * input)671 static StatusOr<ShuffleBuilder> Create(TRTNetworkBuilder* builder, 672 nvinfer1::ITensor* input) { 673 TRT_ENSURE(builder != nullptr); 674 TRT_ENSURE(input != nullptr); 675 return ShuffleBuilder(builder, input); 676 } 677 SetReshape(const nvinfer1::Dims & dims)678 ShuffleBuilder& SetReshape(const nvinfer1::Dims& dims) { 679 layer_->setReshapeDimensions(dims); 680 return *this; 681 } 682 SetReshape(nvinfer1::ITensor * shape)683 ShuffleBuilder& SetReshape(nvinfer1::ITensor* shape) { 684 layer_->setInput(1, *shape); 685 return *this; 686 } 687 SetFirstTranspose(const nvinfer1::Permutation & perm)688 ShuffleBuilder& SetFirstTranspose(const nvinfer1::Permutation& perm) { 689 layer_->setFirstTranspose(perm); 690 return *this; 691 } 692 SetSecondTranspose(const nvinfer1::Permutation & perm)693 ShuffleBuilder& SetSecondTranspose(const nvinfer1::Permutation& perm) { 694 layer_->setSecondTranspose(perm); 695 return *this; 696 } 697 Output()698 StatusOr<nvinfer1::ITensor*> Output() { 699 TRT_ENSURE(layer_ != nullptr); 700 TRT_ENSURE(layer_->getOutput(0) != nullptr); 701 return layer_->getOutput(0); 702 } 703 704 private: 705 TRTNetworkBuilder* builder_; 706 nvinfer1::IShuffleLayer* layer_; 707 }; 708 709 } // namespace convert 710 } // namespace tensorrt 711 } // namespace tensorflow 712 713 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT 714 #endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OPS_LAYER_UTILS_H_ 715