xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/flex/delegate_data_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/delegate_data.h"
16 
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "absl/memory/memory.h"
25 #include "absl/strings/string_view.h"
26 #include "tensorflow/core/common_runtime/eager/context.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 #include "tensorflow/core/platform/status.h"
30 #include "tensorflow/lite/c/common.h"
31 #include "tensorflow/lite/core/api/error_reporter.h"
32 #include "tensorflow/lite/core/subgraph.h"
33 #include "tensorflow/lite/kernels/subgraph_test_util.h"
34 #include "tensorflow/lite/testing/util.h"
35 
36 namespace tflite {
37 namespace flex {
38 namespace {
39 
40 using ::tensorflow::protobuf::TextFormat;
41 using ::tensorflow::protobuf::util::MessageDifferencer;
42 
TEST(DelegateDataTest,Basic)43 TEST(DelegateDataTest, Basic) {
44   DelegateData data;
45   // We only check for success because it is hard to make initialization fail.
46   // It only happens if we manage to not link the CPU device factory into the
47   // binary.
48   tensorflow::SessionOptions session_options;
49   session_options.config.set_intra_op_parallelism_threads(2);
50   EXPECT_TRUE(data.Prepare(session_options).ok());
51 
52   TfLiteContext dummy_context1 = {};
53   TfLiteContext dummy_context2 = {};
54   ASSERT_NE(data.GetEagerContext(), nullptr);
55   EXPECT_NE(data.GetBufferMap(&dummy_context1), nullptr);
56   EXPECT_NE(data.GetBufferMap(&dummy_context1),
57             data.GetBufferMap(&dummy_context2));
58 }
59 
TEST(DelegateDataTest,CheckFunctionDef)60 TEST(DelegateDataTest, CheckFunctionDef) {
61   tensorflow::StaticDeviceMgr device_mgr(tensorflow::DeviceFactory::NewDevice(
62       "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
63   tensorflow::EagerContext* eager_context = new tensorflow::EagerContext(
64       tensorflow::SessionOptions(),
65       tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
66       /*async=*/false, &device_mgr, /*device_mgr_owned*/ false, nullptr,
67       nullptr);
68 
69   auto select_subgraphs_to_register =
70       [](const std::vector<std::unique_ptr<Subgraph>>& subgraphs,
71          std::set<std::string>* result) {
72         result->insert("add_subgraph");
73         result->insert("mul_subgraph");
74         return ::tensorflow::OkStatus();
75       };
76 
77   // Builds a TF Lite primary graph with two subgraphs.
78   subgraph_test_util::SubgraphBuilder builder;
79   std::unique_ptr<ErrorReporter> error_reporter =
80       std::make_unique<TestErrorReporter>();
81   auto add_subgraph = std::make_unique<Subgraph>(
82       error_reporter.get(), /*external_contexts=*/nullptr,
83       /*subgraphs=*/nullptr, /*resources=*/nullptr, /*resource_ids=*/nullptr,
84       /*initialization_status_map=*/nullptr);
85   add_subgraph->SetName("add_subgraph");
86   auto mul_subgraph = std::make_unique<Subgraph>(
87       error_reporter.get(), /*external_contexts=*/nullptr,
88       /*subgraphs=*/nullptr, /*resources=*/nullptr, /*resource_ids=*/nullptr,
89       /*initialization_status_map=*/nullptr);
90   mul_subgraph->SetName("mul_subgraph");
91   builder.BuildAddSubgraph(add_subgraph.get());
92   builder.BuildMulSubgraph(mul_subgraph.get());
93   std::vector<std::unique_ptr<Subgraph>> subgraphs;
94   subgraphs.push_back(std::move(add_subgraph));
95   subgraphs.push_back(std::move(mul_subgraph));
96   Subgraph main_subgraph(error_reporter.get(), nullptr, &subgraphs,
97                          /*resources=*/nullptr, /*resource_ids=*/nullptr,
98                          /*initialization_status_map=*/nullptr);
99   main_subgraph.SetName("main");
100   TF_ASSERT_OK(RegisterFunctionDefForSubgraphs(
101       main_subgraph, select_subgraphs_to_register,
102       eager_context->HostCPU()->resource_manager(), eager_context,
103       /*flex_delegate=*/nullptr));
104 
105   const string add_fdef_txt = R"pb(
106     signature {
107       name: "add_subgraph"
108       input_arg { name: "args_0" type: DT_INT32 }
109       input_arg { name: "args_1" type: DT_INT32 }
110       output_arg { name: "res_0" type: DT_INT32 }
111       is_stateful: true
112     }
113     node_def {
114       name: "SubgraphResourceKey"
115       op: "Const"
116       attr {
117         key: "dtype"
118         value { type: DT_STRING }
119       }
120       attr {
121         key: "value"
122         value {
123           tensor {
124             dtype: DT_STRING
125             tensor_shape {}
126             string_val: "add_subgraph"
127           }
128         }
129       }
130     }
131     node_def {
132       name: "InvokeTfLite"
133       op: "TfLiteSubgraphExecute"
134       input: "SubgraphResourceKey:output:0"
135       input: "args_0"
136       input: "args_1"
137       attr {
138         key: "Tin"
139         value { list { type: DT_INT32 type: DT_INT32 } }
140       }
141       attr {
142         key: "Tout"
143         value { list { type: DT_INT32 } }
144       }
145     }
146     ret { key: "res_0" value: "InvokeTfLite:output:0" })pb";
147 
148   const string mul_fdef_txt = R"pb(
149     signature {
150       name: "mul_subgraph"
151       input_arg { name: "args_0" type: DT_INT32 }
152       input_arg { name: "args_1" type: DT_INT32 }
153       output_arg { name: "res_0" type: DT_INT32 }
154       is_stateful: true
155     }
156     node_def {
157       name: "SubgraphResourceKey"
158       op: "Const"
159       attr {
160         key: "dtype"
161         value { type: DT_STRING }
162       }
163       attr {
164         key: "value"
165         value {
166           tensor {
167             dtype: DT_STRING
168             tensor_shape {}
169             string_val: "mul_subgraph"
170           }
171         }
172       }
173     }
174     node_def {
175       name: "InvokeTfLite"
176       op: "TfLiteSubgraphExecute"
177       input: "SubgraphResourceKey:output:0"
178       input: "args_0"
179       input: "args_1"
180       attr {
181         key: "Tin"
182         value { list { type: DT_INT32 type: DT_INT32 } }
183       }
184       attr {
185         key: "Tout"
186         value { list { type: DT_INT32 } }
187       }
188     }
189     ret { key: "res_0" value: "InvokeTfLite:output:0" })pb";
190 
191   tensorflow::FunctionDef add_fdef, mul_fdef;
192   ASSERT_TRUE(TextFormat::ParseFromString(add_fdef_txt, &add_fdef));
193   ASSERT_TRUE(TextFormat::ParseFromString(mul_fdef_txt, &mul_fdef));
194   EXPECT_EQ(eager_context->GetFunctionDef("main"), nullptr);
195   ASSERT_NE(eager_context->GetFunctionDef("add_subgraph"), nullptr);
196   ASSERT_NE(eager_context->GetFunctionDef("mul_subgraph"), nullptr);
197   EXPECT_TRUE(MessageDifferencer::Equals(
198       *(eager_context->GetFunctionDef("add_subgraph")), add_fdef));
199   EXPECT_TRUE(MessageDifferencer::Equals(
200       *(eager_context->GetFunctionDef("mul_subgraph")), mul_fdef));
201 
202   eager_context->Unref();
203 }
204 
TEST(DelegateDataTest,CheckFunctionDefWithOnlyMainGraph)205 TEST(DelegateDataTest, CheckFunctionDefWithOnlyMainGraph) {
206   tensorflow::StaticDeviceMgr device_mgr(tensorflow::DeviceFactory::NewDevice(
207       "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
208   tensorflow::EagerContext* eager_context = new tensorflow::EagerContext(
209       tensorflow::SessionOptions(),
210       tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
211       /*async=*/false, &device_mgr, /*device_mgr_owned*/ false, nullptr,
212       nullptr);
213 
214   auto select_subgraphs_to_register =
215       [](const std::vector<std::unique_ptr<Subgraph>>& subgraphs,
216          std::set<std::string>* result) {
217         result->insert("add_subgraph");
218         result->insert("mul_subgraph");
219         return ::tensorflow::OkStatus();
220       };
221 
222   // Builds a TF Lite primary graph with two subgraphs.
223   subgraph_test_util::SubgraphBuilder builder;
224   std::unique_ptr<ErrorReporter> error_reporter =
225       std::make_unique<TestErrorReporter>();
226   Subgraph main_subgraph(error_reporter.get(), /*external_contexts=*/nullptr,
227                          /*subgraphs=*/nullptr, /*resources=*/nullptr,
228                          /*resource_ids=*/nullptr,
229                          /*initialization_status_map=*/nullptr);
230   main_subgraph.SetName("main");
231   TF_ASSERT_OK(RegisterFunctionDefForSubgraphs(
232       main_subgraph, select_subgraphs_to_register,
233       eager_context->HostCPU()->resource_manager(), eager_context,
234       /*flex_delegate=*/nullptr));
235 
236   EXPECT_EQ(eager_context->GetFunctionDef("main"), nullptr);
237 
238   eager_context->Unref();
239 }
240 
241 }  // namespace
242 }  // namespace flex
243 }  // namespace tflite
244