1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_
18 
19 #include <list>
20 #include <string>
21 #include <unordered_set>
22 #include <vector>
23 
24 #include "tensorflow/compiler/tf2tensorrt/common/datavec.h"
25 #include "tensorflow/compiler/tf2tensorrt/convert/trt_parameters.h"
26 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
27 #include "tensorflow/compiler/tf2tensorrt/utils/trt_execution_context.h"
28 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 
36 #if GOOGLE_CUDA && GOOGLE_TENSORRT
37 
38 #include "third_party/tensorrt/NvInfer.h"
39 
40 namespace tensorflow {
41 namespace tensorrt {
42 
43 // Stores optimization profile parameters (min/opt/max of each input shape).
44 //
45 // A TensorRT optimization profile describes the possible min/max values of
46 // each dynamic input shape along with an optimum value. These values are used
47 // by the TensorRT builder to select the best kernel for the optimum value among
48 // those kernels that are valid for all input tensors in the [min, max] range.
49 struct OptimizationProfileConfig {
50   // Length of vector == 2*num_inputs to engine. min[0:num_inputs-1] are the min
51   // input dimensions for execution tensors. If engine has shape input tensors,
52   // then min[num_inputs + i] store the shape value for input i. For inputs that
53   // are not shape tensors min = opt = max = {0, {}}.
54   //
55   // When the OptimizationProfileConfig is created from the network definition
56   // (AddProfiles), then each elements of the min, opt, max vectors are defined.
57   // When the OptimizationProfileConfig object is restored during engine
58   // deserialization (RestoreProfiles), then some inputs can be pruned
59   // (see TrtShapeOptimizationProfile::is_pruned_input_). In that case min[i]
60   // is not defined for pruned inputs (same is true for opt and max).
61   std::vector<nvinfer1::Dims> min;
62   std::vector<nvinfer1::Dims> opt;
63   std::vector<nvinfer1::Dims> max;
64 
DebugStringOptimizationProfileConfig65   string DebugString() const {
66     using absl::StrCat;
67     return StrCat("[min: ", tensorflow::tensorrt::DebugString(min),
68                   ", opt: : ", tensorflow::tensorrt::DebugString(opt),
69                   ", max: ", tensorflow::tensorrt::DebugString(max), "]");
70   }
71 
72   // Sets the min/opt/max dimensions for profile.
73   //
74   // The given min/opt/max dimensions should satisfy the condition
75   // min <= opt <= max. Additionally TRT requires that the min/opt/max values
76   // are compatible with the network input. Compatibility is defined the
77   // following way: let dim be the shape of an input binding and min/opt/max the
78   // corresponding profile dims. TRT requires that dim.d[k] must be -1 if
79   // (min.d[k] != dim.d[k] || opt.d[k] != dim.d[k] || max.d[k] != dim.d[k]).
80   //
81   // Parameters:
82   // network - TensorRT network, used to enumerate all the input tensors
83   // profile - on exit the profile information will be set for each input tensor
84   // input_mask - 1 for TRT inputs, 0 for TF inputs that are not TRT inputs
SetDimensionsOptimizationProfileConfig85   Status SetDimensions(const nvinfer1::INetworkDefinition* network,
86                        nvinfer1::IOptimizationProfile* profile,
87                        const std::vector<bool>& input_mask) const {
88     int n_inputs_trt = network->getNbInputs();
89     int n_inputs_tf = opt.size() / 2;
90     /// TODO(lsugy): check that the sum of the mask equals n_inputs.
91     if (input_mask.size() != n_inputs_tf) {
92       return errors::Internal("Incorrect input mask size: ", input_mask.size());
93     }
94     int n_mask_true = 0;
95     for (bool mask_val : input_mask) {
96       if (mask_val) {
97         n_mask_true++;
98       }
99     }
100     if (n_mask_true != n_inputs_trt) {
101       return errors::Internal(
102           "Number of true elements in input_mask (", n_mask_true,
103           ") doesn't match expected TRT inputs (", n_inputs_trt, ")");
104     }
105     int j = 0;
106     for (int i = 0; i < n_inputs_tf; i++) {
107       if (input_mask[i]) {
108         const ITensorProxyPtr input = network->getInput(j);
109         const char* name = input->getName();
110         if (input->isShapeTensor()) {
111           int idx = i + n_inputs_tf;
112           VLOG(2) << "Setting shape values for " << name << ", "
113                   << ::tensorflow::tensorrt::DebugString(opt[idx]);
114           profile->setShapeValues(name, nvinfer1::OptProfileSelector::kMIN,
115                                   min[idx].d, min[idx].nbDims);
116           profile->setShapeValues(name, nvinfer1::OptProfileSelector::kOPT,
117                                   opt[idx].d, opt[idx].nbDims);
118           profile->setShapeValues(name, nvinfer1::OptProfileSelector::kMAX,
119                                   max[idx].d, max[idx].nbDims);
120         }
121         VLOG(2) << "Setting input dimensions for " << name << ", "
122                 << ::tensorflow::tensorrt::DebugString(opt[i]);
123         profile->setDimensions(name, nvinfer1::OptProfileSelector::kMIN,
124                                min[i]);
125         profile->setDimensions(name, nvinfer1::OptProfileSelector::kOPT,
126                                opt[i]);
127         profile->setDimensions(name, nvinfer1::OptProfileSelector::kMAX,
128                                max[i]);
129 
130         j++;
131       }
132     }
133     return Status::OK();
134   }
135 
136   // Returns true if profile range completely includes the given shapes.
IncludesShapesOptimizationProfileConfig137   bool IncludesShapes(const std::vector<TensorShape>& shapes,
138                       bool has_shape_tensor,
139                       const std::vector<nvinfer1::Dims>& shape_values,
140                       const std::vector<bool>& is_pruned_input) const {
141     // min, max, and opt must have the same size which is already verified in
142     // SetDimensions.
143     if (min.size() != shapes.size() * 2 ||
144         (has_shape_tensor && min.size() != shape_values.size() * 2)) {
145       VLOG(2) << "Profile size mismatch min size " << min.size()
146               << " vs input shapes size " << shapes.size() << " "
147               << shape_values.size();
148       return false;
149     }
150     for (int i = 0; i < shapes.size(); i++) {
151       if (is_pruned_input[i]) {
152         continue;
153       }
154       auto current_shape = shapes[i];
155       // min, max, and opt must have the same nbDims, which is already verified
156       // in SetDimensions.
157       if (min[i].nbDims != current_shape.dims()) {
158         return false;
159       }
160       // Check if range [min, max] includes current_shape.
161       for (int dim = 0; dim < current_shape.dims(); dim++) {
162         if ((min[i].d[dim] > current_shape.dim_size(dim)) ||
163             (max[i].d[dim] < current_shape.dim_size(dim))) {
164           return false;
165         }
166       }
167     }
168     // Check shape values.
169     if (has_shape_tensor) {
170       int offset = shapes.size();
171       for (int i = 0; i < shape_values.size(); i++) {
172         if (is_pruned_input[i]) {
173           continue;
174         }
175         auto shape_val = shape_values[i];
176         // min, max, and opt must have the same nbDims, which is already
177         // verified in SetDimensions.
178         if (min[i + offset].nbDims != shape_val.nbDims) {
179           return false;
180         }
181         // Check if range [min, max] includes shape_val.
182         for (int dim = 0; dim < shape_val.nbDims; dim++) {
183           if (min[i + offset].d[dim] > shape_val.d[dim] ||
184               max[i + offset].d[dim] < shape_val.d[dim]) {
185             return false;
186           }
187         }
188       }
189     }
190     return true;
191   }
192 };
193 
194 // Manages Optimization profiles during TRT Engine construction.
195 //
196 // An optimization profile describes a range of dimensions for each TRT network
197 // input, and the optimal dimensions that the auto-tuner should use for
198 // optimization.
199 //
200 // This class stores the list of input shapes that were seen during the
201 // build/profile_generation_mode phase, and using them it creates a set of
202 // OptimizationProfileConfigs. These configs will be added to IBuilderConfig
203 // before the engine is created.
204 class TrtShapeOptimizationProfile {
205  public:
TrtShapeOptimizationProfile()206   TrtShapeOptimizationProfile() {}
207 
208   // Stores input shape information during profile_generation_mode.
AddShape(const std::vector<TensorShape> & shapes)209   void AddShape(const std::vector<TensorShape>& shapes) {
210     input_shapes_.push_back(shapes);
211     input_shape_values_.push_back(actual_shape_values_);
212     VLOG(1) << "Collected shape(s) " << DebugString(shapes) << " for profiles.";
213   }
214 
215   // Stores the input mask.
SetInputMask(const std::vector<bool> & input_mask)216   void SetInputMask(const std::vector<bool>& input_mask) {
217     input_mask_ = input_mask;
218   }
219 
220   // Collects ShapeTensorCompatible tensor values. This is needed both during
221   // profile_generation_mode and during normal inference calls.
222   Status CollectShapeValues(OpKernelContext* ctx);
223 
224   // Collects ShapeTensorCompatible tensor values, used only for unit tests.
225   Status CollectShapeValues(const DataVec& input);
226 
clear()227   void clear() { profiles_.clear(); }
228 
229   // Returns the profile number that should be used to execute the network with
230   // the given input shapes. Returns -1 if none of cached profiles are
231   // compatible with the given input shapes.
232   int GetProfileNumber(const std::vector<TensorShape>& shapes);
233 
234   // Creates optimization profiles and add them to the builder config.
235   Status ConfigureBuilder(nvinfer1::IBuilder* builder,
236                           nvinfer1::IBuilderConfig* config,
237                           const nvinfer1::INetworkDefinition* network);
238 
239   // Creates execution contexts for each optimization profile.
240   Status CreateExecutionContexts(nvinfer1::ICudaEngine* engine,
241                                  std::vector<ExecutionContext>* exec_contexts);
242 
243   Status SetInputShapeBinding(int input_index, int binding_index,
244                               nvinfer1::ICudaEngine* cuda_engine,
245                               nvinfer1::IExecutionContext* exec_context) const;
246 
247   // Creates optimization profiles profiles_ for the set of concrete input
248   // shapes collected in input_shapes_. The input_partial_shapes of the network
249   // is used to ensure that the created optimization profiles are compatible
250   // with the network.
251   void InitProfiles(const std::vector<PartialTensorShape>& input_partial_shapes,
252                     ProfileStrategy strategy);
253 
254   void InitCalibProfile(const std::vector<TensorShape>& shapes);
255 
256   // Returns number of created profiles.
257   int GetNumProfiles() const;
258 
HasShape()259   bool HasShape() const { return !input_shapes_.empty(); }
NeedProfiles()260   bool NeedProfiles() const { return need_profiles_; }
261 
262   // Restores profiles from the engine (used after deserialization).
263   Status RestoreProfiles(const nvinfer1::ICudaEngine* engine,
264                          int n_network_inputs);
265 
266   // Whether the network has any shape tensors.
HasShapeTensor()267   bool HasShapeTensor() const { return has_shape_tensor_; }
268 
269   void SetShapeTensorMask(const nvinfer1::INetworkDefinition* network);
270 
271   // Whether the optimization profiles describe input that can be handled with
272   // a static engine (only 1 profile with min=max).
IsStaticCompatible()273   bool IsStaticCompatible() {
274     return strategy_ == ProfileStrategy::kOptimal && profiles_.size() == 1 &&
275            !HasShapeTensor();
276     // TODO(tfeher): remove !HasShapeTensor() condition once the
277     // FixShapeValueProfile workaround is turned off.
278   }
279 
280  private:
281   // Set of input shape vetors that we collect during profile_generation_mode.
282   std::vector<std::vector<TensorShape>> input_shapes_;
283 
284   // Input shape values that we collect during profile_generation_mode. If the
285   // tensor is not compatible with a TRT shape tensor then an empty shape is
286   // stored.
287   std::vector<std::vector<nvinfer1::Dims>> input_shape_values_;
288 
289   // Shape values present in the current inference call.
290   std::vector<nvinfer1::Dims> actual_shape_values_;
291 
292   // The optimization profiles generated from input_shapes_.
293   std::vector<OptimizationProfileConfig> profiles_;
294 
295   // The optimization profile for calibration.
296   OptimizationProfileConfig calib_profiles_;
297 
298   // A TRTEngineOp can have resource inputs. These are treated as constants:
299   // their value is read during conversion and stored as weights in the TRT
300   // engine. This means that resource inputs have no corresponding TRT engine
301   // input, and we do not need to provide profile information for these. The
302   // input mask helps to identify the TRT inputs, where we need to define
303   // optimization profiles.
304   std::vector<bool> input_mask_;
305 
306   // Whether the network has any shape tensors. Initially we assume that the
307   // network might have a shape value input. This will be updated when the
308   // network is created / engine is deserialized.
309   bool has_shape_tensor_ = true;
310 
311   // Whether the network/engine requires optimization profiles.
312   bool need_profiles_ = false;
313 
314   // Whether an input tensor is a shape tensor.
315   std::vector<bool> is_shape_tensor_;
316 
317   // Whether a network input was pruned (only in TRT 7).
318   std::vector<bool> is_pruned_input_;
319 
320   // Optimization profile generation strategy.
321   ProfileStrategy strategy_;
322 
323   // Adds optimization profiles to the builder config.
324   Status AddProfiles(nvinfer1::IBuilder* builder,
325                      nvinfer1::IBuilderConfig* config,
326                      const nvinfer1::INetworkDefinition* network);
327 
328   void SetShapeTensorMask(const nvinfer1::ICudaEngine* engine, int n_inputs);
329   void SetShapeTensorMask(
330       const std::vector<PartialTensorShape>& input_partial_shapes);
331 
332   Status SetPrunedMask(const nvinfer1::ICudaEngine* engine,
333                        int n_network_inputs);
334 
335   void ImplicitBatchModeCompatibleStrategy(
336       const std::vector<std::vector<nvinfer1::Dims>>& collected_shapes);
337   void OptimalStrategy(
338       const std::vector<std::vector<nvinfer1::Dims>>& collected_shapes);
339   Status RangeStrategy(
340       const std::vector<std::vector<nvinfer1::Dims>>& collected_shapes);
341 };
342 
343 }  // namespace tensorrt
344 }  // namespace tensorflow
345 
346 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
347 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_
348