xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/common/utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_COMMON_UTILS_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_COMMON_UTILS_H_
18 
19 #include <numeric>
20 #include <tuple>
21 
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/core/lib/core/status.h"
24 
25 namespace tensorflow {
26 namespace tensorrt {
27 // Returns the compile time TensorRT library version information
28 // {Maj, Min, Patch}.
29 std::tuple<int, int, int> GetLinkedTensorRTVersion();
30 
31 // Returns the runtime time TensorRT library version information
32 // {Maj, Min, Patch}.
33 std::tuple<int, int, int> GetLoadedTensorRTVersion();
34 }  // namespace tensorrt
35 }  // namespace tensorflow
36 
37 #if GOOGLE_CUDA && GOOGLE_TENSORRT
38 
39 #include "tensorflow/core/platform/errors.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/status.h"
42 #include "third_party/tensorrt/NvInfer.h"
43 
44 #define TFTRT_INTERNAL_ERROR_AT_NODE(node)                           \
45   do {                                                               \
46     return errors::Internal("TFTRT::", __FUNCTION__, ":", __LINE__,  \
47                             " failed to add TRT layer, at: ", node); \
48   } while (0)
49 
50 #define TFTRT_RETURN_ERROR_IF_NULLPTR(ptr, node) \
51   do {                                           \
52     if (ptr == nullptr) {                        \
53       TFTRT_INTERNAL_ERROR_AT_NODE(node);        \
54     }                                            \
55   } while (0)
56 
57 // Use this macro within functions that return a Status or StatusOR<T> to check
58 // boolean conditions. If the condition fails, it returns an
59 // errors::Internal message with the file and line number.
60 #define TRT_ENSURE(x)                                                        \
61   if (!(x)) {                                                                \
62     return errors::Internal(__FILE__, ":", __LINE__, " TRT_ENSURE failure"); \
63   }
64 
65 // Checks that a Status or StatusOr<T> object does not carry an error message.
66 // If it does have an error, returns an errors::Internal instance
67 // containing the error message, along with the file and line number. For
68 // pointer-containing StatusOr<T*>, use the below TRT_ENSURE_PTR_OK macro.
69 #define TRT_ENSURE_OK(x)                                   \
70   if (!x.ok()) {                                           \
71     return errors::Internal(__FILE__, ":", __LINE__,       \
72                             " TRT_ENSURE_OK failure:\n  ", \
73                             x.status().ToString());        \
74   }
75 
76 // Checks that a StatusOr<T* >object does not carry an error, and that the
77 // contained T* is non-null. If it does have an error status, returns an
78 // errors::Internal instance containing the error message, along with the file
79 // and line number.
80 #define TRT_ENSURE_PTR_OK(x)                            \
81   TRT_ENSURE_OK(x);                                     \
82   if (*x == nullptr) {                                  \
83     return errors::Internal(__FILE__, ":", __LINE__,    \
84                             " pointer had null value"); \
85   }
86 
87 namespace tensorflow {
88 namespace tensorrt {
89 
90 #define IS_TRT_VERSION_GE(major, minor, patch, build)           \
91   ((NV_TENSORRT_MAJOR > major) ||                               \
92    (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \
93    (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
94     NV_TENSORRT_PATCH > patch) ||                               \
95    (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
96     NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build))
97 
98 #define LOG_WARNING_WITH_PREFIX LOG(WARNING) << "TF-TRT Warning: "
99 
100 // Initializes the TensorRT plugin registry if this hasn't been done yet.
101 void MaybeInitializeTrtPlugins(nvinfer1::ILogger* trt_logger);
102 
103 class IONamePrefixes {
104  public:
105   static constexpr const char* const kInputPHName = "TensorRTInputPH_";
106   static constexpr const char* const kOutputPHName = "TensorRTOutputPH_";
107 };
108 
109 // Gets the binding index of a tensor in an engine.
110 //
111 // The binding index is looked up using the tensor's name and the profile index.
112 // Profile index should be set to zero, if we do not have optimization profiles.
113 Status GetTrtBindingIndex(const char* tensor_name, int profile_index,
114                           const nvinfer1::ICudaEngine* cuda_engine,
115                           int* binding_index);
116 
117 // Gets the binding index of a tensor in an engine.
118 //
119 // Same as above, but uses the network input index to identify the tensor.
120 Status GetTrtBindingIndex(int network_input_idx, int profile_index,
121                           const nvinfer1::ICudaEngine* cuda_engine,
122                           int* binding_index);
123 }  // namespace tensorrt
124 }  // namespace tensorflow
125 
126 namespace nvinfer1 {
127 // Prints nvinfer1::Dims or any drived type to the given ostream. Per GTest
128 // printing requirements, this must be in the nvinfer1 namespace.
129 inline std::ostream& operator<<(std::ostream& os, const nvinfer1::Dims& v) {
130   os << "nvinfer1::Dims[";
131   os << absl::StrJoin(std::vector<int>(v.d, v.d + v.nbDims), ",");
132   os << "]";
133   return os;
134 }
135 
136 // Returns true if any two derived nvinfer1::Dims type structs are equivalent.
137 inline bool operator==(const nvinfer1::Dims& lhs, const nvinfer1::Dims& rhs) {
138   if (rhs.nbDims != lhs.nbDims) {
139     return false;
140   }
141   for (int i = 0; i < lhs.nbDims; i++) {
142     if (rhs.d[i] != lhs.d[i]) {
143       return false;
144     }
145   }
146   return true;
147 }
148 
149 // Returns false if any 2 subclasses of nvinfer1::Dims are equivalent.
150 inline bool operator!=(const nvinfer1::Dims& lhs, const nvinfer1::Dims& rhs) {
151   return !(rhs == lhs);
152 }
153 
154 // Prints nvinfer1::INetworkDefinition* information to the given ostream.
155 inline std::ostream& operator<<(std::ostream& os,
156                                 nvinfer1::INetworkDefinition* n) {
157   os << "nvinfer1::INetworkDefinition{\n";
158   std::vector<int> layer_idxs(n->getNbLayers());
159   std::iota(layer_idxs.begin(), layer_idxs.end(), 0);
160   os << absl::StrJoin(layer_idxs, "\n ",
161                       [n](std::string* out, const int layer_idx) {
162                         out->append(n->getLayer(layer_idx)->getName());
163                       });
164   os << "}";
165   return os;
166 }
167 
168 // Prints the TensorFormat enum name to the stream.
169 std::ostream& operator<<(std::ostream& os,
170                          const nvinfer1::TensorFormat& format);
171 
172 // Prints the DataType enum name to the stream.
173 std::ostream& operator<<(std::ostream& os, const nvinfer1::DataType& data_type);
174 
175 }  // namespace nvinfer1
176 
177 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
178 
179 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_COMMON_UTILS_H_
180