1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
17
18 #if GOOGLE_CUDA && GOOGLE_TENSORRT
19
20 #include "absl/strings/ascii.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/platform/errors.h"
24
25 namespace tensorflow {
26 namespace tensorrt {
27
DebugString(const nvinfer1::Dims & dims)28 string DebugString(const nvinfer1::Dims& dims) {
29 string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
30 for (int i = 0; i < std::max(dims.nbDims, 0); ++i) {
31 StrAppend(&out, dims.d[i]);
32 StrAppend(&out, ",");
33 }
34 StrAppend(&out, ")");
35 return out;
36 }
37
DebugString(const DataType tf_type)38 string DebugString(const DataType tf_type) {
39 switch (tf_type) {
40 case DT_FLOAT:
41 return "DT_FLOAT";
42 case DT_HALF:
43 return "DT_HALF";
44 case DT_INT32:
45 return "DT_INT32";
46 case DT_INT8:
47 return "DT_INT8";
48 case DT_BOOL:
49 return "DT_BOOL";
50 default:
51 return "Unknow TF DataType";
52 }
53 }
54
DebugString(const nvinfer1::DataType trt_dtype)55 string DebugString(const nvinfer1::DataType trt_dtype) {
56 switch (trt_dtype) {
57 case nvinfer1::DataType::kFLOAT:
58 return "kFLOAT";
59 case nvinfer1::DataType::kHALF:
60 return "kHALF";
61 case nvinfer1::DataType::kINT8:
62 return "kINT8";
63 case nvinfer1::DataType::kINT32:
64 return "kINT32";
65 case nvinfer1::DataType::kBOOL:
66 return "kBOOL";
67 default:
68 return "Invalid TRT data type";
69 }
70 }
71
DebugString(const nvinfer1::Permutation & permutation,int len)72 string DebugString(const nvinfer1::Permutation& permutation, int len) {
73 string out = "nvinfer1::Permutation(";
74 for (int i = 0; i < len; ++i) {
75 StrAppend(&out, permutation.order[i], ",");
76 }
77 StrAppend(&out, ")");
78 return out;
79 }
80
DebugString(const ITensorProxyPtr & tensor)81 string DebugString(const ITensorProxyPtr& tensor) {
82 return StrCat(
83 tensor->is_trt_tensor() ? "nvinfer1::ITensor(@" : "SimpleItensor(@",
84 reinterpret_cast<uintptr_t>(&tensor), ", name=", tensor->getName(),
85 ", dtype=", DebugString(tensor->getType()),
86 ", dims=", DebugString(tensor->getDimensions()), ")");
87 }
88
DebugString(const nvinfer1::ITensor & tensor)89 string DebugString(const nvinfer1::ITensor& tensor) {
90 return StrCat("nvinfer1::ITensor(@", reinterpret_cast<uintptr_t>(&tensor),
91 ", name=", tensor.getName(),
92 ", dtype=", DebugString(tensor.getType()),
93 ", dims=", DebugString(tensor.getDimensions()), ")");
94 }
95
DebugString(const std::vector<nvinfer1::Dims> & dimvec)96 string DebugString(const std::vector<nvinfer1::Dims>& dimvec) {
97 return absl::StrCat("[",
98 absl::StrJoin(dimvec, ",",
99 [](std::string* out, nvinfer1::Dims in) {
100 out->append(DebugString(in));
101 }),
102 "]");
103 }
104
DebugString(const std::vector<TensorShape> & shapes)105 string DebugString(const std::vector<TensorShape>& shapes) {
106 return TensorShapeUtils::ShapeListString(shapes);
107 }
108
DebugString(const std::vector<PartialTensorShape> & shapes)109 string DebugString(const std::vector<PartialTensorShape>& shapes) {
110 return PartialTensorShapeUtils::PartialShapeListString(shapes);
111 }
112
113 // Checks whether actual_shapes are compatible with cached_shapes. This should
114 // only be used in implicit batch mode (in explicit batch mode one needs to
115 // check the profile ranges). Therefore implicit batch mode is assumed.
116 // It is also assumed that both actual_shapes and cached_shapes have been
117 // verified by TRTEngineOp::VerifyInputShapes, which ensures that the batch size
118 // for all tensors are the same.
AreShapesCompatible(const std::vector<TensorShape> & actual_shapes,const std::vector<TensorShape> & cached_shapes)119 bool AreShapesCompatible(const std::vector<TensorShape>& actual_shapes,
120 const std::vector<TensorShape>& cached_shapes) {
121 auto match_shape = [](const TensorShape& actual_shape,
122 const TensorShape& cached_shape) {
123 // Match the rank.
124 if (actual_shape.dims() != cached_shape.dims()) return false;
125 // Match the batch size. In implicit batch mode cached_shape.dim_size(0) is
126 // the max batch size, which can be larger than the actual batch size.
127 if (actual_shape.dim_size(0) > cached_shape.dim_size(0)) return false;
128 // Match remaining dimensions.
129 for (int i = 1; i < actual_shape.dims(); ++i) {
130 if (actual_shape.dim_size(i) != cached_shape.dim_size(i)) return false;
131 }
132 return true;
133 };
134 for (int i = 0; i < actual_shapes.size(); ++i) {
135 if (!match_shape(actual_shapes[i], cached_shapes[i])) {
136 return false;
137 }
138 }
139 return true;
140 }
GetNetworkInputShapes(const nvinfer1::INetworkDefinition * network,std::vector<PartialTensorShape> * input_shapes)141 Status GetNetworkInputShapes(const nvinfer1::INetworkDefinition* network,
142 std::vector<PartialTensorShape>* input_shapes) {
143 const int n_inputs = network->getNbInputs();
144 input_shapes->resize(n_inputs);
145 for (int i = 0; i < n_inputs; i++) {
146 const ITensorProxyPtr input = network->getInput(i);
147 TF_RETURN_IF_ERROR(DimsAdapter(input->getDimensions())
148 .PartialTensorShape(&input_shapes->at(i)));
149 }
150 return Status::OK();
151 }
152
TfTypeToTrtType(DataType tf_type,nvinfer1::DataType * trt_type)153 Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) {
154 switch (tf_type) {
155 case DT_FLOAT:
156 *trt_type = nvinfer1::DataType::kFLOAT;
157 break;
158 case DT_HALF:
159 *trt_type = nvinfer1::DataType::kHALF;
160 break;
161 case DT_INT32:
162 *trt_type = nvinfer1::DataType::kINT32;
163 break;
164 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
165 case DT_BOOL:
166 *trt_type = nvinfer1::DataType::kBOOL;
167 break;
168 #endif
169 default:
170 return errors::InvalidArgument("Unsupported tensorflow data type ",
171 DataTypeString(tf_type));
172 }
173 return Status::OK();
174 }
175
TrtTypeToTfType(nvinfer1::DataType trt_type,DataType * tf_type)176 Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) {
177 switch (trt_type) {
178 case nvinfer1::DataType::kFLOAT:
179 *tf_type = DT_FLOAT;
180 break;
181 case nvinfer1::DataType::kHALF:
182 *tf_type = DT_HALF;
183 break;
184 case nvinfer1::DataType::kINT32:
185 *tf_type = DT_INT32;
186 break;
187 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
188 case nvinfer1::DataType::kBOOL:
189 *tf_type = DT_BOOL;
190 break;
191 #endif
192 default:
193 return errors::InvalidArgument("Invalid TRT data type");
194 }
195 return Status::OK();
196 }
197
GetNumberOfEngineInputs(const nvinfer1::ICudaEngine * engine)198 int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) {
199 int n_bindings = engine->getNbBindings();
200 int n_input = 0;
201 for (int i = 0; i < n_bindings; i++) {
202 if (engine->bindingIsInput(i)) n_input++;
203 }
204 // According to TensorRT 7 doc: "If the engine has been built for K profiles,
205 // the first getNbBindings() / K bindings are used by profile number 0, the
206 // following getNbBindings() / K bindings are used by profile number 1 etc."
207 // Therefore, to get the number of input tensors, we need to divide by the
208 // the number of profiles.
209 int n_profiles = engine->getNbOptimizationProfiles();
210 return n_input / n_profiles;
211 }
212
GetDeviceName(const Node * node)213 absl::string_view GetDeviceName(const Node* node) {
214 if (node->has_assigned_device_name()) {
215 return node->assigned_device_name();
216 }
217 return node->requested_device();
218 }
219
GetDeviceParsedName(const Node * node)220 std::optional<DeviceNameUtils::ParsedName> GetDeviceParsedName(
221 const Node* node) {
222 absl::string_view device_name = GetDeviceName(node);
223 DeviceNameUtils::ParsedName parsed_name;
224 if (!DeviceNameUtils::ParseFullName(device_name, &parsed_name)) {
225 return std::nullopt;
226 }
227 return parsed_name;
228 }
229
MergeIfCompatible(const DeviceNameUtils::ParsedName & a,const DeviceNameUtils::ParsedName & b)230 std::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
231 const DeviceNameUtils::ParsedName& a,
232 const DeviceNameUtils::ParsedName& b) {
233 DeviceNameUtils::ParsedName merged_name = a;
234 if (!DeviceNameUtils::MergeDevNames(&merged_name, b,
235 /*allow_soft_placement=*/false)
236 .ok()) {
237 return std::nullopt;
238 }
239 return merged_name;
240 }
241
MergeIfCompatible(const DeviceNameUtils::ParsedName & a,absl::string_view b)242 std::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
243 const DeviceNameUtils::ParsedName& a, absl::string_view b) {
244 DeviceNameUtils::ParsedName b_parsed_name;
245 if (!DeviceNameUtils::ParseFullName(b, &b_parsed_name)) {
246 return std::nullopt;
247 }
248
249 return MergeIfCompatible(a, b_parsed_name);
250 }
251
252 } // namespace tensorrt
253 } // namespace tensorflow
254
255 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
256