xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
18 
19 #include <algorithm>
20 #include <iterator>
21 #include <memory>
22 #include <type_traits>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
28 #include "tensorflow/compiler/tf2tensorrt/utils/trt_tensor_proxy.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/graph/graph.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/util/env_var.h"
36 
37 #if GOOGLE_CUDA && GOOGLE_TENSORRT
38 #include "third_party/tensorrt/NvInfer.h"
39 
40 #define TFTRT_ERROR(func, ...)                                              \
41   do {                                                                      \
42     return func("TFTRT::", __FUNCTION__, ":", __LINE__, ": ", __VA_ARGS__); \
43   } while (0)
44 
45 #define TFTRT_CHECK_SHAPE_TENSOR(tensor)                                 \
46   if (!IsTrtShapeTensorCompatible(tensor)) {                             \
47     TFTRT_ERROR(errors::InvalidArgument, "Tensor of type ",              \
48                 DebugString(tensor.dtype()), " having shape ",           \
49                 tensor.shape().DebugString(), " is not TRT compatible"); \
50   }
51 
52 namespace tensorflow {
53 namespace tensorrt {
54 
55 static constexpr char kCastOutputTypeAttrName[] = "DstT";
56 
57 #if !IS_TRT_VERSION_GE(8, 2, 0, 0)
58 template <typename T>
59 struct TrtDestroyer {
operatorTrtDestroyer60   void operator()(T* t) {
61     if (t) t->destroy();
62   }
63 };
64 template <typename T>
65 using TrtUniquePtrType = std::unique_ptr<T, TrtDestroyer<T>>;
66 #else
67 template <typename T>
68 using TrtUniquePtrType = std::unique_ptr<T>;
69 #endif
70 
71 // Define a hash function for vector<TensorShape> because it is used as the key
72 // for the engine cache.
73 struct VectorTensorShapeHasher {
operatorVectorTensorShapeHasher74   std::size_t operator()(const std::vector<TensorShape>& key) const {
75     return std::hash<std::string>()(TensorShapeUtils::ShapeListString(key));
76   }
77 };
78 
79 using absl::StrAppend;
80 using absl::StrCat;
81 
82 // This utility template converts an arithmetic type to a string. This function
83 // is necessary to allow the following function to behave recursively:
84 // `string DebugString(const std::vector<CType>&)`.
85 template <typename CType, typename = typename std::enable_if<
86                               std::is_arithmetic<CType>::value, CType>::type>
DebugString(const CType & el)87 string DebugString(const CType& el) {
88   string el_str = std::to_string(el);
89   // Prettify std::to_string which can sometimes returns 1.50000 instead of 1.5.
90   // In short it removes trailing 0s in a string-formatted number.
91   el_str.erase(el_str.find_last_not_of('0') + 1, std::string::npos);
92   return el_str;
93 }
94 // This utility template converts nested vectors to a string for debug purposes.
95 template <typename CType>
DebugString(const std::vector<CType> & vector)96 string DebugString(const std::vector<CType>& vector) {
97   string tmp_s = "";
98   for (const auto el : vector) {
99     StrAppend(&tmp_s, StrCat(DebugString(el), ", "));
100   }
101   return StrCat("{", tmp_s.substr(0, tmp_s.length() - 2), "}");
102 }
103 string DebugString(const nvinfer1::Dims& dims);
104 string DebugString(const nvinfer1::DataType trt_dtype);
105 string DebugString(const DataType tf_type);
106 string DebugString(const nvinfer1::Permutation& permutation, int len);
107 string DebugString(const ITensorProxyPtr& tensor);
108 string DebugString(const nvinfer1::ITensor& tensor);
109 string DebugString(const std::vector<nvinfer1::Dims>& dimvec);
110 string DebugString(const std::vector<TensorShape>& shapes);
111 string DebugString(const std::vector<PartialTensorShape>& shapes);
112 
113 template <size_t N>
DebugString(const absl::InlinedVector<int64,N> & data)114 string DebugString(const absl::InlinedVector<int64, N>& data) {
115   return absl::StrCat("[", absl::StrJoin(data, ","), "]");
116 }
117 
HasStaticShape(const nvinfer1::Dims & dims)118 inline bool HasStaticShape(const nvinfer1::Dims& dims) {
119   if (dims.nbDims < 0) return false;
120   for (int d = 0; d < dims.nbDims; ++d) {
121     if (dims.d[d] < 0) return false;
122   }
123   return true;
124 }
125 
126 template <typename T>
HasStaticShape(const T & dims)127 bool HasStaticShape(const T& dims) {
128   return !absl::c_any_of(dims, [](int i) { return i < 0; });
129 }
130 
131 // Returns whether a shape is compatible with a TRT shape tensor.
132 template <typename TensorShapeType>
IsTrtShapeTensorCompatible(const TensorShapeType & shape)133 inline bool IsTrtShapeTensorCompatible(const TensorShapeType& shape) {
134   return (
135       shape.dims() == 0 ||
136       (shape.dims() == 1 && shape.num_elements() <= nvinfer1::Dims::MAX_DIMS));
137 }
138 
139 // Returns whether a TF tensor could be interpreted as a TRT shape tensor.
IsTrtShapeTensorCompatible(const Tensor & tensor)140 inline bool IsTrtShapeTensorCompatible(const Tensor& tensor) {
141   return tensor.dtype() == DT_INT32 &&
142          IsTrtShapeTensorCompatible(tensor.shape());
143 }
144 
145 // Adapts various representations of shape (TF Shape, TRT Dims, plain
146 // containers) and provides methods for properties (length, volume) and
147 // conversion between types. Note that unlike TF's TensorShape, the underlying
148 // storage will only contain active dimensions. In the case of scalar shapes,
149 // `NumDims` is allowed to return 0 or 1, but the `storage_` vector will contain
150 // 1 element in both cases. In the non-scalar case, `NumDims() ==
151 // storage_.size()`.
152 class DimsAdapter {
153  public:
154   using StorageType = absl::InlinedVector<int64_t, 4>;
155 
156  private:
157   template <typename T>
158   using EnableIfNotTensorShapeType =
159       std::enable_if_t<!std::is_base_of<TensorShapeBase<T>, T>::value>;
160 
161   template <typename T>
162   using EnableIfInt = std::enable_if_t<std::is_arithmetic<T>::value &&
163                                        std::is_integral<T>::value>;
164 
165  public:
166   //----- Constructors ------
167 
168   // Constructs from an absl::Span.
169   template <typename T>
DimsAdapter(absl::Span<T> shape)170   explicit DimsAdapter(absl::Span<T> shape)
171       : num_dims_(static_cast<int32_t>(shape.size())) {
172     absl::c_copy(shape, std::back_inserter(storage_));
173   }
174 
175   // Constructs from an absl::Span.
176   template <typename T>
DimsAdapter(const std::vector<T> & shape)177   explicit DimsAdapter(const std::vector<T>& shape)
178       : num_dims_(static_cast<int32_t>(shape.size())) {
179     absl::c_copy(shape, std::back_inserter(storage_));
180   }
181 
182   // Constructs from a TRT dims object.
DimsAdapter(const nvinfer1::Dims & dims)183   DimsAdapter(const nvinfer1::Dims& dims) : num_dims_(dims.nbDims) {
184     absl::c_copy(absl::MakeSpan(dims.d, dims.d + std::max(dims.nbDims, 0)),
185                  std::back_inserter(storage_));
186   }
187 
188   // Constructs explicitly specifing num_dims and storage data.
DimsAdapter(int32_t num_dims,StorageType data)189   DimsAdapter(int32_t num_dims, StorageType data)
190       : num_dims_(num_dims), storage_(std::forward<StorageType>(data)) {}
191 
192   // Constructs from a TensorShape or PartialTensorShape.
193   template <typename T>
194   static StatusOr<DimsAdapter> Create(const TensorShapeBase<T>& shape,
195                                       bool ignore_first_dim = false) {
196     if (shape.dims() > nvinfer1::Dims::MAX_DIMS)
197       return errors::InvalidArgument("dims of TensorShape exceed MAX_DIMS");
198     if (ignore_first_dim && shape.dims() <= 0)
199       return errors::InvalidArgument(
200           "removing first dim requires explicit batch dimension");
201     if (shape.dims() == -1) {
202       return DimsAdapter(-1, StorageType{});
203     }
204     if (shape.dims() == 0) {
205       return DimsAdapter(0, StorageType{1});
206     }
207     auto offt = (ignore_first_dim ? 1 : 0);
208     return DimsAdapter(
209         absl::MakeSpan(shape.dim_sizes().begin() + offt, shape.dims() - offt));
210   }
211 
212   // Constructs from a container.
213   template <typename InputSequence,
214             typename = EnableIfNotTensorShapeType<InputSequence>>
215   static StatusOr<DimsAdapter> Create(const InputSequence& shape,
216                                       bool ignore_first_dim = false) {
217     if (ignore_first_dim && shape.size() <= 0) {
218       return errors::InvalidArgument(
219           "removing first dim requires explicit batch dimension");
220     }
221     return DimsAdapter(
222         absl::MakeSpan(shape).subspan(ignore_first_dim ? 1 : 0, shape.size()));
223   }
224 
225   //----- Conversion Utilities ------
226 
227   //  Converts to an nvinfers::Dims and assign the result to the object passed
228   //  in via the result pointer.
TrtDims(nvinfer1::Dims * result)229   void TrtDims(nvinfer1::Dims* result) const {
230     result->nbDims = num_dims_;
231     absl::c_copy(storage_, static_cast<int32_t*>(result->d));
232   }
233 
234   // Converts to an nvinfer1::Dims and return by value.
AsTrtDims()235   nvinfer1::Dims AsTrtDims() const {
236     nvinfer1::Dims result;
237     TrtDims(&result);
238     return result;
239   }
240 
241   // Converts to a TensorShape and assigns the result to the object passed in
242   // via the shape pointer.
243   Status TensorShape(TensorShape* shape,
244                      std::optional<int> batch_size = std::nullopt) const {
245     TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(
246         static_cast<const int64_t*>(storage_.data()), storage_.size(), shape));
247     if (batch_size) shape->InsertDim(0, *batch_size);
248     return Status::OK();
249   }
250 
251   // Converts to a PartialTensorShape and assigns the result to the object
252   // passed in via the shape pointer.
253   Status PartialTensorShape(
254       PartialTensorShape* shape,
255       std::optional<int> batch_size = std::nullopt) const {
256     TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(
257         static_cast<const int64_t*>(storage_.data()), storage_.size(), shape));
258     if (batch_size) shape->InsertDim(0, *batch_size);
259     return Status::OK();
260   }
261 
262   // Copies the dimension values to the vector passed in via the shape pointer.
263   template <typename T, typename = EnableIfInt<T>>
Vector(std::vector<T> * shape)264   Status Vector(std::vector<T>* shape) const {
265     shape->clear();
266     absl::c_copy(storage_, std::back_inserter(*shape));
267     return Status::OK();
268   }
269 
270   //----- Property Accessors ------
271 
272   // Returns true if the shape has no dynamic dimensions.
IsStatic()273   bool IsStatic() const {
274     return !absl::c_any_of(storage_, [](auto i) { return i < 0; });
275   }
276 
277   // Returns product of all dimensions.
Volume()278   int64_t Volume() const {
279     return absl::c_accumulate(storage_, static_cast<int64_t>(1),
280                               std::multiplies<>());
281   }
282 
NumDims()283   int32_t NumDims() const { return num_dims_; }
284 
285   // Returns true if the shape should be interpreted as a scalar. This follows
286   // TensorRT conversions: a scalar shape can have NumDims()==1 or NumDims()==0,
287   // but the underlying storage_ container has a single dimension of size 1.
IsScalar()288   bool IsScalar() const {
289     return (num_dims_ == 0 || num_dims_ == 1) && storage_.size() == 1 &&
290            storage_[0] == 1;
291   }
292 
293   // Returns true if the dimension storage is empty. This indicates an empty
294   // shape in both the scalar and non-scalar case.
IsEmpty()295   bool IsEmpty() const { return storage_.empty(); }
296 
DebugString()297   string DebugString() const {
298     auto vol = absl::c_accumulate(storage_, static_cast<int64_t>(1),
299                                   std::multiplies<>());
300     return absl::StrCat("DimsAdapter(num_dims=", num_dims_, ",shape=[",
301                         absl::StrJoin(storage_, ","), "],", "vol=", vol, ")");
302   }
303 
304   // Returns beginning iterator for the underlying storage.
begin()305   StorageType::const_iterator begin() const { return storage_.begin(); }
306 
307   // Returns ending iterator for the underlying storage.
end()308   StorageType::const_iterator end() const { return storage_.end(); }
309 
310   // Returns the size of the dimension at `idx`.
dim(size_t idx)311   StorageType::value_type dim(size_t idx) const { return storage_[idx]; }
312 
313   // Returns a references to the dimension at `idx`.
dim(size_t idx)314   StorageType::value_type& dim(size_t idx) { return storage_[idx]; }
315 
316   //----- Non-Const Operators ------
317 
Append(int32_t dim)318   DimsAdapter& Append(int32_t dim) {
319     StatusOr<bool> is_scalar = IsScalar();
320     if (!is_scalar.ok()) return *this;
321     num_dims_ = *is_scalar ? 2 : num_dims_ + 1;
322     storage_.push_back(dim);
323     return *this;
324   }
325 
Prepend(std::optional<int32_t> dim)326   DimsAdapter& Prepend(std::optional<int32_t> dim) {
327     if (dim) {
328       num_dims_ = IsScalar() ? 2 : num_dims_ + 1;
329       storage_.insert(storage_.begin(), *dim);
330     }
331     return *this;
332   }
333 
RemoveBatchDimension()334   Status RemoveBatchDimension() {
335     if (storage_.empty())
336       return errors::InvalidArgument(
337           "attempted to remove batch dim from scalar");
338     num_dims_ -= 1;
339     storage_.erase(storage_.begin());
340     return Status::OK();
341   }
342 
343   //----- Comparison Operators ------
344 
345   bool operator==(const DimsAdapter& rhs) const {
346     if (rhs.num_dims_ != num_dims_) return false;
347     for (int i = 0; i < num_dims_; i++) {
348       if (rhs.storage_[i] != storage_[i]) return false;
349     }
350     return true;
351   }
352 
353   bool operator!=(const DimsAdapter& rhs) const { return !(*this == rhs); }
354 
355  private:
356   int32_t num_dims_{0};
357   StorageType storage_{};
358 };
359 
360 Status GetNetworkInputShapes(const nvinfer1::INetworkDefinition* network,
361                              std::vector<PartialTensorShape>* input_shapes);
362 
363 Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type);
364 Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type);
365 
366 // Returns true if an engine built for cached_shapes can also run actual_shapes.
367 bool AreShapesCompatible(const std::vector<TensorShape>& actual_shapes,
368                          const std::vector<TensorShape>& cached_shapes);
369 
370 // Returns the number of inputs for the engine, which also correspends to the
371 // number of input tensors for the network. This can differ from the number of
372 // input bindings, because the number of total input bindings equals the number
373 // of profiles times the number of engine inputs.
374 int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine);
375 
376 // Returns the string representation for the assigned device or the requested
377 // device of the given node.
378 absl::string_view GetDeviceName(const Node* node);
379 
380 // Returns the ParsedName representation for the assigned device or the
381 // requested device string of the given node. If the device string is invalid,
382 // returns std::nullopt.
383 std::optional<DeviceNameUtils::ParsedName> GetDeviceParsedName(
384     const Node* node);
385 
386 // If the given two device assignments as compatible, returns the merge of the
387 // two assignments. Otherwise, returns std::nullopt.
388 std::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
389     const DeviceNameUtils::ParsedName& a, const DeviceNameUtils::ParsedName& b);
390 // Similar to the above, except that the second device assignment is represented
391 // by a string_view.
392 std::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
393     const DeviceNameUtils::ParsedName& a, absl::string_view b);
394 
395 }  // namespace tensorrt
396 }  // namespace tensorflow
397 
398 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
399 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
400