xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <numeric>
18 #include <utility>
19 #include <vector>
20 
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "absl/container/inlined_vector.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/span.h"
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
28 #include "tensorflow/cc/framework/scope.h"
29 #include "tensorflow/cc/ops/function_ops.h"
30 #include "tensorflow/cc/ops/math_ops.h"
31 #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
32 #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
33 #include "tensorflow/core/common_runtime/device.h"
34 #include "tensorflow/core/common_runtime/device_factory.h"
35 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
36 #include "tensorflow/core/framework/attr_value.pb.h"
37 #include "tensorflow/core/framework/fake_input.h"
38 #include "tensorflow/core/framework/function.h"
39 #include "tensorflow/core/framework/graph.pb.h"
40 #include "tensorflow/core/framework/node_def_builder.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/resource_mgr.h"
43 #include "tensorflow/core/framework/tensor.h"
44 #include "tensorflow/core/framework/tensor_shape.h"
45 #include "tensorflow/core/framework/types.h"
46 #include "tensorflow/core/framework/types.pb.h"
47 #include "tensorflow/core/graph/graph.h"
48 #include "tensorflow/core/kernels/ops_testutil.h"
49 #include "tensorflow/core/lib/core/status_test_util.h"
50 #include "tensorflow/core/platform/refcount.h"
51 #include "tensorflow/core/platform/status.h"
52 #include "tensorflow/core/public/version.h"
53 
54 #if GOOGLE_CUDA && GOOGLE_TENSORRT
55 
56 namespace tensorflow {
57 namespace tensorrt {
58 using ::absl::StrCat;
59 using ::testing::ElementsAre;
60 
61 struct TestParam {
62   bool static_engine;
63 };
64 
65 class TRTEngineOpTestBase : public OpsTestBase {
66  public:
AddSimpleTrtOp(DataType dtype,int max_cached_engines_count=1,PartialTensorShape shape=PartialTensorShape ({-1, -1}),bool use_implicit_batch=true,bool allow_build_at_runtime=true,bool static_engine=false)67   void AddSimpleTrtOp(DataType dtype, int max_cached_engines_count = 1,
68                       PartialTensorShape shape = PartialTensorShape({-1, -1}),
69                       bool use_implicit_batch = true,
70                       bool allow_build_at_runtime = true,
71                       bool static_engine = false) {
72     // Create the GPU device.
73     std::unique_ptr<Device> device(
74         DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0"));
75 
76     // Create simple TF graph.
77     Scope s = Scope::NewRootScope();
78     auto feed = ops::_Arg(s.WithOpName("TensorRTInputPH_0"), dtype, 0);
79     auto add = ops::Add(s.WithOpName("add"), feed, feed);
80     ops::_Retval(s.WithOpName("TensorRTOutputPH_0"), add, 0);
81 
82     // Serialize the graph. TRTEngineOp will convert it using dynamic mode.
83     GraphDef graph_def;
84     TF_ASSERT_OK(s.ToGraphDef(&graph_def));
85     Graph* graph = s.graph();
86     TF_ASSERT_OK(convert::RegisterGraphToFunctionLibrary(graph_def, graph,
87                                                          std::string(kOpName)));
88     TF_ASSERT_OK(flib_def_->AddLibrary(graph->flib_def()));
89 
90     string segment_string;
91     if (static_engine) {
92       convert::TRTOptimizationPass::ConversionParams params;
93       convert::EngineInfo info;
94       info.segment_graph_def.CopyFrom(graph_def);
95       info.precision_mode = TrtPrecisionMode::FP32;
96       info.max_workspace_size_bytes = 1 << 20;
97       info.engine_name = "TRTEngineOP_000_000";
98       params.use_implicit_batch = use_implicit_batch;
99       params.trt_logger_name = "DefaultLogger";
100 
101       TrtShapeOptimizationProfile profile;
102       // We set the input mask to true (no resource inputs)
103       std::vector<bool> input_mask = {true};
104       profile.SetInputMask(input_mask);
105       // We set profile 0 to be incompatible with the input used in the test.
106       // This way we ensure that profile selection is tested.
107       TensorShape my_shape;
108       TF_CHECK_OK(
109           TensorShapeUtils::MakeShape(std::vector<int32>{4, 2}, &my_shape));
110       profile.AddShape({my_shape, {}});
111       TF_CHECK_OK(
112           TensorShapeUtils::MakeShape(std::vector<int32>{1, 2}, &my_shape));
113       profile.AddShape({my_shape, {}});
114 
115       profile.InitProfiles({shape}, ProfileStrategy::kOptimal);
116       std::vector<PartialTensorShape> shape_vec{shape, {}};
117       TF_CHECK_OK(convert::CreateStaticEngine(
118           params, info, 1, shape_vec, &profile, &segment_string, nullptr));
119     }
120 
121     // Create the op.
122     // In implicit batch mode, the input shapes that we specify here are not
123     // used for engine creation, we use the concrete shapes during inference
124     // time for creating the engine.
125     // In explicit batch mode, the input shapes attribute is used to define
126     // the network for the TensorRT engine.
127     OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
128     NameAttrList function;
129     function.set_name(StrCat(std::string(kOpName), "_native_segment"));
130     // We disable allow_soft_placement when executing the native segment of the
131     // TRTEngineOp for the following reasons:
132     //    OpsTestBase only allow one device in the device manager.
133     //    We need to define the GPU device to test TRTEngineOp.
134     //    When allow_soft_placement is true, the TensorFlow runtime produces an
135     //      error if a CPU device is not defined
136     //      (see ProcessFunctionLibraryRuntime::InstantiateMultiDevice).
137     TF_ASSERT_OK(NodeDefBuilder(std::string(kOpName), "TRTEngineOp")
138                      .Input(FakeInput(1, dtype))
139                      .Attr("input_shapes", {shape})
140                      .Attr("output_shapes", {shape})
141                      .Attr("static_engine", static_engine)
142                      .Attr("segment_func", function)
143                      .Attr("serialized_segment", segment_string)
144                      .Attr("calibration_data", "")
145                      .Attr("max_cached_engines_count", max_cached_engines_count)
146                      .Attr("workspace_size_bytes", 1 << 20)
147                      .Attr("precision_mode", "FP32")
148                      .Attr("use_calibration", false)
149                      .Attr("profile_strategy", "optimal")
150                      .Attr("_use_implicit_batch", use_implicit_batch)
151                      .Attr("_allow_build_at_runtime", allow_build_at_runtime)
152                      .Attr("_allow_soft_placement", false)
153                      .Attr("OutT", {dtype})
154                      .Finalize(OpsTestBase::node_def()));
155     TF_ASSERT_OK(InitOpWithFunctionLibrary());
156   }
157 
158   static const absl::string_view kOpName;
159 
160   template <typename T>
AddSimpleInput(const TensorShape & shape)161   void AddSimpleInput(const TensorShape& shape) {
162     std::vector<T> input(shape.num_elements());
163     std::iota(input.begin(), input.end(), T(0));
164     OpsTestBase::AddInputFromArray<T>(shape, input);
165   }
166 
ResetInputs()167   void ResetInputs() {
168     inputs_.clear();
169     for (auto& temp : tensors_) {
170       delete temp;
171     }
172     tensors_.clear();
173   }
174 
175  private:
InitOpWithFunctionLibrary()176   Status InitOpWithFunctionLibrary() {
177     OpKernel* kernel = nullptr;
178     auto flr = pflr_->GetFLR(device_->name());
179     std::shared_ptr<const NodeProperties> props;
180     Status status = NodeProperties::CreateFromNodeDef(
181         node_def_, flr->GetFunctionLibraryDefinition(), &props);
182     if (status.ok()) {
183       status.Update(CreateOpKernel(device_type_, device_, allocator(), flr,
184                                    props, TF_GRAPH_DEF_VERSION, &kernel));
185     }
186     kernel_ = std::unique_ptr<OpKernel>(kernel);
187     if (kernel_ != nullptr) input_types_ = kernel_->input_types();
188     return status;
189   }
190 };
191 
192 class TRTEngineOpTestWithParam
193     : public TRTEngineOpTestBase,
194       public ::testing::WithParamInterface<TestParam> {
195  public:
TRTEngineOpTestWithParam()196   TRTEngineOpTestWithParam() : param_(GetParam()) {}
197 
198  protected:
199   TestParam param_;
200 };
201 
202 const absl::string_view TRTEngineOpTestBase::kOpName = "myop";
203 
204 constexpr std::array<TestParam, 2> TestParameters{TestParam{false},
205                                                   TestParam{true}};
206 
207 INSTANTIATE_TEST_CASE_P(TRTEngineOpTestInstantiation, TRTEngineOpTestWithParam,
208                         ::testing::ValuesIn(TestParameters));
209 
TEST_F(TRTEngineOpTestBase,DynamicEngines)210 TEST_F(TRTEngineOpTestBase, DynamicEngines) {
211   // Test dynamic engine creation during inference time
212   TRTEngineOpTestBase::AddSimpleTrtOp(DT_FLOAT, /*max_cached_engines_count=*/4);
213 
214   // Execute the op with batch size > 1.
215   TRTEngineOpTestBase::AddSimpleInput<float>(TensorShape({2, 2}));
216   TF_ASSERT_OK(OpsTestBase::RunOpKernel());
217 
218   // Get the engine cache.
219   TRTEngineCacheResource* cache_resource = nullptr;
220   TF_ASSERT_OK(device_->resource_manager()->Lookup(
221       std::string(kTfTrtContainerName), std::string(kOpName), &cache_resource));
222   core::ScopedUnref sc(cache_resource);
223 
224   // It should contain only one engine.
225   auto cache = &cache_resource->cache_;
226   EXPECT_EQ(1, cache->size());
227   EXPECT_EQ(1, cache->count({TensorShape({2, 2})}));
228 
229   // Execute the op with batch size 1. It should reuse existing engine to
230   // execute.
231   ResetInputs();
232   TRTEngineOpTestBase::AddSimpleInput<float>(TensorShape({1, 2}));
233   TF_ASSERT_OK(OpsTestBase::RunOpKernel());
234   EXPECT_EQ(1, cache->size());
235   EXPECT_EQ(1, cache->count({TensorShape({2, 2})}));
236 
237   // Execute the op with a larger batch size.
238   ResetInputs();
239   TRTEngineOpTestBase::AddSimpleInput<float>(TensorShape({3, 2}));
240   TF_ASSERT_OK(OpsTestBase::RunOpKernel());
241   EXPECT_EQ(2, cache->size());
242   EXPECT_EQ(1, cache->count({TensorShape({2, 2})}));
243   EXPECT_EQ(1, cache->count({TensorShape({3, 2})}));
244 
245   // Execute the op with an input that has different non-batch dimension.
246   ResetInputs();
247   TRTEngineOpTestBase::AddSimpleInput<float>(TensorShape({10, 10}));
248   TF_ASSERT_OK(OpsTestBase::RunOpKernel());
249   // Execute it again with an input that has the same non-batch dimension but
250   // smallest batch size. It should find the correct engine to use.
251   ResetInputs();
252   TRTEngineOpTestBase::AddSimpleInput<float>(TensorShape({1, 10}));
253   TF_ASSERT_OK(OpsTestBase::RunOpKernel());
254   EXPECT_EQ(3, cache->size());  // Should only create 3 engines in total.
255   EXPECT_EQ(1, cache->count({TensorShape({2, 2})}));
256   EXPECT_EQ(1, cache->count({TensorShape({3, 2})}));
257   EXPECT_EQ(1, cache->count({TensorShape({10, 10})}));
258 }
259 
TEST_F(TRTEngineOpTestBase,AllowBuildAtRuntime)260 TEST_F(TRTEngineOpTestBase, AllowBuildAtRuntime) {
261   TRTEngineOpTestBase::AddSimpleTrtOp(DT_FLOAT, /*max_cached_engines_count=*/1,
262                                       PartialTensorShape({-1, -1}),
263                                       /*use_implicit_batch=*/true,
264                                       /*allow_build_at_runtime=*/false);
265 
266   // Execute the op
267   TensorShape input_shape({2, 2});
268   TRTEngineOpTestBase::AddSimpleInput<float>(input_shape);
269   TF_ASSERT_OK(OpsTestBase::RunOpKernel());
270 
271   // Get the engine cache.
272   TRTEngineCacheResource* cache_resource = nullptr;
273   TF_ASSERT_OK(device_->resource_manager()->Lookup(
274       std::string(kTfTrtContainerName), std::string(kOpName), &cache_resource));
275   core::ScopedUnref sc(cache_resource);
276 
277   // It should contain a placeholder with an empty cuda_engine (to mark that
278   // engine creation was not successful for the given input shape).
279   auto cache = &cache_resource->cache_;
280   EXPECT_EQ(1, cache->size());
281   ASSERT_EQ(1, cache->count({input_shape}));
282   EngineContext* ectx = cache->at({input_shape}).get();
283   EXPECT_EQ(ectx->GetCudaEngine(), nullptr);
284 }
285 
TEST_P(TRTEngineOpTestWithParam,ExplicitBatch)286 TEST_P(TRTEngineOpTestWithParam, ExplicitBatch) {
287   // Test inference in explicit batch mode with static input shapes. Static
288   // shapes in this context means that the TensorRT knows all the input shapes
289   // during engine creation time.
290   TRTEngineOpTestBase::AddSimpleTrtOp(DT_FLOAT, /*max_cached_engines_count=*/1,
291                                       /*shape=*/PartialTensorShape({1, 2}),
292                                       /*use_implicit_batch=*/false,
293                                       /*allow_build_at_runtime=*/true,
294                                       /*static_engine=*/param_.static_engine);
295 
296   TensorShape input_shape({1, 2});
297   TRTEngineOpTestBase::AddSimpleInput<float>(input_shape);
298   TF_ASSERT_OK(OpsTestBase::RunOpKernel());
299 
300   // Get the engine cache.
301   TRTEngineCacheResource* cache_resource = nullptr;
302   TF_ASSERT_OK(device_->resource_manager()->Lookup(
303       std::string(kTfTrtContainerName), std::string(kOpName), &cache_resource));
304   core::ScopedUnref sc(cache_resource);
305 
306   auto cache = &cache_resource->cache_;
307   EXPECT_EQ(1, cache->size());
308   ASSERT_EQ(1, cache->count({input_shape}));
309   EngineContext* ectx = cache->at({input_shape}).get();
310   EXPECT_NE(ectx->GetCudaEngine(), nullptr);
311 }
312 
TEST_P(TRTEngineOpTestWithParam,DynamicShapes)313 TEST_P(TRTEngineOpTestWithParam, DynamicShapes) {
314   // Test inference in explicit batch mode with dynamic input shapes. Dynamic
315   // shapes in this context means that some input shapes for TensorRT are
316   // unknown during engine creation time. When we create the network, the
317   // unknow shapes are repsesented as -1. Before we run inference, these shapes
318   // have to be specified by calling setBindingDimensions.
319   TRTEngineOpTestBase::AddSimpleTrtOp(DT_FLOAT, /*max_cached_engines_count=*/1,
320                                       /*shape=*/PartialTensorShape({-1, -1}),
321                                       /*use_implicit_batch=*/false,
322                                       /*allow_build_at_runtime=*/true,
323                                       param_.static_engine);
324 
325   TensorShape input_shape({1, 2});
326   TRTEngineOpTestBase::AddSimpleInput<float>(input_shape);
327 
328   TF_ASSERT_OK(OpsTestBase::RunOpKernel());
329 
330   // Get the engine cache.
331   TRTEngineCacheResource* cache_resource = nullptr;
332   TF_ASSERT_OK(device_->resource_manager()->Lookup(
333       std::string(kTfTrtContainerName), std::string(kOpName), &cache_resource));
334   core::ScopedUnref sc(cache_resource);
335 
336   auto cache = &cache_resource->cache_;
337   EXPECT_EQ(1, cache->size());
338   ASSERT_EQ(1, cache->count({input_shape}));
339   EngineContext* ectx = cache->at({input_shape}).get();
340   EXPECT_NE(ectx->GetCudaEngine(), nullptr);
341 
342   // Execute the op with an incompatible shape.
343   ResetInputs();
344   TRTEngineOpTestBase::AddSimpleInput<float>(TensorShape({1, 37}));
345   // Test that the op runs. This should fall back to native segment.
346   TF_ASSERT_OK(OpsTestBase::RunOpKernel());
347   // We should still have a single engine that is not compatible with the input.
348   EXPECT_EQ(1, cache->size());
349   EXPECT_EQ(0, cache->count({TensorShape({1, 37})}));
350 }
351 
352 template <typename T>
353 class TRTEngineOpTest : public TRTEngineOpTestBase {};
354 
355 using TypeList = ::testing::Types<float, Eigen::half>;
356 TYPED_TEST_SUITE(TRTEngineOpTest, TypeList);
357 
TYPED_TEST(TRTEngineOpTest,Basic)358 TYPED_TEST(TRTEngineOpTest, Basic) {
359   TRTEngineOpTestBase::AddSimpleTrtOp(DataTypeToEnum<TypeParam>::v());
360 
361   // Execute the op.
362   OpsTestBase::AddInputFromArray<TypeParam>(TensorShape({1, 2}),
363                                             {TypeParam(0.0f), TypeParam(1.0f)});
364   TF_ASSERT_OK(OpsTestBase::RunOpKernel());
365 
366   // Verify the result.
367   Tensor* output = OpsTestBase::GetOutput(0);
368   EXPECT_THAT(
369       absl::Span<const TypeParam>(output->template flat<TypeParam>().data(),
370                                   output->NumElements()),
371       ElementsAre(TypeParam(0.0f), TypeParam(2.0f)));
372 }
373 
374 }  // namespace tensorrt
375 }  // namespace tensorflow
376 
377 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
378