xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
17 
18 #if GOOGLE_CUDA && GOOGLE_TENSORRT
19 
20 #include "absl/strings/ascii.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/platform/errors.h"
24 
25 namespace tensorflow {
26 namespace tensorrt {
27 
DebugString(const nvinfer1::Dims & dims)28 string DebugString(const nvinfer1::Dims& dims) {
29   string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
30   for (int i = 0; i < std::max(dims.nbDims, 0); ++i) {
31     StrAppend(&out, dims.d[i]);
32     StrAppend(&out, ",");
33   }
34   StrAppend(&out, ")");
35   return out;
36 }
37 
DebugString(const DataType tf_type)38 string DebugString(const DataType tf_type) {
39   switch (tf_type) {
40     case DT_FLOAT:
41       return "DT_FLOAT";
42     case DT_HALF:
43       return "DT_HALF";
44     case DT_INT32:
45       return "DT_INT32";
46     case DT_INT8:
47       return "DT_INT8";
48     case DT_BOOL:
49       return "DT_BOOL";
50     default:
51       return "Unknow TF DataType";
52   }
53 }
54 
DebugString(const nvinfer1::DataType trt_dtype)55 string DebugString(const nvinfer1::DataType trt_dtype) {
56   switch (trt_dtype) {
57     case nvinfer1::DataType::kFLOAT:
58       return "kFLOAT";
59     case nvinfer1::DataType::kHALF:
60       return "kHALF";
61     case nvinfer1::DataType::kINT8:
62       return "kINT8";
63     case nvinfer1::DataType::kINT32:
64       return "kINT32";
65     case nvinfer1::DataType::kBOOL:
66       return "kBOOL";
67     default:
68       return "Invalid TRT data type";
69   }
70 }
71 
DebugString(const nvinfer1::Permutation & permutation,int len)72 string DebugString(const nvinfer1::Permutation& permutation, int len) {
73   string out = "nvinfer1::Permutation(";
74   for (int i = 0; i < len; ++i) {
75     StrAppend(&out, permutation.order[i], ",");
76   }
77   StrAppend(&out, ")");
78   return out;
79 }
80 
DebugString(const ITensorProxyPtr & tensor)81 string DebugString(const ITensorProxyPtr& tensor) {
82   return StrCat(
83       tensor->is_trt_tensor() ? "nvinfer1::ITensor(@" : "SimpleItensor(@",
84       reinterpret_cast<uintptr_t>(&tensor), ", name=", tensor->getName(),
85       ", dtype=", DebugString(tensor->getType()),
86       ", dims=", DebugString(tensor->getDimensions()), ")");
87 }
88 
DebugString(const nvinfer1::ITensor & tensor)89 string DebugString(const nvinfer1::ITensor& tensor) {
90   return StrCat("nvinfer1::ITensor(@", reinterpret_cast<uintptr_t>(&tensor),
91                 ", name=", tensor.getName(),
92                 ", dtype=", DebugString(tensor.getType()),
93                 ", dims=", DebugString(tensor.getDimensions()), ")");
94 }
95 
DebugString(const std::vector<nvinfer1::Dims> & dimvec)96 string DebugString(const std::vector<nvinfer1::Dims>& dimvec) {
97   return absl::StrCat("[",
98                       absl::StrJoin(dimvec, ",",
99                                     [](std::string* out, nvinfer1::Dims in) {
100                                       out->append(DebugString(in));
101                                     }),
102                       "]");
103 }
104 
DebugString(const std::vector<TensorShape> & shapes)105 string DebugString(const std::vector<TensorShape>& shapes) {
106   return TensorShapeUtils::ShapeListString(shapes);
107 }
108 
DebugString(const std::vector<PartialTensorShape> & shapes)109 string DebugString(const std::vector<PartialTensorShape>& shapes) {
110   return PartialTensorShapeUtils::PartialShapeListString(shapes);
111 }
112 
113 // Checks whether actual_shapes are compatible with cached_shapes. This should
114 // only be used in implicit batch mode (in explicit batch mode one needs to
115 // check the profile ranges). Therefore implicit batch mode is assumed.
116 // It is also assumed that both actual_shapes and cached_shapes have been
117 // verified by TRTEngineOp::VerifyInputShapes, which ensures that the batch size
118 // for all tensors are the same.
AreShapesCompatible(const std::vector<TensorShape> & actual_shapes,const std::vector<TensorShape> & cached_shapes)119 bool AreShapesCompatible(const std::vector<TensorShape>& actual_shapes,
120                          const std::vector<TensorShape>& cached_shapes) {
121   auto match_shape = [](const TensorShape& actual_shape,
122                         const TensorShape& cached_shape) {
123     // Match the rank.
124     if (actual_shape.dims() != cached_shape.dims()) return false;
125     // Match the batch size. In implicit batch mode cached_shape.dim_size(0) is
126     // the max batch size, which can be larger than the actual batch size.
127     if (actual_shape.dim_size(0) > cached_shape.dim_size(0)) return false;
128     // Match remaining dimensions.
129     for (int i = 1; i < actual_shape.dims(); ++i) {
130       if (actual_shape.dim_size(i) != cached_shape.dim_size(i)) return false;
131     }
132     return true;
133   };
134   for (int i = 0; i < actual_shapes.size(); ++i) {
135     if (!match_shape(actual_shapes[i], cached_shapes[i])) {
136       return false;
137     }
138   }
139   return true;
140 }
GetNetworkInputShapes(const nvinfer1::INetworkDefinition * network,std::vector<PartialTensorShape> * input_shapes)141 Status GetNetworkInputShapes(const nvinfer1::INetworkDefinition* network,
142                              std::vector<PartialTensorShape>* input_shapes) {
143   const int n_inputs = network->getNbInputs();
144   input_shapes->resize(n_inputs);
145   for (int i = 0; i < n_inputs; i++) {
146     const ITensorProxyPtr input = network->getInput(i);
147     TF_RETURN_IF_ERROR(DimsAdapter(input->getDimensions())
148                            .PartialTensorShape(&input_shapes->at(i)));
149   }
150   return Status::OK();
151 }
152 
TfTypeToTrtType(DataType tf_type,nvinfer1::DataType * trt_type)153 Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) {
154   switch (tf_type) {
155     case DT_FLOAT:
156       *trt_type = nvinfer1::DataType::kFLOAT;
157       break;
158     case DT_HALF:
159       *trt_type = nvinfer1::DataType::kHALF;
160       break;
161     case DT_INT32:
162       *trt_type = nvinfer1::DataType::kINT32;
163       break;
164 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
165     case DT_BOOL:
166       *trt_type = nvinfer1::DataType::kBOOL;
167       break;
168 #endif
169     default:
170       return errors::InvalidArgument("Unsupported tensorflow data type ",
171                                      DataTypeString(tf_type));
172   }
173   return Status::OK();
174 }
175 
TrtTypeToTfType(nvinfer1::DataType trt_type,DataType * tf_type)176 Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) {
177   switch (trt_type) {
178     case nvinfer1::DataType::kFLOAT:
179       *tf_type = DT_FLOAT;
180       break;
181     case nvinfer1::DataType::kHALF:
182       *tf_type = DT_HALF;
183       break;
184     case nvinfer1::DataType::kINT32:
185       *tf_type = DT_INT32;
186       break;
187 #if IS_TRT_VERSION_GE(8, 2, 0, 0)
188     case nvinfer1::DataType::kBOOL:
189       *tf_type = DT_BOOL;
190       break;
191 #endif
192     default:
193       return errors::InvalidArgument("Invalid TRT data type");
194   }
195   return Status::OK();
196 }
197 
GetNumberOfEngineInputs(const nvinfer1::ICudaEngine * engine)198 int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) {
199   int n_bindings = engine->getNbBindings();
200   int n_input = 0;
201   for (int i = 0; i < n_bindings; i++) {
202     if (engine->bindingIsInput(i)) n_input++;
203   }
204   // According to TensorRT 7 doc: "If the engine has been built for K profiles,
205   // the first getNbBindings() / K bindings are used by profile number 0, the
206   // following getNbBindings() / K bindings are used by profile number 1 etc."
207   // Therefore, to get the number of input tensors, we need to divide by the
208   // the number of profiles.
209   int n_profiles = engine->getNbOptimizationProfiles();
210   return n_input / n_profiles;
211 }
212 
GetDeviceName(const Node * node)213 absl::string_view GetDeviceName(const Node* node) {
214   if (node->has_assigned_device_name()) {
215     return node->assigned_device_name();
216   }
217   return node->requested_device();
218 }
219 
GetDeviceParsedName(const Node * node)220 std::optional<DeviceNameUtils::ParsedName> GetDeviceParsedName(
221     const Node* node) {
222   absl::string_view device_name = GetDeviceName(node);
223   DeviceNameUtils::ParsedName parsed_name;
224   if (!DeviceNameUtils::ParseFullName(device_name, &parsed_name)) {
225     return std::nullopt;
226   }
227   return parsed_name;
228 }
229 
MergeIfCompatible(const DeviceNameUtils::ParsedName & a,const DeviceNameUtils::ParsedName & b)230 std::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
231     const DeviceNameUtils::ParsedName& a,
232     const DeviceNameUtils::ParsedName& b) {
233   DeviceNameUtils::ParsedName merged_name = a;
234   if (!DeviceNameUtils::MergeDevNames(&merged_name, b,
235                                       /*allow_soft_placement=*/false)
236            .ok()) {
237     return std::nullopt;
238   }
239   return merged_name;
240 }
241 
MergeIfCompatible(const DeviceNameUtils::ParsedName & a,absl::string_view b)242 std::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
243     const DeviceNameUtils::ParsedName& a, absl::string_view b) {
244   DeviceNameUtils::ParsedName b_parsed_name;
245   if (!DeviceNameUtils::ParseFullName(b, &b_parsed_name)) {
246     return std::nullopt;
247   }
248 
249   return MergeIfCompatible(a, b_parsed_name);
250 }
251 
252 }  // namespace tensorrt
253 }  // namespace tensorflow
254 
255 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
256