xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/common/utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
17 
18 #if GOOGLE_CUDA && GOOGLE_TENSORRT
19 #include "absl/base/call_once.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/profiler/lib/traceme.h"
24 #include "third_party/tensorrt/NvInferPlugin.h"
25 
26 #endif
27 
28 namespace tensorflow {
29 namespace tensorrt {
30 
GetLinkedTensorRTVersion()31 std::tuple<int, int, int> GetLinkedTensorRTVersion() {
32 #if GOOGLE_CUDA && GOOGLE_TENSORRT
33   return std::tuple<int, int, int>{NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR,
34                                    NV_TENSORRT_PATCH};
35 #else
36   return std::tuple<int, int, int>{0, 0, 0};
37 #endif
38 }
39 
GetLoadedTensorRTVersion()40 std::tuple<int, int, int> GetLoadedTensorRTVersion() {
41 #if GOOGLE_CUDA && GOOGLE_TENSORRT
42   int ver = getInferLibVersion();
43   int major = ver / 1000;
44   ver = ver - major * 1000;
45   int minor = ver / 100;
46   int patch = ver - minor * 100;
47   return std::tuple<int, int, int>{major, minor, patch};
48 #else
49   return std::tuple<int, int, int>{0, 0, 0};
50 #endif
51 }
52 
53 }  // namespace tensorrt
54 }  // namespace tensorflow
55 
56 #if GOOGLE_CUDA && GOOGLE_TENSORRT
57 namespace tensorflow {
58 namespace tensorrt {
59 
GetTrtBindingIndex(const char * tensor_name,int profile_index,const nvinfer1::ICudaEngine * cuda_engine,int * binding_index)60 Status GetTrtBindingIndex(const char* tensor_name, int profile_index,
61                           const nvinfer1::ICudaEngine* cuda_engine,
62                           int* binding_index) {
63   tensorflow::profiler::TraceMe activity(
64       "GetTrtBindingIndex", tensorflow::profiler::TraceMeLevel::kInfo);
65   // If the engine has been built for K profiles, the first getNbBindings() / K
66   // bindings are used by profile number 0, the following getNbBindings() / K
67   // bindings are used by profile number 1 etc.
68   //
69   // GetBindingIndex(tensor_name) returns the binding index for the progile 0.
70   // We can also consider it as a "binding_index_within_profile".
71   *binding_index = cuda_engine->getBindingIndex(tensor_name);
72   if (*binding_index == -1) {
73     const string msg = absl::StrCat("Input node ", tensor_name, " not found");
74     return errors::NotFound(msg);
75   }
76   int n_profiles = cuda_engine->getNbOptimizationProfiles();
77   // If we have more then one optimization profile, then we need to shift the
78   // binding index according to the following formula:
79   // binding_index_within_engine = binding_index_within_profile +
80   //                               profile_index * bindings_per_profile
81   const int bindings_per_profile = cuda_engine->getNbBindings() / n_profiles;
82   *binding_index = *binding_index + profile_index * bindings_per_profile;
83   return Status::OK();
84 }
85 
GetTrtBindingIndex(int network_input_index,int profile_index,const nvinfer1::ICudaEngine * cuda_engine,int * binding_index)86 Status GetTrtBindingIndex(int network_input_index, int profile_index,
87                           const nvinfer1::ICudaEngine* cuda_engine,
88                           int* binding_index) {
89   const string input_name =
90       absl::StrCat(IONamePrefixes::kInputPHName, network_input_index);
91   return GetTrtBindingIndex(input_name.c_str(), profile_index, cuda_engine,
92                             binding_index);
93 }
94 
95 namespace {
96 
InitializeTrtPlugins(nvinfer1::ILogger * trt_logger)97 void InitializeTrtPlugins(nvinfer1::ILogger* trt_logger) {
98 #if defined(PLATFORM_WINDOWS)
99   LOG_WARNING_WITH_PREFIX
100       << "Windows support is provided experimentally. No guarantee is made "
101          "regarding functionality or engineering support. Use at your own "
102          "risk.";
103 #endif
104   LOG(INFO) << "Linked TensorRT version: "
105             << absl::StrJoin(GetLinkedTensorRTVersion(), ".");
106   LOG(INFO) << "Loaded TensorRT version: "
107             << absl::StrJoin(GetLoadedTensorRTVersion(), ".");
108 
109   bool plugin_initialized = initLibNvInferPlugins(trt_logger, "");
110   if (!plugin_initialized) {
111     LOG(ERROR) << "Failed to initialize TensorRT plugins, and conversion may "
112                   "fail later.";
113   }
114 
115   int num_trt_plugins = 0;
116   nvinfer1::IPluginCreator* const* trt_plugin_creator_list =
117       getPluginRegistry()->getPluginCreatorList(&num_trt_plugins);
118   if (!trt_plugin_creator_list) {
119     LOG_WARNING_WITH_PREFIX << "Can not find any TensorRT plugins in registry.";
120   } else {
121     VLOG(1) << "Found the following " << num_trt_plugins
122             << " TensorRT plugins in registry:";
123     for (int i = 0; i < num_trt_plugins; ++i) {
124       if (!trt_plugin_creator_list[i]) {
125         LOG_WARNING_WITH_PREFIX
126             << "TensorRT plugin at index " << i
127             << " is not accessible (null pointer returned by "
128                "getPluginCreatorList for this plugin)";
129       } else {
130         VLOG(1) << "  " << trt_plugin_creator_list[i]->getPluginName();
131       }
132     }
133   }
134 }
135 
136 }  // namespace
137 
MaybeInitializeTrtPlugins(nvinfer1::ILogger * trt_logger)138 void MaybeInitializeTrtPlugins(nvinfer1::ILogger* trt_logger) {
139   static absl::once_flag once;
140   absl::call_once(once, InitializeTrtPlugins, trt_logger);
141 }
142 
143 }  // namespace tensorrt
144 }  // namespace tensorflow
145 
146 namespace nvinfer1 {
operator <<(std::ostream & os,const nvinfer1::TensorFormat & format)147 std::ostream& operator<<(std::ostream& os,
148                          const nvinfer1::TensorFormat& format) {
149   os << "nvinfer1::TensorFormat::";
150   switch (format) {
151     case nvinfer1::TensorFormat::kLINEAR:
152       os << "kLINEAR";
153       break;
154 
155     case nvinfer1::TensorFormat::kCHW2:
156       os << "kCHW2";
157       break;
158 
159     case nvinfer1::TensorFormat::kHWC8:
160       os << "kHWC8";
161       break;
162 
163     case nvinfer1::TensorFormat::kCHW4:
164       os << "kCHW4";
165       break;
166 
167     case nvinfer1::TensorFormat::kCHW16:
168       os << "kCHW16";
169       break;
170 
171     case nvinfer1::TensorFormat::kCHW32:
172       os << "kCHW32";
173       break;
174 
175 #if IS_TRT_VERSION_GE(8, 0, 0, 0)
176     case nvinfer1::TensorFormat::kDHWC8:
177       os << "kDHWC8";
178       break;
179 
180     case nvinfer1::TensorFormat::kCDHW32:
181       os << "kCDHW32";
182       break;
183 
184     case nvinfer1::TensorFormat::kHWC:
185       os << "kHWC";
186       break;
187 
188     case nvinfer1::TensorFormat::kDLA_LINEAR:
189       os << "kDLA_LINEAR";
190       break;
191 
192     case nvinfer1::TensorFormat::kDLA_HWC4:
193       os << "kDLA_HWC4";
194       break;
195 
196     case nvinfer1::TensorFormat::kHWC16:
197       os << "kHWC16";
198       break;
199 #endif
200 
201     default:
202       os << "unknown format";
203   }
204   return os;
205 }
206 
operator <<(std::ostream & os,const nvinfer1::DataType & v)207 std::ostream& operator<<(std::ostream& os, const nvinfer1::DataType& v) {
208   os << "nvinfer1::DataType::";
209   switch (v) {
210     case nvinfer1::DataType::kFLOAT:
211       os << "kFLOAT";
212       break;
213     case nvinfer1::DataType::kHALF:
214       os << "kHalf";
215       break;
216     case nvinfer1::DataType::kINT8:
217       os << "kINT8";
218       break;
219     case nvinfer1::DataType::kINT32:
220       os << "kINT32";
221       break;
222     case nvinfer1::DataType::kBOOL:
223       os << "kBOOL";
224       break;
225   }
226   return os;
227 }
228 }  // namespace nvinfer1
229 
230 #endif
231