1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA && GOOGLE_TENSORRT
17 
18 #include <string.h>
19 
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
24 #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/platform/test.h"
29 #include "third_party/tensorrt/NvInfer.h"
30 
31 namespace tensorflow {
32 namespace tensorrt {
33 
DimVecToShapeVec(std::vector<nvinfer1::Dims3> dimvec,bool expand_with_empty_shape_values=false)34 std::vector<TensorShape> DimVecToShapeVec(
35     std::vector<nvinfer1::Dims3> dimvec,
36     bool expand_with_empty_shape_values = false) {
37   std::vector<TensorShape> shapevec(dimvec.size());
38   for (int i = 0; i < dimvec.size(); i++) {
39     TensorShape shape;
40     TF_CHECK_OK(
41         TensorShapeUtils::MakeShape(dimvec[i].d, dimvec[i].nbDims, &shape));
42     shapevec[i] = shape;
43   }
44   if (expand_with_empty_shape_values) {
45     shapevec.resize(2 * dimvec.size());  // Append empty shape values
46   }
47   return shapevec;
48 }
49 
DimsContained(const nvinfer1::Dims & dim,const nvinfer1::Dims & min,const nvinfer1::Dims & max)50 bool DimsContained(const nvinfer1::Dims& dim, const nvinfer1::Dims& min,
51                    const nvinfer1::Dims& max) {
52   if (dim.nbDims != min.nbDims || dim.nbDims != max.nbDims) {
53     return false;
54   }
55   for (int i = 0; i < dim.nbDims; i++) {
56     if (dim.d[i] < min.d[i] || dim.d[i] > max.d[i]) {
57       return false;
58     }
59   }
60   return true;
61 }
62 
DimsEqual(const nvinfer1::Dims & a,const nvinfer1::Dims & b)63 bool DimsEqual(const nvinfer1::Dims& a, const nvinfer1::Dims& b) {
64   if (a.nbDims != b.nbDims) {
65     return false;
66   }
67   for (int i = 0; i < a.nbDims; i++) {
68     if (a.d[i] != b.d[i]) {
69       return false;
70     }
71   }
72   return true;
73 }
74 
75 class TrtShapeOptimizationProfileTest
76     : public ::testing::TestWithParam<ProfileStrategy> {
77  protected:
TrtShapeOptimizationProfileTest()78   TrtShapeOptimizationProfileTest() {
79     strategy_ = GetParam();
80     builder_ = TrtUniquePtrType<nvinfer1::IBuilder>(
81         nvinfer1::createInferBuilder(logger_));
82     network_ = TrtUniquePtrType<nvinfer1::INetworkDefinition>(
83         builder_->createNetworkV2(flags_));
84     builder_config_ = TrtUniquePtrType<nvinfer1::IBuilderConfig>(
85         builder_->createBuilderConfig());
86     builder_config_->setMaxWorkspaceSize(1 << 10);
87   }
88 
89   // Defines a simple network: output = input1 + input2.
DefineNetwork(nvinfer1::INetworkDefinition * network,nvinfer1::Dims3 & dims)90   void DefineNetwork(nvinfer1::INetworkDefinition* network,
91                      nvinfer1::Dims3& dims) {
92     ITensorProxyPtr input1 =
93         network->addInput("input1", nvinfer1::DataType::kFLOAT, dims);
94     EXPECT_NE(nullptr, input1->trt_tensor());
95 
96     ITensorProxyPtr input2 =
97         network->addInput("input2", nvinfer1::DataType::kFLOAT, dims);
98     EXPECT_NE(nullptr, input2->trt_tensor());
99 
100     auto layer =
101         network->addElementWise(*input1->trt_tensor(), *input2->trt_tensor(),
102                                 nvinfer1::ElementWiseOperation::kSUM);
103     EXPECT_NE(nullptr, layer);
104     // Mark the output.
105     ITensorProxyPtr output = layer->getOutput(0);
106     output->setName("output");
107     network->markOutput(*output->trt_tensor());
108   }
109 
CheckProfile(const std::vector<nvinfer1::Dims3> & dimvec,TrtShapeOptimizationProfile * profile,bool has_prof,bool test_optimality)110   void CheckProfile(const std::vector<nvinfer1::Dims3>& dimvec,
111                     TrtShapeOptimizationProfile* profile, bool has_prof,
112                     bool test_optimality) {
113     std::vector<TensorShape> shape_vec = DimVecToShapeVec(dimvec);
114     int idx = profile->GetProfileNumber(shape_vec);
115     ASSERT_EQ(idx >= 0, has_prof);
116     if (idx < 0) return;
117     int prof_idx = exec_contexts_[idx]->getOptimizationProfile();
118     ASSERT_GE(prof_idx, 0);
119     for (int j = 0; j < dimvec.size(); j++) {
120       nvinfer1::Dims min = engine->getProfileDimensions(
121           j, prof_idx, nvinfer1::OptProfileSelector::kMIN);
122       nvinfer1::Dims max = engine->getProfileDimensions(
123           j, prof_idx, nvinfer1::OptProfileSelector::kMAX);
124       nvinfer1::Dims opt = engine->getProfileDimensions(
125           j, prof_idx, nvinfer1::OptProfileSelector::kOPT);
126 
127       // This should always hold.
128       EXPECT_TRUE(DimsContained(dimvec[j], min, max));
129 
130       if (test_optimality) {
131         // We shall have selected an optimal strategy.
132         EXPECT_TRUE(DimsEqual(dimvec[j], opt));
133       }
134     }
135   }
136 
137   Logger& logger_ = *Logger::GetLogger();
138   TrtUniquePtrType<nvinfer1::IBuilder> builder_;
139   TrtUniquePtrType<nvinfer1::INetworkDefinition> network_;
140   TrtUniquePtrType<nvinfer1::IBuilderConfig> builder_config_;
141   TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
142   std::vector<ExecutionContext> exec_contexts_;
143   // The order is important: exec_context_ must be destroyed first, and logger
144   // at last.
145   const uint32_t flags_ =
146       1U << static_cast<int>(
147           nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
148   ProfileStrategy strategy_;
149 };
150 
151 INSTANTIATE_TEST_CASE_P(
152     OptProfilesTestInstantiation, TrtShapeOptimizationProfileTest,
153     ::testing::Values(ProfileStrategy::kRange, ProfileStrategy::kOptimal,
154                       ProfileStrategy::kRangeOptimal,
155                       ProfileStrategy::kImplicitBatchModeCompatible));
156 
TEST_P(TrtShapeOptimizationProfileTest,Static)157 TEST_P(TrtShapeOptimizationProfileTest, Static) {
158   // Static mode does not depend on strategies, we test only once.
159   if (strategy_ != ProfileStrategy::kRange) return;
160 
161   // Network with static input shape.
162   nvinfer1::Dims3 dims(8, 8, 10);
163   DefineNetwork(network_.get(), dims);
164 
165   TrtShapeOptimizationProfile profile;
166 
167   // Configure and build engine - should be a no-op.
168   TF_CHECK_OK(profile.ConfigureBuilder(builder_.get(), builder_config_.get(),
169                                        network_.get()));
170 
171   engine = TrtUniquePtrType<nvinfer1::ICudaEngine>(
172       builder_->buildEngineWithConfig(*network_, *builder_config_));
173   EXPECT_NE(nullptr, engine);
174   TF_CHECK_OK(profile.CreateExecutionContexts(engine.get(), &exec_contexts_));
175   // A single execution context should be created for a graph with static input.
176   ASSERT_EQ(exec_contexts_.size(), 1);
177   EXPECT_NE(nullptr, exec_contexts_[0]);
178 
179   std::vector<nvinfer1::Dims3> dim_vec(2, dims);
180   std::vector<TensorShape> shape_vec = DimVecToShapeVec(dim_vec);
181   EXPECT_EQ(0, profile.GetProfileNumber(shape_vec));
182 }
183 
TEST_P(TrtShapeOptimizationProfileTest,Dynamic)184 TEST_P(TrtShapeOptimizationProfileTest, Dynamic) {
185   // Network with dynamic input shapes.
186   nvinfer1::Dims3 dims(-1, -1, 10);
187   DefineNetwork(network_.get(), dims);
188 
189   TrtShapeOptimizationProfile profile;
190 
191   // Set the input mask to true (no resource input)
192   std::vector<bool> input_mask(2, true);
193   profile.SetInputMask(input_mask);
194 
195   std::vector<std::vector<nvinfer1::Dims3>> input_profiles{
196       {nvinfer1::Dims3(2, 2, 10), nvinfer1::Dims3(2, 2, 10)},
197       {nvinfer1::Dims3(3, 3, 10), nvinfer1::Dims3(3, 3, 10)},
198       {nvinfer1::Dims3(16, 16, 10), nvinfer1::Dims3(16, 16, 10)},
199   };
200 
201   std::vector<nvinfer1::Dims3> unseen_shapes{nvinfer1::Dims3(5, 5, 10),
202                                              nvinfer1::Dims3(9, 9, 10)};
203 
204   // Simulate a profile collection phase.
205   for (auto dim_vec : input_profiles) {
206     std::vector<TensorShape> shape_vec = DimVecToShapeVec(dim_vec, true);
207     profile.AddShape(shape_vec);
208   }
209   std::vector<PartialTensorShape> input_partial_shapes;
210   TF_CHECK_OK(GetNetworkInputShapes(network_.get(), &input_partial_shapes));
211   profile.InitProfiles(input_partial_shapes, strategy_);
212 
213   // Configure and build engine.
214   TF_CHECK_OK(profile.ConfigureBuilder(builder_.get(), builder_config_.get(),
215                                        network_.get()));
216   engine = TrtUniquePtrType<nvinfer1::ICudaEngine>(
217       builder_->buildEngineWithConfig(*network_.get(), *builder_config_.get()));
218   ASSERT_NE(nullptr, engine);
219 
220   TF_CHECK_OK(profile.CreateExecutionContexts(engine.get(), &exec_contexts_));
221 
222   int n_profiles_exp;
223   switch (strategy_) {
224     case (ProfileStrategy::kImplicitBatchModeCompatible):
225     case (ProfileStrategy::kOptimal):
226       n_profiles_exp = input_profiles.size();
227       break;
228     case (ProfileStrategy::kRange):
229       n_profiles_exp = 1;
230       break;
231     case (ProfileStrategy::kRangeOptimal):
232       n_profiles_exp = 1 + input_profiles.size();
233       break;
234   }
235   // Each profile has an associated execution context.
236   EXPECT_EQ(exec_contexts_.size(), n_profiles_exp);
237 
238   profile.SetShapeTensorMask(network_.get());
239 
240   EXPECT_EQ(profile.HasShapeTensor(), false);
241 
242   // Check if the profiles are assigned correctly.
243   for (auto dimvec : input_profiles) {
244     bool test_optimal_prof = strategy_ == ProfileStrategy::kOptimal ||
245                              strategy_ == ProfileStrategy::kRangeOptimal;
246     CheckProfile(dimvec, &profile, true, test_optimal_prof);
247   }
248   bool has_prof = (strategy_ == ProfileStrategy::kRange ||
249                    strategy_ == ProfileStrategy::kRangeOptimal);
250   CheckProfile(unseen_shapes, &profile, has_prof, false);
251 }
252 
253 }  // namespace tensorrt
254 }  // namespace tensorflow
255 
256 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
257