xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/utils/trt_tensor_proxy.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_TENSOR_PROXY_H
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_TENSOR_PROXY_H
18 
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
25 #include "tensorflow/core/platform/logging.h"
26 
27 #if GOOGLE_CUDA && GOOGLE_TENSORRT
28 #include "third_party/tensorrt/NvInfer.h"
29 
30 namespace tensorflow {
31 
32 namespace tensorrt {
33 
34 // SimpleITensor implements part of the ITensor interfaces to support the TF-TRT
35 // validator, as well as some TF-TRT tests. The former use case only utilizes
36 // the interfaces related to shape and type information.
37 class SimpleITensor {
38  public:
SimpleITensor(nvinfer1::DataType trt_dtype,const nvinfer1::Dims & trt_dims)39   SimpleITensor(nvinfer1::DataType trt_dtype, const nvinfer1::Dims& trt_dims)
40       : trt_dtype_(trt_dtype), trt_dims_(trt_dims) {}
41 
SimpleITensor()42   SimpleITensor() : dynamic_range_min_(0.0f), dynamic_range_max_(0.0f) {}
SimpleITensor(const nvinfer1::Dims & dims)43   SimpleITensor(const nvinfer1::Dims& dims)
44       : trt_dims_(dims), dynamic_range_min_(0.0f), dynamic_range_max_(0.0f) {}
45 
SimpleITensor(const std::vector<int> & dims)46   SimpleITensor(const std::vector<int>& dims) {
47     trt_dims_.nbDims = dims.size();
48     for (int i = 0; i < dims.size(); ++i) {
49       trt_dims_.d[i] = dims[i];
50     }
51     dynamic_range_min_ = 0.0f;
52     dynamic_range_max_ = 0.0f;
53   }
54 
setName(const char * name)55   void setName(const char* name) {}
56 
getName()57   const char* getName() const { return ""; }
58 
setDimensions(nvinfer1::Dims dimensions)59   void setDimensions(nvinfer1::Dims dimensions) { trt_dims_ = dimensions; }
60 
getDimensions()61   nvinfer1::Dims getDimensions() const { return trt_dims_; }
62 
setType(nvinfer1::DataType trt_dtype)63   void setType(nvinfer1::DataType trt_dtype) { trt_dtype_ = trt_dtype; }
64 
getType()65   nvinfer1::DataType getType() const { return trt_dtype_; }
66 
isNetworkInput()67   bool isNetworkInput() const { return false; }
68 
isNetworkOutput()69   bool isNetworkOutput() const { return false; }
70 
setBroadcastAcrossBatch(bool broadcastAcrossBatch)71   void setBroadcastAcrossBatch(bool broadcastAcrossBatch) {}
72 
getBroadcastAcrossBatch()73   bool getBroadcastAcrossBatch() const { return false; }
74 
getLocation()75   nvinfer1::TensorLocation getLocation() const { return location_; }
76 
setLocation(nvinfer1::TensorLocation location)77   void setLocation(nvinfer1::TensorLocation location) { location_ = location; }
setDynamicRange(float min,float max)78   bool setDynamicRange(float min, float max) {
79     dynamic_range_max_ = max;
80     dynamic_range_min_ = min;
81     return true;
82   }
83 
getDynamicRange()84   float getDynamicRange() const {
85     return (std::abs(dynamic_range_min_) + dynamic_range_max_) / 2.f;
86   }
dynamicRangeIsSet()87   bool dynamicRangeIsSet() const { return true; }
88 
resetDynamicRange()89   void resetDynamicRange() {
90     dynamic_range_min_ = 0.f;
91     dynamic_range_max_ = 0.f;
92   }
getDynamicRangeMin()93   float getDynamicRangeMin() const { return dynamic_range_min_; }
94 
getDynamicRangeMax()95   float getDynamicRangeMax() const { return dynamic_range_max_; }
96 
setAllowedFormats(nvinfer1::TensorFormats formats)97   void setAllowedFormats(nvinfer1::TensorFormats formats) {}
98 
getAllowedFormats()99   nvinfer1::TensorFormats getAllowedFormats() const { return 1; }
100 
isShapeTensor()101   bool isShapeTensor() const { return false; }
isExecutionTensor()102   bool isExecutionTensor() const { return true; }
103 
104  private:
105   nvinfer1::DataType trt_dtype_;
106   nvinfer1::Dims trt_dims_;
107   std::string name_;
108   nvinfer1::TensorLocation location_;
109   float dynamic_range_min_;
110   float dynamic_range_max_;
111 };
112 
113 enum class TensorType : int { kTRT, kSIMPLE };
114 
115 class ITensorProxy {
116  public:
117   //! ITensor not owned
ITensorProxy(nvinfer1::ITensor * trt_tensor)118   ITensorProxy(nvinfer1::ITensor* trt_tensor)
119       : trt_tensor_(trt_tensor), ttype_(TensorType::kTRT) {}
120 
121   //! SimpleITensor owned
ITensorProxy(SimpleITensor * simple_itensor)122   ITensorProxy(SimpleITensor* simple_itensor)
123       : simple_tensor_(simple_itensor), ttype_(TensorType::kSIMPLE) {}
124 
125   //! SimpleITensor owned
ITensorProxy(nvinfer1::DataType trt_dtype,const nvinfer1::Dims & trt_dims)126   explicit ITensorProxy(nvinfer1::DataType trt_dtype,
127                         const nvinfer1::Dims& trt_dims)
128       : simple_tensor_(std::unique_ptr<SimpleITensor>(
129             new SimpleITensor(trt_dtype, trt_dims))),
130         ttype_(TensorType::kSIMPLE) {}
131 
132   //! Variants for testing purposes
ITensorProxy()133   ITensorProxy()
134       : simple_tensor_(std::unique_ptr<SimpleITensor>(new SimpleITensor())),
135         ttype_(TensorType::kSIMPLE) {}
136 
ITensorProxy(const nvinfer1::Dims & dims)137   explicit ITensorProxy(const nvinfer1::Dims& dims)
138       : simple_tensor_(std::unique_ptr<SimpleITensor>(new SimpleITensor(dims))),
139         ttype_(TensorType::kSIMPLE) {}
140 
ITensorProxy(const std::vector<int> & dims)141   explicit ITensorProxy(const std::vector<int>& dims)
142       : simple_tensor_(std::unique_ptr<SimpleITensor>(new SimpleITensor(dims))),
143         ttype_(TensorType::kSIMPLE) {}
144 
is_trt_tensor()145   bool is_trt_tensor() const {
146     CHECK(validate());
147     return trt_tensor_ != nullptr;
148   }
149 
is_simple_tensor()150   bool is_simple_tensor() const {
151     CHECK(validate());
152     return simple_tensor_ != nullptr;
153   }
154 
ttype()155   TensorType ttype() const { return ttype_; }
156 
trt_tensor()157   nvinfer1::ITensor* trt_tensor() const {
158     CHECK_NOTNULL(trt_tensor_);
159     CHECK(ttype_ == TensorType::kTRT);
160     return trt_tensor_;
161   }
162 
simple_tensor()163   SimpleITensor* simple_tensor() const {
164     CHECK_NOTNULL(simple_tensor_);
165     CHECK(ttype_ == TensorType::kSIMPLE);
166     return simple_tensor_.get();
167   }
168 
setName(const char * name)169   void setName(const char* name) {
170     switch (ttype_) {
171       case TensorType::kTRT:
172         return trt_tensor_->setName(name);
173       case TensorType::kSIMPLE:
174         return simple_tensor_->setName(name);
175     }
176     LOG(FATAL) << "Unsupported itensor_ type";
177   }
178 
getName()179   const char* getName() const {
180     switch (ttype_) {
181       case TensorType::kTRT:
182         return trt_tensor_->getName();
183       case TensorType::kSIMPLE:
184         return simple_tensor_->getName();
185     }
186     LOG(FATAL) << "Unsupported itensor_ type";
187   }
188 
setDimensions(nvinfer1::Dims dimensions)189   void setDimensions(nvinfer1::Dims dimensions) {
190     switch (ttype_) {
191       case TensorType::kTRT:
192         return trt_tensor_->setDimensions(dimensions);
193       case TensorType::kSIMPLE:
194         return simple_tensor_->setDimensions(dimensions);
195     }
196     LOG(FATAL) << "Unsupported itensor_ type";
197   }
198 
getDimensions()199   nvinfer1::Dims getDimensions() const {
200     switch (ttype_) {
201       case TensorType::kTRT:
202         return trt_tensor_->getDimensions();
203       case TensorType::kSIMPLE:
204         return simple_tensor_->getDimensions();
205     }
206     LOG(FATAL) << "Unsupported itensor_ type";
207   }
208 
setType(nvinfer1::DataType type)209   void setType(nvinfer1::DataType type) {
210     switch (ttype_) {
211       case TensorType::kTRT:
212         return trt_tensor_->setType(type);
213       case TensorType::kSIMPLE:
214         return simple_tensor_->setType(type);
215     }
216     LOG(FATAL) << "Unsupported itensor_ type";
217   }
218 
getType()219   nvinfer1::DataType getType() const {
220     switch (ttype_) {
221       case TensorType::kTRT:
222         return trt_tensor_->getType();
223       case TensorType::kSIMPLE:
224         return simple_tensor_->getType();
225     }
226     LOG(FATAL) << "Unsupported itensor_ type";
227   }
228 
isNetworkInput()229   bool isNetworkInput() const {
230     switch (ttype_) {
231       case TensorType::kTRT:
232         return trt_tensor_->isNetworkInput();
233       case TensorType::kSIMPLE:
234         return simple_tensor_->isNetworkInput();
235     }
236     LOG(FATAL) << "Unsupported itensor_ type";
237   }
238 
isNetworkOutput()239   bool isNetworkOutput() const {
240     switch (ttype_) {
241       case TensorType::kTRT:
242         return trt_tensor_->isNetworkOutput();
243       case TensorType::kSIMPLE:
244         return simple_tensor_->isNetworkOutput();
245     }
246     LOG(FATAL) << "Unsupported itensor_ type";
247   }
248 
setBroadcastAcrossBatch(bool broadcastAcrossBatch)249   void setBroadcastAcrossBatch(bool broadcastAcrossBatch) {
250     switch (ttype_) {
251       case TensorType::kTRT:
252         return trt_tensor_->setBroadcastAcrossBatch(broadcastAcrossBatch);
253       case TensorType::kSIMPLE:
254         return simple_tensor_->setBroadcastAcrossBatch(broadcastAcrossBatch);
255     }
256     LOG(FATAL) << "Unsupported itensor_ type";
257   }
258 
getBroadcastAcrossBatch()259   bool getBroadcastAcrossBatch() const {
260     switch (ttype_) {
261       case TensorType::kTRT:
262         return trt_tensor_->getBroadcastAcrossBatch();
263       case TensorType::kSIMPLE:
264         return simple_tensor_->getBroadcastAcrossBatch();
265     }
266     LOG(FATAL) << "Unsupported itensor_ type";
267   }
268 
getLocation()269   nvinfer1::TensorLocation getLocation() const {
270     switch (ttype_) {
271       case TensorType::kTRT:
272         return trt_tensor_->getLocation();
273       case TensorType::kSIMPLE:
274         return simple_tensor_->getLocation();
275     }
276     LOG(FATAL) << "Unsupported itensor_ type";
277   }
278 
setLocation(nvinfer1::TensorLocation location)279   void setLocation(nvinfer1::TensorLocation location) {
280     switch (ttype_) {
281       case TensorType::kTRT:
282         return trt_tensor_->setLocation(location);
283       case TensorType::kSIMPLE:
284         return simple_tensor_->setLocation(location);
285     }
286     LOG(FATAL) << "Unsupported itensor_ type";
287   }
288 
setDynamicRange(float min,float max)289   bool setDynamicRange(float min, float max) {
290     switch (ttype_) {
291       case TensorType::kTRT:
292         return trt_tensor_->setDynamicRange(min, max);
293       case TensorType::kSIMPLE:
294         return simple_tensor_->setDynamicRange(min, max);
295     }
296     LOG(FATAL) << "Unsupported itensor_ type";
297   }
298 
dynamicRangeIsSet()299   bool dynamicRangeIsSet() const {
300     switch (ttype_) {
301       case TensorType::kTRT:
302         return trt_tensor_->dynamicRangeIsSet();
303       case TensorType::kSIMPLE:
304         return simple_tensor_->dynamicRangeIsSet();
305     }
306     LOG(FATAL) << "Unsupported itensor_ type";
307   }
308 
resetDynamicRange()309   void resetDynamicRange() {
310     switch (ttype_) {
311       case TensorType::kTRT:
312         return trt_tensor_->resetDynamicRange();
313       case TensorType::kSIMPLE:
314         return simple_tensor_->resetDynamicRange();
315     }
316     LOG(FATAL) << "Unsupported itensor_ type";
317   }
getDynamicRangeMin()318   float getDynamicRangeMin() const {
319     switch (ttype_) {
320       case TensorType::kTRT:
321         return trt_tensor_->getDynamicRangeMin();
322       case TensorType::kSIMPLE:
323         return simple_tensor_->getDynamicRangeMin();
324     }
325     LOG(FATAL) << "Unsupported itensor_ type";
326   }
327 
getDynamicRangeMax()328   float getDynamicRangeMax() const {
329     switch (ttype_) {
330       case TensorType::kTRT:
331         return trt_tensor_->getDynamicRangeMax();
332       case TensorType::kSIMPLE:
333         return simple_tensor_->getDynamicRangeMax();
334     }
335     LOG(FATAL) << "Unsupported itensor_ type";
336   }
337 #if !IS_TRT_VERSION_GE(8, 0, 0, 0)
getDynamicRange()338   float getDynamicRange() const {
339     switch (ttype_) {
340       case TensorType::kTRT:
341         return trt_tensor_->getDynamicRange();
342       case TensorType::kSIMPLE:
343         return simple_tensor_->getDynamicRange();
344     }
345     LOG(FATAL) << "Unsupported itensor_ type";
346   }
347 #endif
setAllowedFormats(nvinfer1::TensorFormats formats)348   void setAllowedFormats(nvinfer1::TensorFormats formats) {
349     switch (ttype_) {
350       case TensorType::kTRT:
351         return trt_tensor_->setAllowedFormats(formats);
352       case TensorType::kSIMPLE:
353         return simple_tensor_->setAllowedFormats(formats);
354     }
355     LOG(FATAL) << "Unsupported itensor_ type";
356   }
357 
getAllowedFormats()358   nvinfer1::TensorFormats getAllowedFormats() const {
359     switch (ttype_) {
360       case TensorType::kTRT:
361         return trt_tensor_->getAllowedFormats();
362       case TensorType::kSIMPLE:
363         return simple_tensor_->getAllowedFormats();
364     }
365     LOG(FATAL) << "Unsupported itensor_ type";
366   }
367 
isShapeTensor()368   bool isShapeTensor() const {
369     switch (ttype_) {
370       case TensorType::kTRT:
371         return trt_tensor_->isShapeTensor();
372       case TensorType::kSIMPLE:
373         return simple_tensor_->isShapeTensor();
374     }
375     LOG(FATAL) << "Unsupported itensor_ type";
376   }
377 
isExecutionTensor()378   bool isExecutionTensor() const {
379     switch (ttype_) {
380       case TensorType::kTRT:
381         return trt_tensor_->isExecutionTensor();
382       case TensorType::kSIMPLE:
383         return simple_tensor_->isExecutionTensor();
384     }
385     LOG(FATAL) << "Unsupported itensor_ type";
386   }
387 
388  private:
validate()389   bool validate() const {
390     return (trt_tensor_ && !simple_tensor_) || (!trt_tensor_ && simple_tensor_);
391   }
392 
393   // When ITensorProxy represents an ITensor, the ITensor can be either passed
394   // by the caller via the constructor that takes an ITensor* as parameter, or
395   // be created as a SimpleITensor.
396   //
397   // In the first case, the ITensor pointer is stored in 'tensor_' below, and
398   // the ITensor itself is not owned by this class. This method is used by
399   // Converter (e.g. AddInputTensor) and op converters during TRT network
400   // construction, where the TRT network owns the ITensor.
401   //
402   nvinfer1::ITensor* trt_tensor_ = nullptr;  // Not owned.
403   // In the second case, the created SimpleITensor is stored in
404   // 'simple_itensor_' below and is owned by this class. SimpleITensor is a fake
405   // implementation of ITensor and is used for testing and by TrtNodeValidator
406   //  to validate the graph nodes.
407   std::shared_ptr<SimpleITensor> simple_tensor_ = nullptr;
408 
409   TensorType ttype_;
410 };
411 
412 class ITensorProxyPtr {
413  public:
ITensorProxyPtr(std::nullptr_t)414   ITensorProxyPtr(std::nullptr_t) : p_(nullptr) {}
ITensorProxyPtr(ITensorProxy * p)415   ITensorProxyPtr(ITensorProxy* p) : p_(p) {}
ITensorProxyPtr(nvinfer1::ITensor * p)416   ITensorProxyPtr(nvinfer1::ITensor* p) : p_(new ITensorProxy(p)) {}
ITensorProxyPtr(SimpleITensor * p)417   ITensorProxyPtr(SimpleITensor* p) : p_(new ITensorProxy(p)) {}
418 
ITensorProxyPtr()419   ITensorProxyPtr() : p_(new ITensorProxy()) {}
ITensorProxyPtr(const nvinfer1::Dims & dims)420   ITensorProxyPtr(const nvinfer1::Dims& dims) : p_(new ITensorProxy(dims)) {}
ITensorProxyPtr(const std::vector<int> & dims)421   ITensorProxyPtr(const std::vector<int>& dims) : p_(new ITensorProxy(dims)) {}
422 
423   std::shared_ptr<ITensorProxy> p_{nullptr};
424   ITensorProxy* operator->() { return p_.get(); }
425   ITensorProxy* operator->() const { return p_.get(); }
426   ITensorProxy* operator*() { return p_.get(); }
427   ITensorProxy* operator*() const { return p_.get(); }
428 };
429 
430 inline bool operator==(const ITensorProxyPtr& p1, const ITensorProxyPtr& p2) {
431   if (p1.p_ == nullptr) {
432     return p2.p_ == nullptr;
433   }
434   if (p2.p_ == nullptr) {
435     return p1.p_ == nullptr;
436   }
437   return (p1->ttype() == p2->ttype()) &&
438          ((p1->ttype() == TensorType::kTRT &&
439            p1->trt_tensor() == p2->trt_tensor()) ||
440           (p1->ttype() == TensorType::kSIMPLE &&
441            p1->simple_tensor() == p2->simple_tensor()));
442 }
443 
444 inline bool operator!=(const ITensorProxyPtr& p1, const ITensorProxyPtr& p2) {
445   return !(p1 == p2);
446 }
447 
448 struct ITensorProxyHash {
operatorITensorProxyHash449   size_t operator()(const ITensorProxyPtr& tensor) const {
450     return reinterpret_cast<std::uintptr_t>(tensor.p_.get());
451   }
452 };
453 
454 }  // namespace tensorrt
455 }  // namespace tensorflow
456 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
457 
458 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_TENSOR_PROXY_H
459