1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
17 
18 #include <algorithm>
19 #include <functional>
20 
21 #include "absl/algorithm/container.h"
22 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
23 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
24 #include "tensorflow/core/platform/stream_executor.h"
25 #include "tensorflow/core/profiler/lib/traceme.h"
26 
27 #if GOOGLE_CUDA && GOOGLE_TENSORRT
28 
29 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
30 
31 namespace tensorflow {
32 namespace tensorrt {
33 
34 // Returns a vector of nvinfer1::Dims for a vector of TensorShapes.
35 template <typename TensorShapeType>
GetDimVec(std::vector<TensorShapeType> shape_vec)36 std::vector<nvinfer1::Dims> GetDimVec(std::vector<TensorShapeType> shape_vec) {
37   std::vector<nvinfer1::Dims> dimvec(shape_vec.size());
38   absl::c_transform(shape_vec, dimvec.begin(), [](TensorShapeType shape) {
39     auto adap = DimsAdapter::Create(shape);
40     TF_CHECK_OK(adap.status());
41     return adap->AsTrtDims();
42   });
43   return dimvec;
44 }
45 
46 // In dynamic shape mode the optimization profile dims are only allowed to
47 // differ from the network input dims where the network input dims have -1
48 // values. We enforce this condition by changing prof_dims if necessary.
EnforceCompatibility(nvinfer1::Dims * prof_dims,const PartialTensorShape & input_shape)49 void EnforceCompatibility(nvinfer1::Dims* prof_dims,
50                           const PartialTensorShape& input_shape) {
51   for (int i = 0; i < input_shape.dims(); i++) {
52     if (input_shape.dim_size(i) != -1) {
53       prof_dims->d[i] = input_shape.dim_size(i);
54     }
55   }
56 }
57 
SetImplicitBatchModeCompatibleProfile(const std::vector<nvinfer1::Dims> & dimvec,std::vector<nvinfer1::Dims> * min,std::vector<nvinfer1::Dims> * opt,std::vector<nvinfer1::Dims> * max)58 void SetImplicitBatchModeCompatibleProfile(
59     const std::vector<nvinfer1::Dims>& dimvec, std::vector<nvinfer1::Dims>* min,
60     std::vector<nvinfer1::Dims>* opt, std::vector<nvinfer1::Dims>* max) {
61   *min = dimvec;
62   for (auto& dim : *min) {
63     // Shape value tensors can have -1 value as a wildcard. We do not change
64     // in that case.
65     if (dim.d[0] != -1) dim.d[0] = 1;  // Set min batch size to 1.
66   }
67   *opt = dimvec;
68   *max = dimvec;
69 }
70 
ImplicitBatchModeCompatibleStrategy(const std::vector<std::vector<nvinfer1::Dims>> & collected_shapes)71 void TrtShapeOptimizationProfile::ImplicitBatchModeCompatibleStrategy(
72     const std::vector<std::vector<nvinfer1::Dims>>& collected_shapes) {
73   for (auto& shape_vec : collected_shapes) {
74     std::vector<nvinfer1::Dims> min, opt, max;
75     SetImplicitBatchModeCompatibleProfile(shape_vec, &min, &opt, &max);
76     VLOG(2) << "Initializing optimization profile config with min="
77             << DebugString(min) << ", opt=max=" << DebugString(max);
78     OptimizationProfileConfig profConfig{min, opt, max};
79     profiles_.push_back(std::move(profConfig));
80   }
81 }
82 
83 // Applies a binary operation for each dimension of the input shapes.
84 // x[i].d[k] = op(x[i].d[k], y[i].d[k]), where i enumerates the input tensors,
85 // and k enumerates the dimensions of the tensors. The BinaryOperation may be
86 // std::min, std::max etc.
87 template <typename BinaryOperation>
ShapeProfileBinaryOp(std::vector<nvinfer1::Dims> * x,const std::vector<nvinfer1::Dims> & y,BinaryOperation op)88 Status ShapeProfileBinaryOp(std::vector<nvinfer1::Dims>* x,
89                             const std::vector<nvinfer1::Dims>& y,
90                             BinaryOperation op) {
91   if (x->size() != y.size())
92     return errors::InvalidArgument(
93         "Number of input tensors differ during profile creation");
94   for (int i = 0; i < x->size(); i++) {
95     if (x->at(i).nbDims != y[i].nbDims)
96       return errors::InvalidArgument(
97           "Number of input dimensions differ during profile creation");
98     for (int j = 0; j < x->at(i).nbDims; j++) {
99       x->at(i).d[j] = op(x->at(i).d[j], y[i].d[j]);
100     }
101   }
102   return Status::OK();
103 }
104 
RangeStrategy(const std::vector<std::vector<nvinfer1::Dims>> & collected_shapes)105 Status TrtShapeOptimizationProfile::RangeStrategy(
106     const std::vector<std::vector<nvinfer1::Dims>>& collected_shapes) {
107   if (collected_shapes.empty()) return Status::OK();
108 
109   std::vector<nvinfer1::Dims> min = collected_shapes[0];
110   std::vector<nvinfer1::Dims> max = min;
111 
112   for (int i = 1; i < collected_shapes.size(); i++) {
113     TF_RETURN_IF_ERROR(
114         ShapeProfileBinaryOp(&min, collected_shapes[i],
115                              [](int a, int b) { return std::min(a, b); }));
116     TF_RETURN_IF_ERROR(
117         ShapeProfileBinaryOp(&max, collected_shapes[i],
118                              [](int a, int b) { return std::max(a, b); }));
119   }
120   VLOG(2) << "Initializing optimization profile config with min="
121           << DebugString(min) << ", opt=max=" << DebugString(max);
122   OptimizationProfileConfig profConfig{min, max, max};
123   profiles_.push_back(std::move(profConfig));
124   return Status::OK();
125 }
126 
OptimalStrategy(const std::vector<std::vector<nvinfer1::Dims>> & collected_shapes)127 void TrtShapeOptimizationProfile::OptimalStrategy(
128     const std::vector<std::vector<nvinfer1::Dims>>& collected_shapes) {
129   for (auto& shape_vec : collected_shapes) {
130     std::vector<nvinfer1::Dims> min = shape_vec;
131     std::vector<nvinfer1::Dims> opt = min;
132     std::vector<nvinfer1::Dims> max = min;
133     VLOG(2) << "Initializing optimization profile config with min=opt=max="
134             << DebugString(min);
135     OptimizationProfileConfig profConfig{min, opt, max};
136     profiles_.push_back(std::move(profConfig));
137   }
138 }
139 
140 // Collects the values of tensors that are ShapeTensorCompatible to. The values
141 // are stored in the actual_shape_values_ member variable.
CollectShapeValues(OpKernelContext * ctx)142 Status TrtShapeOptimizationProfile::CollectShapeValues(OpKernelContext* ctx) {
143   tensorflow::profiler::TraceMe activity(
144       "TrtShapeOptimizationProfile::CollectShapeValues",
145       tensorflow::profiler::TraceMeLevel::kInfo);
146   const cudaStream_t* stream = CHECK_NOTNULL(
147       reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
148                                                 ->stream()
149                                                 ->implementation()
150                                                 ->GpuStreamMemberHack()));
151   actual_shape_values_.resize(ctx->num_inputs());
152   if (is_shape_tensor_.empty()) {
153     is_shape_tensor_.resize(ctx->num_inputs());
154     for (int i = 0; i < ctx->num_inputs(); i++) {
155       is_shape_tensor_[i] = IsTrtShapeTensorCompatible(ctx->input(i));
156     }
157   }
158   int n_shape_val = 0;
159   // First copy all the shape value candidates into actual_shape_values_ vector.
160   for (int i = 0; i < ctx->num_inputs(); i++) {
161     if (is_shape_tensor_[i]) {
162       if (ctx->input_dtype(i) != DT_INT32) {
163         // In case the is_shape_tensor mask was initialized with the input
164         // shapes only (without knowledge of dtype) then we apply correction.
165         is_shape_tensor_[i] = false;
166         continue;
167       }
168       // We have to copy the shape values to the host, because TRT's
169       // ExecutionContext::setInputShapeBinding expects a host pointer.
170       n_shape_val++;
171       const Tensor& input = ctx->input(i);
172       actual_shape_values_[i].nbDims = input.NumElements();
173       auto ret = cudaMemcpyAsync(
174           actual_shape_values_[i].d, input.flat<int32>().data(),
175           input.NumElements() * sizeof(int32), cudaMemcpyDeviceToHost, *stream);
176       if (ret != 0) {
177         return errors::Internal("Could not copy shape tensor values");
178       }
179       VLOG(2) << "Input " << i << " is (probably) a shape tensor, n_values="
180               << input.NumElements();
181     } else {
182       actual_shape_values_[i] = {0, {}};
183     }
184   }
185   if (n_shape_val > 0) {
186     // If we have any shape values candidates, then wait until data is copied
187     // to host.
188     cudaStreamSynchronize(*stream);
189   }
190   return Status::OK();
191 }
192 
193 // Collects the values of tensors that are ShapeTensorCompatible to. To be used
194 // for unit tests.
CollectShapeValues(const DataVec & input)195 Status TrtShapeOptimizationProfile::CollectShapeValues(const DataVec& input) {
196   actual_shape_values_.resize(input.size());
197   for (int i = 0; i < input.size(); i++) {
198     if (is_shape_tensor_[i]) {
199       if (!IsTrtShapeTensorCompatible(input[i].tensor)) {
200         return errors::Internal("Inconsistent shape tensor ", input[i].name,
201                                 ", ", i);
202       }
203       int n_elements = input[i].tensor.NumElements();
204       actual_shape_values_[i].nbDims = n_elements;
205       // During unit tests, the data is in unified memory
206       std::copy(input[i].tensor.flat<int32>().data(),
207                 input[i].tensor.flat<int32>().data() + n_elements,
208                 actual_shape_values_[i].d);
209       VLOG(2) << "Collected tensor shape values "
210               << DebugString(actual_shape_values_[i]);
211     } else {
212       actual_shape_values_[i] = {0, {}};
213     }
214   }
215   return Status::OK();
216 }
217 
218 // Adjusts shape value profile to prevent TRT from removing shape value input
219 // bindings whose value is redundant (only a single value matches the profile).
220 // This should be removed once the NVIDIA bug 3153064 is fixed.
FixShapeValueProfile(OptimizationProfileConfig * prof,const std::vector<bool> & is_shape_tensor)221 void FixShapeValueProfile(OptimizationProfileConfig* prof,
222                           const std::vector<bool>& is_shape_tensor) {
223   int shape_value_offset = is_shape_tensor.size();
224   for (int i = 0; i < is_shape_tensor.size(); i++) {
225     if (is_shape_tensor[i] &&
226         std::equal(prof->min[shape_value_offset + i].d,
227                    prof->min[shape_value_offset + i].d +
228                        prof->min[shape_value_offset + i].nbDims,
229                    prof->max[shape_value_offset + i].d)) {
230       prof->max[shape_value_offset + i].d[0]++;
231       VLOG(2) << "Adjusted profile for shape value tensor " << i << " "
232               << DebugString(prof->max[shape_value_offset + i]);
233     } else {
234       VLOG(2) << i << " is not a shape tensor." << is_shape_tensor[i];
235     }
236   }
237 }
238 
239 // Checks whether rhs is already contained in values.
AlreadyCollected(const std::vector<std::vector<nvinfer1::Dims>> & values,const std::vector<nvinfer1::Dims> & rhs)240 bool AlreadyCollected(const std::vector<std::vector<nvinfer1::Dims>>& values,
241                       const std::vector<nvinfer1::Dims>& rhs) {
242   for (auto& lhs : values) {
243     bool ret = lhs.size() == rhs.size();
244     for (int i = 0; ret && i < lhs.size(); i++) {
245       ret &= lhs[i].nbDims == rhs[i].nbDims;
246       for (int j = 0; ret && j < lhs[i].nbDims; j++) {
247         ret &= (lhs[i].d[j] == rhs[i].d[j]);
248       }
249     }
250     if (ret) return true;
251   }
252   return false;
253 }
254 
InitProfiles(const std::vector<PartialTensorShape> & input_partial_shapes,ProfileStrategy strategy)255 void TrtShapeOptimizationProfile::InitProfiles(
256     const std::vector<PartialTensorShape>& input_partial_shapes,
257     ProfileStrategy strategy) {
258   strategy_ = strategy;
259   if (input_shapes_.size() == 0) {
260     VLOG(1) << "Not creating profiles without input_shapes. "
261                "You have to enable profile generation mode first (build).";
262     return;
263   }
264   // Preprocess the vector of input shapes and shape values:
265   // - Converts TensorShape -> nvinfer::Dims.
266   // - Concatenates the shape values after the input shapes:
267   //   dimvec = [dim0, dim1,..., shapeval0, shapval1, ...]
268   // - Ensures that the list is unique.
269   std::vector<std::vector<nvinfer1::Dims>> collected_shapes;
270   for (int i = 0; i < input_shapes_.size(); i++) {
271     auto shape_vec = input_shapes_[i];
272     VLOG(2) << "Initprofiles, processing shape " << i;
273     if (!shape_vec.empty()) {
274       std::vector<nvinfer1::Dims> dimvec = GetDimVec(shape_vec);
275       dimvec.insert(dimvec.end(), input_shape_values_[i].begin(),
276                     input_shape_values_[i].end());
277       // TODO(tfeher): This condition should not apply for explicit profile. In
278       // that case consicutive elements in collected_shapes contain the user
279       // defined values of min, opt and max, and it is valid the have min = opt
280       // and opt = max.
281       if (!AlreadyCollected(collected_shapes, dimvec)) {
282         collected_shapes.push_back(dimvec);
283       }
284     }
285   }
286   switch (strategy_) {
287     case ProfileStrategy::kImplicitBatchModeCompatible:
288       VLOG(1) << "Creating profiles with ImplicitBatchModeCompatible strategy";
289       ImplicitBatchModeCompatibleStrategy(collected_shapes);
290       break;
291     // Treat all other strategies the same as kOptimal for now. Implementing
292     // those is outlined in the dynamic shape support implementation plan.
293     case ProfileStrategy::kRange:
294       VLOG(1) << "Creating profiles with Range strategy";
295       TF_CHECK_OK(RangeStrategy(collected_shapes));
296       break;
297     case ProfileStrategy::kRangeOptimal:
298       VLOG(1) << "Creating profiles with RangeOptimal strategy";
299       OptimalStrategy(collected_shapes);
300       TF_CHECK_OK(RangeStrategy(collected_shapes));
301       break;
302     case ProfileStrategy::kOptimal:
303       VLOG(1) << "Creating profiles with Optimal strategy";
304       OptimalStrategy(collected_shapes);
305       break;
306   }
307   // Define a mask that describe which input could be a shape tensor. Note
308   // that here we can have false positives. The shape tensor mask will be
309   // updated once the network is constructed.
310   SetShapeTensorMask(input_partial_shapes);
311   if (input_partial_shapes.size() > 0) {
312     for (OptimizationProfileConfig& prof : profiles_) {
313       // TODO: Remove this when the bug is fixed.
314       FixShapeValueProfile(&prof, is_shape_tensor_);
315       for (int i = 0; i < input_partial_shapes.size(); i++) {
316         auto network_input = input_partial_shapes[i];
317         EnforceCompatibility(&prof.min[i], network_input);
318         EnforceCompatibility(&prof.opt[i], network_input);
319         EnforceCompatibility(&prof.max[i], network_input);
320       }
321     }
322   }
323 }
324 
InitCalibProfile(const std::vector<TensorShape> & shapes)325 void TrtShapeOptimizationProfile::InitCalibProfile(
326     const std::vector<TensorShape>& shapes) {
327   VLOG(1) << "Collected shape(s) " << DebugString(shapes) << " for "
328           << " calibration profile.";
329   auto shape_vec = shapes;
330   if (!shape_vec.empty()) {
331     std::vector<nvinfer1::Dims> dimvec = GetDimVec(shape_vec);
332     dimvec.insert(dimvec.end(), actual_shape_values_.begin(),
333                   actual_shape_values_.end());
334     VLOG(2) << "Initializing calibration optimization profile config with "
335             << "min=opt=max " << DebugString(dimvec);
336 
337     OptimizationProfileConfig profConfig{dimvec, dimvec, dimvec};
338     calib_profiles_ = std::move(profConfig);
339   } else {
340     VLOG(2) << "Failed to initialize calibration optimization profile.";
341   }
342 }
343 
AddProfiles(nvinfer1::IBuilder * builder,nvinfer1::IBuilderConfig * config,const nvinfer1::INetworkDefinition * network)344 Status TrtShapeOptimizationProfile::AddProfiles(
345     nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
346     const nvinfer1::INetworkDefinition* network) {
347   // Create optimization profile for calibration if necessary.
348   if (!calib_profiles_.min.empty()) {
349     VLOG(2) << "Setting up calibration profies";
350     auto* calibProfile = builder->createOptimizationProfile();
351     Status status =
352         calib_profiles_.SetDimensions(network, calibProfile, input_mask_);
353     if (!status.ok()) {
354       return status;
355     }
356     bool result = false;
357     if (calibProfile->isValid()) {
358       result = config->setCalibrationProfile(calibProfile);
359     } else {
360       VLOG(2) << "Calibration profile is not valid";
361     }
362     if (result) {
363       VLOG(2) << "Added calibration optimization profile "
364               << calib_profiles_.DebugString() << " to builder config.";
365     } else {
366       VLOG(2) << "FAILED TO ADD PROFILE";
367       LOG(ERROR) << "Failed to add calibration optimization profile "
368                  << calib_profiles_.DebugString()
369                  << ". This usually happens when profile is invalid.";
370     }
371   }
372   // Create a vector of optimization profiles.
373   for (int i = 0; i < profiles_.size(); i++) {
374     auto* optProfile = builder->createOptimizationProfile();
375     Status status =
376         profiles_[i].SetDimensions(network, optProfile, input_mask_);
377     if (!status.ok()) {
378       return status;
379     }
380     int idx = -1;
381     if (optProfile->isValid()) {
382       idx = config->addOptimizationProfile(optProfile);
383     }
384     if (idx >= 0) {
385       if (i != idx) {
386         return errors::Internal(
387             "Profile index of engine config is different from source profile "
388             "index: ",
389             i, " != ", idx);
390       }
391       VLOG(1) << "Added optimization profile " << profiles_[i].DebugString()
392               << " with idx " << idx << " to builder config.";
393     } else {
394       LOG(ERROR) << "Failed to add optimization profile "
395                  << profiles_[i].DebugString()
396                  << ". This usually happens when profile is invalid.";
397     }
398   }
399   if (!profiles_.empty() && config->getNbOptimizationProfiles() == 0) {
400     return errors::Internal("Failure in adding an optimization profile.");
401   }
402   need_profiles_ = config->getNbOptimizationProfiles() > 0;
403   // Update the mask that flag shape tensors. The network is known now,
404   // the mask will be correct.
405   SetShapeTensorMask(network);
406   is_pruned_input_.resize(network->getNbInputs());
407   absl::c_fill(is_pruned_input_, false);
408   return Status::OK();
409 }
410 
ConfigureBuilder(nvinfer1::IBuilder * builder,nvinfer1::IBuilderConfig * config,const nvinfer1::INetworkDefinition * network)411 Status TrtShapeOptimizationProfile::ConfigureBuilder(
412     nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
413     const nvinfer1::INetworkDefinition* network) {
414   TF_RETURN_IF_ERROR(AddProfiles(builder, config, network));
415   return Status::OK();
416 }
417 
418 // Sets the shape tensor mask from the TRT engine definition.
SetShapeTensorMask(const nvinfer1::ICudaEngine * engine,int n_inputs)419 void TrtShapeOptimizationProfile::SetShapeTensorMask(
420     const nvinfer1::ICudaEngine* engine, int n_inputs) {
421   is_shape_tensor_.resize(n_inputs, false);
422   for (int i = 0; i < n_inputs; i++) {
423     int binding_index;
424     Status status = GetTrtBindingIndex(i, 0, engine, &binding_index);
425     if (!status.ok()) {
426       continue;
427     }
428     is_shape_tensor_[i] = engine->isShapeBinding(binding_index);
429     if (is_shape_tensor_[i]) {
430       VLOG(2) << "Found shape tensor at " << i;
431     }
432   }
433   has_shape_tensor_ =
434       absl::c_any_of(is_shape_tensor_, [](bool b) { return b; });
435 }
436 
437 // Sets the shape tensor mask using the network definition.
SetShapeTensorMask(const nvinfer1::INetworkDefinition * network)438 void TrtShapeOptimizationProfile::SetShapeTensorMask(
439     const nvinfer1::INetworkDefinition* network) {
440   int n_inputs = network->getNbInputs();
441   is_shape_tensor_.resize(n_inputs, false);
442   for (int i = 0; i < n_inputs; i++) {
443     const ITensorProxyPtr input = network->getInput(i);
444     is_shape_tensor_[i] = input->isShapeTensor();
445     if (is_shape_tensor_[i]) {
446       VLOG(2) << "Found shape tensor " << input->getName() << " at " << i;
447     }
448   }
449   has_shape_tensor_ =
450       absl::c_any_of(is_shape_tensor_, [](bool b) { return b; });
451 }
452 
453 // Sets the shape tensor mask using the input partial shapes. This only tells
454 // whether the tensors are shape value compatible, only the final network
455 // definition or the engine would give concrete answers.
SetShapeTensorMask(const std::vector<PartialTensorShape> & input_partial_shapes)456 void TrtShapeOptimizationProfile::SetShapeTensorMask(
457     const std::vector<PartialTensorShape>& input_partial_shapes) {
458   if (is_shape_tensor_.size() == input_partial_shapes.size()) {
459     // Already initialized, e.g. by TRTEngineOp::ComputeAsync().
460     return;
461   }
462   is_shape_tensor_.resize(input_partial_shapes.size(), false);
463   for (int i = 0; i < input_partial_shapes.size(); i++) {
464     is_shape_tensor_[i] = IsTrtShapeTensorCompatible(input_partial_shapes[i]);
465     if (is_shape_tensor_[i]) {
466       VLOG(2) << "Found shape compatible tensor at " << i;
467     }
468   }
469   has_shape_tensor_ =
470       absl::c_any_of(is_shape_tensor_, [](bool b) { return b; });
471 }
472 
GetProfileNumber(const std::vector<TensorShape> & shapes)473 int TrtShapeOptimizationProfile::GetProfileNumber(
474     const std::vector<TensorShape>& shapes) {
475   tensorflow::profiler::TraceMe activity(
476       "TrtShapeOptimizationProfile::GetProfileNumber",
477       tensorflow::profiler::TraceMeLevel::kInfo);
478   if (!need_profiles_) return 0;
479   // TODO(tfeher): Return the best profile not just the first compatible.
480   for (int i = 0; i < profiles_.size(); i++) {
481     if (profiles_[i].IncludesShapes(shapes, HasShapeTensor(),
482                                     actual_shape_values_, is_pruned_input_)) {
483       return i;
484     }
485   }
486   VLOG(1) << "Profile not found for input shapes " << DebugString(shapes)
487           << ".";
488   return -1;
489 }
490 
CreateExecutionContexts(nvinfer1::ICudaEngine * engine,std::vector<ExecutionContext> * exec_contexts)491 Status TrtShapeOptimizationProfile::CreateExecutionContexts(
492     nvinfer1::ICudaEngine* engine,
493     std::vector<ExecutionContext>* exec_contexts) {
494   int i = 0;
495   // The following loop runs once if we have static shapes, to create a single
496   // execution context without profiles. In dynamic mode we create one context
497   // for each profile and set the corresponding optimization profile.
498   do {
499     VLOG(1) << "Creating execution context " << i;
500     ExecutionContext context = ExecutionContext::Create(engine);
501     if (i > 0) {
502       // This condition is needed for two reasons:
503       // - using static shapes we do not have any profiles so we cannot call
504       //   set optimizationprofiles.
505       // - The 0th profile is set implicitly for the first execution context
506       //   therefore we do not need to set.
507       if (!context->setOptimizationProfile(i)) {
508         return errors::Internal("Could not set TRT optimization profile.");
509       }
510     }
511     exec_contexts->push_back(std::move(context));
512     i++;
513   } while (i < profiles_.size());
514 
515   return Status::OK();
516 }
517 
SetInputShapeBinding(int input_index,int binding_index,nvinfer1::ICudaEngine * cuda_engine,nvinfer1::IExecutionContext * exec_context) const518 Status TrtShapeOptimizationProfile::SetInputShapeBinding(
519     int input_index, int binding_index, nvinfer1::ICudaEngine* cuda_engine,
520     nvinfer1::IExecutionContext* exec_context) const {
521   tensorflow::profiler::TraceMe activity(
522       "TrtShapeOptimizationProfile::SetInputShapeBinding",
523       tensorflow::profiler::TraceMeLevel::kInfo);
524   if (cuda_engine->isShapeBinding(binding_index)) {
525     // Input shape binding data has to be in host memory. That is the reason
526     // we can't use input_tensor.flat().data(). which contains the same
527     // values in device memory. Instead, we use data that was copied to host
528     // by CollectShapeValues.
529     VLOG(2) << "Setting input shape binding for idx " << binding_index
530             << ", with values "
531             << DebugString(actual_shape_values_.at(input_index));
532     bool ret = exec_context->setInputShapeBinding(
533         binding_index, actual_shape_values_.at(input_index).d);
534     if (!ret) {
535       return errors::Internal("Could not set input shape binding for idx ",
536                               binding_index);
537     }
538   }
539   return Status::OK();
540 }
541 
542 // If binding_idx is a shape tensor, then returns the associated min/max/opt
543 // shape values from prof_idx.
GetDimsFromShapeVal(int prof_idx,int binding_idx,nvinfer1::OptProfileSelector selector,const nvinfer1::ICudaEngine * engine)544 nvinfer1::Dims GetDimsFromShapeVal(int prof_idx, int binding_idx,
545                                    nvinfer1::OptProfileSelector selector,
546                                    const nvinfer1::ICudaEngine* engine) {
547   if (engine->isShapeBinding(binding_idx)) {
548     const int32* shape_val_ptr =
549         engine->getProfileShapeValues(binding_idx, prof_idx, selector);
550     if (shape_val_ptr) {
551       VLOG(2) << "Found shape value in prof " << prof_idx << ", binding "
552               << binding_idx;
553       nvinfer1::Dims dims = engine->getBindingDimensions(binding_idx);
554       // nbDims == 0 represent scalar, -1 represents invalid dim
555       int n_values = (dims.nbDims == 0) ? 1 : dims.d[0];
556       if (n_values > 0) {
557         dims.nbDims = n_values;
558         std::copy(shape_val_ptr, shape_val_ptr + n_values, dims.d);
559       }
560       return dims;
561     }
562   }
563   return {0, {0}};
564 }
565 
SetPrunedMask(const nvinfer1::ICudaEngine * engine,int n_network_inputs)566 Status TrtShapeOptimizationProfile::SetPrunedMask(
567     const nvinfer1::ICudaEngine* engine, int n_network_inputs) {
568   is_pruned_input_.resize(n_network_inputs);
569   absl::c_fill(is_pruned_input_, false);
570   for (int j = 0; j < n_network_inputs; j++) {
571     int binding_idx;
572     Status status = GetTrtBindingIndex(j, 0, engine, &binding_idx);
573     if (!status.ok()) {
574       // Before TRT 8, an input tensor can be pruned (nvbugs/3153064)
575       // Resource inputs are also unknown by TRT, so we can treat them as
576       // pruned (the engine includes the variable as weights).
577       is_pruned_input_[j] = true;
578       VLOG(2) << "Skipping pruned input " << j;
579       continue;
580     }
581   }
582   return Status::OK();
583 }
584 
RestoreProfiles(const nvinfer1::ICudaEngine * engine,int n_network_inputs)585 Status TrtShapeOptimizationProfile::RestoreProfiles(
586     const nvinfer1::ICudaEngine* engine, int n_network_inputs) {
587   need_profiles_ = false;
588   if (!engine) {
589     // We do not need to restore profiles for an empty engine.
590     return Status::OK();
591   }
592   if (engine->hasImplicitBatchDimension()) {
593     // Nothing to do, we cannot have profiles in implicit batch mode.
594     return Status::OK();
595   }
596   int n_profiles = engine->getNbOptimizationProfiles();
597   need_profiles_ = n_profiles > 0;
598   int n_inputs = GetNumberOfEngineInputs(engine);
599   if (n_inputs > n_network_inputs) {
600     return errors::Internal("Incorrect number of engine inputs");
601   }
602   VLOG(2) << "Attempting to restore " << n_profiles << " profiles, each with "
603           << n_inputs << " inputs";
604   SetShapeTensorMask(engine, n_network_inputs);
605 
606   TF_RETURN_IF_ERROR(SetPrunedMask(engine, n_network_inputs));
607 
608   for (int prof_idx = 0; prof_idx < n_profiles; prof_idx++) {
609     OptimizationProfileConfig cfg;
610 
611     cfg.min.resize(n_network_inputs * 2);
612     cfg.max.resize(n_network_inputs * 2);
613     cfg.opt.resize(n_network_inputs * 2);
614     // restore shape values
615     for (int j = 0; j < n_network_inputs; j++) {
616       if (is_pruned_input_[j]) continue;
617       int binding_idx;
618       TF_RETURN_IF_ERROR(GetTrtBindingIndex(j, 0, engine, &binding_idx));
619 
620       nvinfer1::Dims min = engine->getProfileDimensions(
621           binding_idx, prof_idx, nvinfer1::OptProfileSelector::kMIN);
622       nvinfer1::Dims max = engine->getProfileDimensions(
623           binding_idx, prof_idx, nvinfer1::OptProfileSelector::kMAX);
624       nvinfer1::Dims opt = engine->getProfileDimensions(
625           binding_idx, prof_idx, nvinfer1::OptProfileSelector::kOPT);
626       cfg.min[j] = min;
627       cfg.max[j] = max;
628       cfg.opt[j] = opt;
629 
630       cfg.min[j + n_inputs] = GetDimsFromShapeVal(
631           prof_idx, binding_idx, nvinfer1::OptProfileSelector::kMIN, engine);
632       cfg.max[j + n_inputs] = GetDimsFromShapeVal(
633           prof_idx, binding_idx, nvinfer1::OptProfileSelector::kMAX, engine);
634       cfg.opt[j + n_inputs] = GetDimsFromShapeVal(
635           prof_idx, binding_idx, nvinfer1::OptProfileSelector::kOPT, engine);
636     }
637     VLOG(2) << "Restored profile " << cfg.DebugString();
638     profiles_.push_back(std::move(cfg));
639   }
640   return Status::OK();
641 }
642 
GetNumProfiles() const643 int TrtShapeOptimizationProfile::GetNumProfiles() const {
644   return profiles_.size();
645 }
646 
647 }  // namespace tensorrt
648 }  // namespace tensorflow
649 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
650