1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
17
18 #if GOOGLE_CUDA && GOOGLE_TENSORRT
19 #include "absl/base/call_once.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/profiler/lib/traceme.h"
24 #include "third_party/tensorrt/NvInferPlugin.h"
25
26 #endif
27
28 namespace tensorflow {
29 namespace tensorrt {
30
GetLinkedTensorRTVersion()31 std::tuple<int, int, int> GetLinkedTensorRTVersion() {
32 #if GOOGLE_CUDA && GOOGLE_TENSORRT
33 return std::tuple<int, int, int>{NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR,
34 NV_TENSORRT_PATCH};
35 #else
36 return std::tuple<int, int, int>{0, 0, 0};
37 #endif
38 }
39
GetLoadedTensorRTVersion()40 std::tuple<int, int, int> GetLoadedTensorRTVersion() {
41 #if GOOGLE_CUDA && GOOGLE_TENSORRT
42 int ver = getInferLibVersion();
43 int major = ver / 1000;
44 ver = ver - major * 1000;
45 int minor = ver / 100;
46 int patch = ver - minor * 100;
47 return std::tuple<int, int, int>{major, minor, patch};
48 #else
49 return std::tuple<int, int, int>{0, 0, 0};
50 #endif
51 }
52
53 } // namespace tensorrt
54 } // namespace tensorflow
55
56 #if GOOGLE_CUDA && GOOGLE_TENSORRT
57 namespace tensorflow {
58 namespace tensorrt {
59
GetTrtBindingIndex(const char * tensor_name,int profile_index,const nvinfer1::ICudaEngine * cuda_engine,int * binding_index)60 Status GetTrtBindingIndex(const char* tensor_name, int profile_index,
61 const nvinfer1::ICudaEngine* cuda_engine,
62 int* binding_index) {
63 tensorflow::profiler::TraceMe activity(
64 "GetTrtBindingIndex", tensorflow::profiler::TraceMeLevel::kInfo);
65 // If the engine has been built for K profiles, the first getNbBindings() / K
66 // bindings are used by profile number 0, the following getNbBindings() / K
67 // bindings are used by profile number 1 etc.
68 //
69 // GetBindingIndex(tensor_name) returns the binding index for the progile 0.
70 // We can also consider it as a "binding_index_within_profile".
71 *binding_index = cuda_engine->getBindingIndex(tensor_name);
72 if (*binding_index == -1) {
73 const string msg = absl::StrCat("Input node ", tensor_name, " not found");
74 return errors::NotFound(msg);
75 }
76 int n_profiles = cuda_engine->getNbOptimizationProfiles();
77 // If we have more then one optimization profile, then we need to shift the
78 // binding index according to the following formula:
79 // binding_index_within_engine = binding_index_within_profile +
80 // profile_index * bindings_per_profile
81 const int bindings_per_profile = cuda_engine->getNbBindings() / n_profiles;
82 *binding_index = *binding_index + profile_index * bindings_per_profile;
83 return Status::OK();
84 }
85
GetTrtBindingIndex(int network_input_index,int profile_index,const nvinfer1::ICudaEngine * cuda_engine,int * binding_index)86 Status GetTrtBindingIndex(int network_input_index, int profile_index,
87 const nvinfer1::ICudaEngine* cuda_engine,
88 int* binding_index) {
89 const string input_name =
90 absl::StrCat(IONamePrefixes::kInputPHName, network_input_index);
91 return GetTrtBindingIndex(input_name.c_str(), profile_index, cuda_engine,
92 binding_index);
93 }
94
95 namespace {
96
InitializeTrtPlugins(nvinfer1::ILogger * trt_logger)97 void InitializeTrtPlugins(nvinfer1::ILogger* trt_logger) {
98 #if defined(PLATFORM_WINDOWS)
99 LOG_WARNING_WITH_PREFIX
100 << "Windows support is provided experimentally. No guarantee is made "
101 "regarding functionality or engineering support. Use at your own "
102 "risk.";
103 #endif
104 LOG(INFO) << "Linked TensorRT version: "
105 << absl::StrJoin(GetLinkedTensorRTVersion(), ".");
106 LOG(INFO) << "Loaded TensorRT version: "
107 << absl::StrJoin(GetLoadedTensorRTVersion(), ".");
108
109 bool plugin_initialized = initLibNvInferPlugins(trt_logger, "");
110 if (!plugin_initialized) {
111 LOG(ERROR) << "Failed to initialize TensorRT plugins, and conversion may "
112 "fail later.";
113 }
114
115 int num_trt_plugins = 0;
116 nvinfer1::IPluginCreator* const* trt_plugin_creator_list =
117 getPluginRegistry()->getPluginCreatorList(&num_trt_plugins);
118 if (!trt_plugin_creator_list) {
119 LOG_WARNING_WITH_PREFIX << "Can not find any TensorRT plugins in registry.";
120 } else {
121 VLOG(1) << "Found the following " << num_trt_plugins
122 << " TensorRT plugins in registry:";
123 for (int i = 0; i < num_trt_plugins; ++i) {
124 if (!trt_plugin_creator_list[i]) {
125 LOG_WARNING_WITH_PREFIX
126 << "TensorRT plugin at index " << i
127 << " is not accessible (null pointer returned by "
128 "getPluginCreatorList for this plugin)";
129 } else {
130 VLOG(1) << " " << trt_plugin_creator_list[i]->getPluginName();
131 }
132 }
133 }
134 }
135
136 } // namespace
137
MaybeInitializeTrtPlugins(nvinfer1::ILogger * trt_logger)138 void MaybeInitializeTrtPlugins(nvinfer1::ILogger* trt_logger) {
139 static absl::once_flag once;
140 absl::call_once(once, InitializeTrtPlugins, trt_logger);
141 }
142
143 } // namespace tensorrt
144 } // namespace tensorflow
145
146 namespace nvinfer1 {
operator <<(std::ostream & os,const nvinfer1::TensorFormat & format)147 std::ostream& operator<<(std::ostream& os,
148 const nvinfer1::TensorFormat& format) {
149 os << "nvinfer1::TensorFormat::";
150 switch (format) {
151 case nvinfer1::TensorFormat::kLINEAR:
152 os << "kLINEAR";
153 break;
154
155 case nvinfer1::TensorFormat::kCHW2:
156 os << "kCHW2";
157 break;
158
159 case nvinfer1::TensorFormat::kHWC8:
160 os << "kHWC8";
161 break;
162
163 case nvinfer1::TensorFormat::kCHW4:
164 os << "kCHW4";
165 break;
166
167 case nvinfer1::TensorFormat::kCHW16:
168 os << "kCHW16";
169 break;
170
171 case nvinfer1::TensorFormat::kCHW32:
172 os << "kCHW32";
173 break;
174
175 #if IS_TRT_VERSION_GE(8, 0, 0, 0)
176 case nvinfer1::TensorFormat::kDHWC8:
177 os << "kDHWC8";
178 break;
179
180 case nvinfer1::TensorFormat::kCDHW32:
181 os << "kCDHW32";
182 break;
183
184 case nvinfer1::TensorFormat::kHWC:
185 os << "kHWC";
186 break;
187
188 case nvinfer1::TensorFormat::kDLA_LINEAR:
189 os << "kDLA_LINEAR";
190 break;
191
192 case nvinfer1::TensorFormat::kDLA_HWC4:
193 os << "kDLA_HWC4";
194 break;
195
196 case nvinfer1::TensorFormat::kHWC16:
197 os << "kHWC16";
198 break;
199 #endif
200
201 default:
202 os << "unknown format";
203 }
204 return os;
205 }
206
operator <<(std::ostream & os,const nvinfer1::DataType & v)207 std::ostream& operator<<(std::ostream& os, const nvinfer1::DataType& v) {
208 os << "nvinfer1::DataType::";
209 switch (v) {
210 case nvinfer1::DataType::kFLOAT:
211 os << "kFLOAT";
212 break;
213 case nvinfer1::DataType::kHALF:
214 os << "kHalf";
215 break;
216 case nvinfer1::DataType::kINT8:
217 os << "kINT8";
218 break;
219 case nvinfer1::DataType::kINT32:
220 os << "kINT32";
221 break;
222 case nvinfer1::DataType::kBOOL:
223 os << "kBOOL";
224 break;
225 }
226 return os;
227 }
228 } // namespace nvinfer1
229
230 #endif
231