1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_TENSOR_PROXY_H 17 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_TENSOR_PROXY_H 18 19 #include <cstddef> 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "tensorflow/compiler/tf2tensorrt/common/utils.h" 25 #include "tensorflow/core/platform/logging.h" 26 27 #if GOOGLE_CUDA && GOOGLE_TENSORRT 28 #include "third_party/tensorrt/NvInfer.h" 29 30 namespace tensorflow { 31 32 namespace tensorrt { 33 34 // SimpleITensor implements part of the ITensor interfaces to support the TF-TRT 35 // validator, as well as some TF-TRT tests. The former use case only utilizes 36 // the interfaces related to shape and type information. 37 class SimpleITensor { 38 public: SimpleITensor(nvinfer1::DataType trt_dtype,const nvinfer1::Dims & trt_dims)39 SimpleITensor(nvinfer1::DataType trt_dtype, const nvinfer1::Dims& trt_dims) 40 : trt_dtype_(trt_dtype), trt_dims_(trt_dims) {} 41 SimpleITensor()42 SimpleITensor() : dynamic_range_min_(0.0f), dynamic_range_max_(0.0f) {} SimpleITensor(const nvinfer1::Dims & dims)43 SimpleITensor(const nvinfer1::Dims& dims) 44 : trt_dims_(dims), dynamic_range_min_(0.0f), dynamic_range_max_(0.0f) {} 45 SimpleITensor(const std::vector<int> & dims)46 SimpleITensor(const std::vector<int>& dims) { 47 trt_dims_.nbDims = dims.size(); 48 for (int i = 0; i < dims.size(); ++i) { 49 trt_dims_.d[i] = dims[i]; 50 } 51 dynamic_range_min_ = 0.0f; 52 dynamic_range_max_ = 0.0f; 53 } 54 setName(const char * name)55 void setName(const char* name) {} 56 getName()57 const char* getName() const { return ""; } 58 setDimensions(nvinfer1::Dims dimensions)59 void setDimensions(nvinfer1::Dims dimensions) { trt_dims_ = dimensions; } 60 getDimensions()61 nvinfer1::Dims getDimensions() const { return trt_dims_; } 62 setType(nvinfer1::DataType trt_dtype)63 void setType(nvinfer1::DataType trt_dtype) { trt_dtype_ = trt_dtype; } 64 getType()65 nvinfer1::DataType getType() const { return trt_dtype_; } 66 isNetworkInput()67 bool isNetworkInput() const { return false; } 68 isNetworkOutput()69 bool isNetworkOutput() const { return false; } 70 setBroadcastAcrossBatch(bool broadcastAcrossBatch)71 void setBroadcastAcrossBatch(bool broadcastAcrossBatch) {} 72 getBroadcastAcrossBatch()73 bool getBroadcastAcrossBatch() const { return false; } 74 getLocation()75 nvinfer1::TensorLocation getLocation() const { return location_; } 76 setLocation(nvinfer1::TensorLocation location)77 void setLocation(nvinfer1::TensorLocation location) { location_ = location; } setDynamicRange(float min,float max)78 bool setDynamicRange(float min, float max) { 79 dynamic_range_max_ = max; 80 dynamic_range_min_ = min; 81 return true; 82 } 83 getDynamicRange()84 float getDynamicRange() const { 85 return (std::abs(dynamic_range_min_) + dynamic_range_max_) / 2.f; 86 } dynamicRangeIsSet()87 bool dynamicRangeIsSet() const { return true; } 88 resetDynamicRange()89 void resetDynamicRange() { 90 dynamic_range_min_ = 0.f; 91 dynamic_range_max_ = 0.f; 92 } getDynamicRangeMin()93 float getDynamicRangeMin() const { return dynamic_range_min_; } 94 getDynamicRangeMax()95 float getDynamicRangeMax() const { return dynamic_range_max_; } 96 setAllowedFormats(nvinfer1::TensorFormats formats)97 void setAllowedFormats(nvinfer1::TensorFormats formats) {} 98 getAllowedFormats()99 nvinfer1::TensorFormats getAllowedFormats() const { return 1; } 100 isShapeTensor()101 bool isShapeTensor() const { return false; } isExecutionTensor()102 bool isExecutionTensor() const { return true; } 103 104 private: 105 nvinfer1::DataType trt_dtype_; 106 nvinfer1::Dims trt_dims_; 107 std::string name_; 108 nvinfer1::TensorLocation location_; 109 float dynamic_range_min_; 110 float dynamic_range_max_; 111 }; 112 113 enum class TensorType : int { kTRT, kSIMPLE }; 114 115 class ITensorProxy { 116 public: 117 //! ITensor not owned ITensorProxy(nvinfer1::ITensor * trt_tensor)118 ITensorProxy(nvinfer1::ITensor* trt_tensor) 119 : trt_tensor_(trt_tensor), ttype_(TensorType::kTRT) {} 120 121 //! SimpleITensor owned ITensorProxy(SimpleITensor * simple_itensor)122 ITensorProxy(SimpleITensor* simple_itensor) 123 : simple_tensor_(simple_itensor), ttype_(TensorType::kSIMPLE) {} 124 125 //! SimpleITensor owned ITensorProxy(nvinfer1::DataType trt_dtype,const nvinfer1::Dims & trt_dims)126 explicit ITensorProxy(nvinfer1::DataType trt_dtype, 127 const nvinfer1::Dims& trt_dims) 128 : simple_tensor_(std::unique_ptr<SimpleITensor>( 129 new SimpleITensor(trt_dtype, trt_dims))), 130 ttype_(TensorType::kSIMPLE) {} 131 132 //! Variants for testing purposes ITensorProxy()133 ITensorProxy() 134 : simple_tensor_(std::unique_ptr<SimpleITensor>(new SimpleITensor())), 135 ttype_(TensorType::kSIMPLE) {} 136 ITensorProxy(const nvinfer1::Dims & dims)137 explicit ITensorProxy(const nvinfer1::Dims& dims) 138 : simple_tensor_(std::unique_ptr<SimpleITensor>(new SimpleITensor(dims))), 139 ttype_(TensorType::kSIMPLE) {} 140 ITensorProxy(const std::vector<int> & dims)141 explicit ITensorProxy(const std::vector<int>& dims) 142 : simple_tensor_(std::unique_ptr<SimpleITensor>(new SimpleITensor(dims))), 143 ttype_(TensorType::kSIMPLE) {} 144 is_trt_tensor()145 bool is_trt_tensor() const { 146 CHECK(validate()); 147 return trt_tensor_ != nullptr; 148 } 149 is_simple_tensor()150 bool is_simple_tensor() const { 151 CHECK(validate()); 152 return simple_tensor_ != nullptr; 153 } 154 ttype()155 TensorType ttype() const { return ttype_; } 156 trt_tensor()157 nvinfer1::ITensor* trt_tensor() const { 158 CHECK_NOTNULL(trt_tensor_); 159 CHECK(ttype_ == TensorType::kTRT); 160 return trt_tensor_; 161 } 162 simple_tensor()163 SimpleITensor* simple_tensor() const { 164 CHECK_NOTNULL(simple_tensor_); 165 CHECK(ttype_ == TensorType::kSIMPLE); 166 return simple_tensor_.get(); 167 } 168 setName(const char * name)169 void setName(const char* name) { 170 switch (ttype_) { 171 case TensorType::kTRT: 172 return trt_tensor_->setName(name); 173 case TensorType::kSIMPLE: 174 return simple_tensor_->setName(name); 175 } 176 LOG(FATAL) << "Unsupported itensor_ type"; 177 } 178 getName()179 const char* getName() const { 180 switch (ttype_) { 181 case TensorType::kTRT: 182 return trt_tensor_->getName(); 183 case TensorType::kSIMPLE: 184 return simple_tensor_->getName(); 185 } 186 LOG(FATAL) << "Unsupported itensor_ type"; 187 } 188 setDimensions(nvinfer1::Dims dimensions)189 void setDimensions(nvinfer1::Dims dimensions) { 190 switch (ttype_) { 191 case TensorType::kTRT: 192 return trt_tensor_->setDimensions(dimensions); 193 case TensorType::kSIMPLE: 194 return simple_tensor_->setDimensions(dimensions); 195 } 196 LOG(FATAL) << "Unsupported itensor_ type"; 197 } 198 getDimensions()199 nvinfer1::Dims getDimensions() const { 200 switch (ttype_) { 201 case TensorType::kTRT: 202 return trt_tensor_->getDimensions(); 203 case TensorType::kSIMPLE: 204 return simple_tensor_->getDimensions(); 205 } 206 LOG(FATAL) << "Unsupported itensor_ type"; 207 } 208 setType(nvinfer1::DataType type)209 void setType(nvinfer1::DataType type) { 210 switch (ttype_) { 211 case TensorType::kTRT: 212 return trt_tensor_->setType(type); 213 case TensorType::kSIMPLE: 214 return simple_tensor_->setType(type); 215 } 216 LOG(FATAL) << "Unsupported itensor_ type"; 217 } 218 getType()219 nvinfer1::DataType getType() const { 220 switch (ttype_) { 221 case TensorType::kTRT: 222 return trt_tensor_->getType(); 223 case TensorType::kSIMPLE: 224 return simple_tensor_->getType(); 225 } 226 LOG(FATAL) << "Unsupported itensor_ type"; 227 } 228 isNetworkInput()229 bool isNetworkInput() const { 230 switch (ttype_) { 231 case TensorType::kTRT: 232 return trt_tensor_->isNetworkInput(); 233 case TensorType::kSIMPLE: 234 return simple_tensor_->isNetworkInput(); 235 } 236 LOG(FATAL) << "Unsupported itensor_ type"; 237 } 238 isNetworkOutput()239 bool isNetworkOutput() const { 240 switch (ttype_) { 241 case TensorType::kTRT: 242 return trt_tensor_->isNetworkOutput(); 243 case TensorType::kSIMPLE: 244 return simple_tensor_->isNetworkOutput(); 245 } 246 LOG(FATAL) << "Unsupported itensor_ type"; 247 } 248 setBroadcastAcrossBatch(bool broadcastAcrossBatch)249 void setBroadcastAcrossBatch(bool broadcastAcrossBatch) { 250 switch (ttype_) { 251 case TensorType::kTRT: 252 return trt_tensor_->setBroadcastAcrossBatch(broadcastAcrossBatch); 253 case TensorType::kSIMPLE: 254 return simple_tensor_->setBroadcastAcrossBatch(broadcastAcrossBatch); 255 } 256 LOG(FATAL) << "Unsupported itensor_ type"; 257 } 258 getBroadcastAcrossBatch()259 bool getBroadcastAcrossBatch() const { 260 switch (ttype_) { 261 case TensorType::kTRT: 262 return trt_tensor_->getBroadcastAcrossBatch(); 263 case TensorType::kSIMPLE: 264 return simple_tensor_->getBroadcastAcrossBatch(); 265 } 266 LOG(FATAL) << "Unsupported itensor_ type"; 267 } 268 getLocation()269 nvinfer1::TensorLocation getLocation() const { 270 switch (ttype_) { 271 case TensorType::kTRT: 272 return trt_tensor_->getLocation(); 273 case TensorType::kSIMPLE: 274 return simple_tensor_->getLocation(); 275 } 276 LOG(FATAL) << "Unsupported itensor_ type"; 277 } 278 setLocation(nvinfer1::TensorLocation location)279 void setLocation(nvinfer1::TensorLocation location) { 280 switch (ttype_) { 281 case TensorType::kTRT: 282 return trt_tensor_->setLocation(location); 283 case TensorType::kSIMPLE: 284 return simple_tensor_->setLocation(location); 285 } 286 LOG(FATAL) << "Unsupported itensor_ type"; 287 } 288 setDynamicRange(float min,float max)289 bool setDynamicRange(float min, float max) { 290 switch (ttype_) { 291 case TensorType::kTRT: 292 return trt_tensor_->setDynamicRange(min, max); 293 case TensorType::kSIMPLE: 294 return simple_tensor_->setDynamicRange(min, max); 295 } 296 LOG(FATAL) << "Unsupported itensor_ type"; 297 } 298 dynamicRangeIsSet()299 bool dynamicRangeIsSet() const { 300 switch (ttype_) { 301 case TensorType::kTRT: 302 return trt_tensor_->dynamicRangeIsSet(); 303 case TensorType::kSIMPLE: 304 return simple_tensor_->dynamicRangeIsSet(); 305 } 306 LOG(FATAL) << "Unsupported itensor_ type"; 307 } 308 resetDynamicRange()309 void resetDynamicRange() { 310 switch (ttype_) { 311 case TensorType::kTRT: 312 return trt_tensor_->resetDynamicRange(); 313 case TensorType::kSIMPLE: 314 return simple_tensor_->resetDynamicRange(); 315 } 316 LOG(FATAL) << "Unsupported itensor_ type"; 317 } getDynamicRangeMin()318 float getDynamicRangeMin() const { 319 switch (ttype_) { 320 case TensorType::kTRT: 321 return trt_tensor_->getDynamicRangeMin(); 322 case TensorType::kSIMPLE: 323 return simple_tensor_->getDynamicRangeMin(); 324 } 325 LOG(FATAL) << "Unsupported itensor_ type"; 326 } 327 getDynamicRangeMax()328 float getDynamicRangeMax() const { 329 switch (ttype_) { 330 case TensorType::kTRT: 331 return trt_tensor_->getDynamicRangeMax(); 332 case TensorType::kSIMPLE: 333 return simple_tensor_->getDynamicRangeMax(); 334 } 335 LOG(FATAL) << "Unsupported itensor_ type"; 336 } 337 #if !IS_TRT_VERSION_GE(8, 0, 0, 0) getDynamicRange()338 float getDynamicRange() const { 339 switch (ttype_) { 340 case TensorType::kTRT: 341 return trt_tensor_->getDynamicRange(); 342 case TensorType::kSIMPLE: 343 return simple_tensor_->getDynamicRange(); 344 } 345 LOG(FATAL) << "Unsupported itensor_ type"; 346 } 347 #endif setAllowedFormats(nvinfer1::TensorFormats formats)348 void setAllowedFormats(nvinfer1::TensorFormats formats) { 349 switch (ttype_) { 350 case TensorType::kTRT: 351 return trt_tensor_->setAllowedFormats(formats); 352 case TensorType::kSIMPLE: 353 return simple_tensor_->setAllowedFormats(formats); 354 } 355 LOG(FATAL) << "Unsupported itensor_ type"; 356 } 357 getAllowedFormats()358 nvinfer1::TensorFormats getAllowedFormats() const { 359 switch (ttype_) { 360 case TensorType::kTRT: 361 return trt_tensor_->getAllowedFormats(); 362 case TensorType::kSIMPLE: 363 return simple_tensor_->getAllowedFormats(); 364 } 365 LOG(FATAL) << "Unsupported itensor_ type"; 366 } 367 isShapeTensor()368 bool isShapeTensor() const { 369 switch (ttype_) { 370 case TensorType::kTRT: 371 return trt_tensor_->isShapeTensor(); 372 case TensorType::kSIMPLE: 373 return simple_tensor_->isShapeTensor(); 374 } 375 LOG(FATAL) << "Unsupported itensor_ type"; 376 } 377 isExecutionTensor()378 bool isExecutionTensor() const { 379 switch (ttype_) { 380 case TensorType::kTRT: 381 return trt_tensor_->isExecutionTensor(); 382 case TensorType::kSIMPLE: 383 return simple_tensor_->isExecutionTensor(); 384 } 385 LOG(FATAL) << "Unsupported itensor_ type"; 386 } 387 388 private: validate()389 bool validate() const { 390 return (trt_tensor_ && !simple_tensor_) || (!trt_tensor_ && simple_tensor_); 391 } 392 393 // When ITensorProxy represents an ITensor, the ITensor can be either passed 394 // by the caller via the constructor that takes an ITensor* as parameter, or 395 // be created as a SimpleITensor. 396 // 397 // In the first case, the ITensor pointer is stored in 'tensor_' below, and 398 // the ITensor itself is not owned by this class. This method is used by 399 // Converter (e.g. AddInputTensor) and op converters during TRT network 400 // construction, where the TRT network owns the ITensor. 401 // 402 nvinfer1::ITensor* trt_tensor_ = nullptr; // Not owned. 403 // In the second case, the created SimpleITensor is stored in 404 // 'simple_itensor_' below and is owned by this class. SimpleITensor is a fake 405 // implementation of ITensor and is used for testing and by TrtNodeValidator 406 // to validate the graph nodes. 407 std::shared_ptr<SimpleITensor> simple_tensor_ = nullptr; 408 409 TensorType ttype_; 410 }; 411 412 class ITensorProxyPtr { 413 public: ITensorProxyPtr(std::nullptr_t)414 ITensorProxyPtr(std::nullptr_t) : p_(nullptr) {} ITensorProxyPtr(ITensorProxy * p)415 ITensorProxyPtr(ITensorProxy* p) : p_(p) {} ITensorProxyPtr(nvinfer1::ITensor * p)416 ITensorProxyPtr(nvinfer1::ITensor* p) : p_(new ITensorProxy(p)) {} ITensorProxyPtr(SimpleITensor * p)417 ITensorProxyPtr(SimpleITensor* p) : p_(new ITensorProxy(p)) {} 418 ITensorProxyPtr()419 ITensorProxyPtr() : p_(new ITensorProxy()) {} ITensorProxyPtr(const nvinfer1::Dims & dims)420 ITensorProxyPtr(const nvinfer1::Dims& dims) : p_(new ITensorProxy(dims)) {} ITensorProxyPtr(const std::vector<int> & dims)421 ITensorProxyPtr(const std::vector<int>& dims) : p_(new ITensorProxy(dims)) {} 422 423 std::shared_ptr<ITensorProxy> p_{nullptr}; 424 ITensorProxy* operator->() { return p_.get(); } 425 ITensorProxy* operator->() const { return p_.get(); } 426 ITensorProxy* operator*() { return p_.get(); } 427 ITensorProxy* operator*() const { return p_.get(); } 428 }; 429 430 inline bool operator==(const ITensorProxyPtr& p1, const ITensorProxyPtr& p2) { 431 if (p1.p_ == nullptr) { 432 return p2.p_ == nullptr; 433 } 434 if (p2.p_ == nullptr) { 435 return p1.p_ == nullptr; 436 } 437 return (p1->ttype() == p2->ttype()) && 438 ((p1->ttype() == TensorType::kTRT && 439 p1->trt_tensor() == p2->trt_tensor()) || 440 (p1->ttype() == TensorType::kSIMPLE && 441 p1->simple_tensor() == p2->simple_tensor())); 442 } 443 444 inline bool operator!=(const ITensorProxyPtr& p1, const ITensorProxyPtr& p2) { 445 return !(p1 == p2); 446 } 447 448 struct ITensorProxyHash { operatorITensorProxyHash449 size_t operator()(const ITensorProxyPtr& tensor) const { 450 return reinterpret_cast<std::uintptr_t>(tensor.p_.get()); 451 } 452 }; 453 454 } // namespace tensorrt 455 } // namespace tensorflow 456 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT 457 458 #endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_TENSOR_PROXY_H 459