1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/delegate_data.h"
16
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "absl/memory/memory.h"
25 #include "absl/strings/string_view.h"
26 #include "tensorflow/core/common_runtime/eager/context.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 #include "tensorflow/core/platform/status.h"
30 #include "tensorflow/lite/c/common.h"
31 #include "tensorflow/lite/core/api/error_reporter.h"
32 #include "tensorflow/lite/core/subgraph.h"
33 #include "tensorflow/lite/kernels/subgraph_test_util.h"
34 #include "tensorflow/lite/testing/util.h"
35
36 namespace tflite {
37 namespace flex {
38 namespace {
39
40 using ::tensorflow::protobuf::TextFormat;
41 using ::tensorflow::protobuf::util::MessageDifferencer;
42
TEST(DelegateDataTest,Basic)43 TEST(DelegateDataTest, Basic) {
44 DelegateData data;
45 // We only check for success because it is hard to make initialization fail.
46 // It only happens if we manage to not link the CPU device factory into the
47 // binary.
48 tensorflow::SessionOptions session_options;
49 session_options.config.set_intra_op_parallelism_threads(2);
50 EXPECT_TRUE(data.Prepare(session_options).ok());
51
52 TfLiteContext dummy_context1 = {};
53 TfLiteContext dummy_context2 = {};
54 ASSERT_NE(data.GetEagerContext(), nullptr);
55 EXPECT_NE(data.GetBufferMap(&dummy_context1), nullptr);
56 EXPECT_NE(data.GetBufferMap(&dummy_context1),
57 data.GetBufferMap(&dummy_context2));
58 }
59
TEST(DelegateDataTest,CheckFunctionDef)60 TEST(DelegateDataTest, CheckFunctionDef) {
61 tensorflow::StaticDeviceMgr device_mgr(tensorflow::DeviceFactory::NewDevice(
62 "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
63 tensorflow::EagerContext* eager_context = new tensorflow::EagerContext(
64 tensorflow::SessionOptions(),
65 tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
66 /*async=*/false, &device_mgr, /*device_mgr_owned*/ false, nullptr,
67 nullptr);
68
69 auto select_subgraphs_to_register =
70 [](const std::vector<std::unique_ptr<Subgraph>>& subgraphs,
71 std::set<std::string>* result) {
72 result->insert("add_subgraph");
73 result->insert("mul_subgraph");
74 return ::tensorflow::OkStatus();
75 };
76
77 // Builds a TF Lite primary graph with two subgraphs.
78 subgraph_test_util::SubgraphBuilder builder;
79 std::unique_ptr<ErrorReporter> error_reporter =
80 std::make_unique<TestErrorReporter>();
81 auto add_subgraph = std::make_unique<Subgraph>(
82 error_reporter.get(), /*external_contexts=*/nullptr,
83 /*subgraphs=*/nullptr, /*resources=*/nullptr, /*resource_ids=*/nullptr,
84 /*initialization_status_map=*/nullptr);
85 add_subgraph->SetName("add_subgraph");
86 auto mul_subgraph = std::make_unique<Subgraph>(
87 error_reporter.get(), /*external_contexts=*/nullptr,
88 /*subgraphs=*/nullptr, /*resources=*/nullptr, /*resource_ids=*/nullptr,
89 /*initialization_status_map=*/nullptr);
90 mul_subgraph->SetName("mul_subgraph");
91 builder.BuildAddSubgraph(add_subgraph.get());
92 builder.BuildMulSubgraph(mul_subgraph.get());
93 std::vector<std::unique_ptr<Subgraph>> subgraphs;
94 subgraphs.push_back(std::move(add_subgraph));
95 subgraphs.push_back(std::move(mul_subgraph));
96 Subgraph main_subgraph(error_reporter.get(), nullptr, &subgraphs,
97 /*resources=*/nullptr, /*resource_ids=*/nullptr,
98 /*initialization_status_map=*/nullptr);
99 main_subgraph.SetName("main");
100 TF_ASSERT_OK(RegisterFunctionDefForSubgraphs(
101 main_subgraph, select_subgraphs_to_register,
102 eager_context->HostCPU()->resource_manager(), eager_context,
103 /*flex_delegate=*/nullptr));
104
105 const string add_fdef_txt = R"pb(
106 signature {
107 name: "add_subgraph"
108 input_arg { name: "args_0" type: DT_INT32 }
109 input_arg { name: "args_1" type: DT_INT32 }
110 output_arg { name: "res_0" type: DT_INT32 }
111 is_stateful: true
112 }
113 node_def {
114 name: "SubgraphResourceKey"
115 op: "Const"
116 attr {
117 key: "dtype"
118 value { type: DT_STRING }
119 }
120 attr {
121 key: "value"
122 value {
123 tensor {
124 dtype: DT_STRING
125 tensor_shape {}
126 string_val: "add_subgraph"
127 }
128 }
129 }
130 }
131 node_def {
132 name: "InvokeTfLite"
133 op: "TfLiteSubgraphExecute"
134 input: "SubgraphResourceKey:output:0"
135 input: "args_0"
136 input: "args_1"
137 attr {
138 key: "Tin"
139 value { list { type: DT_INT32 type: DT_INT32 } }
140 }
141 attr {
142 key: "Tout"
143 value { list { type: DT_INT32 } }
144 }
145 }
146 ret { key: "res_0" value: "InvokeTfLite:output:0" })pb";
147
148 const string mul_fdef_txt = R"pb(
149 signature {
150 name: "mul_subgraph"
151 input_arg { name: "args_0" type: DT_INT32 }
152 input_arg { name: "args_1" type: DT_INT32 }
153 output_arg { name: "res_0" type: DT_INT32 }
154 is_stateful: true
155 }
156 node_def {
157 name: "SubgraphResourceKey"
158 op: "Const"
159 attr {
160 key: "dtype"
161 value { type: DT_STRING }
162 }
163 attr {
164 key: "value"
165 value {
166 tensor {
167 dtype: DT_STRING
168 tensor_shape {}
169 string_val: "mul_subgraph"
170 }
171 }
172 }
173 }
174 node_def {
175 name: "InvokeTfLite"
176 op: "TfLiteSubgraphExecute"
177 input: "SubgraphResourceKey:output:0"
178 input: "args_0"
179 input: "args_1"
180 attr {
181 key: "Tin"
182 value { list { type: DT_INT32 type: DT_INT32 } }
183 }
184 attr {
185 key: "Tout"
186 value { list { type: DT_INT32 } }
187 }
188 }
189 ret { key: "res_0" value: "InvokeTfLite:output:0" })pb";
190
191 tensorflow::FunctionDef add_fdef, mul_fdef;
192 ASSERT_TRUE(TextFormat::ParseFromString(add_fdef_txt, &add_fdef));
193 ASSERT_TRUE(TextFormat::ParseFromString(mul_fdef_txt, &mul_fdef));
194 EXPECT_EQ(eager_context->GetFunctionDef("main"), nullptr);
195 ASSERT_NE(eager_context->GetFunctionDef("add_subgraph"), nullptr);
196 ASSERT_NE(eager_context->GetFunctionDef("mul_subgraph"), nullptr);
197 EXPECT_TRUE(MessageDifferencer::Equals(
198 *(eager_context->GetFunctionDef("add_subgraph")), add_fdef));
199 EXPECT_TRUE(MessageDifferencer::Equals(
200 *(eager_context->GetFunctionDef("mul_subgraph")), mul_fdef));
201
202 eager_context->Unref();
203 }
204
TEST(DelegateDataTest,CheckFunctionDefWithOnlyMainGraph)205 TEST(DelegateDataTest, CheckFunctionDefWithOnlyMainGraph) {
206 tensorflow::StaticDeviceMgr device_mgr(tensorflow::DeviceFactory::NewDevice(
207 "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
208 tensorflow::EagerContext* eager_context = new tensorflow::EagerContext(
209 tensorflow::SessionOptions(),
210 tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
211 /*async=*/false, &device_mgr, /*device_mgr_owned*/ false, nullptr,
212 nullptr);
213
214 auto select_subgraphs_to_register =
215 [](const std::vector<std::unique_ptr<Subgraph>>& subgraphs,
216 std::set<std::string>* result) {
217 result->insert("add_subgraph");
218 result->insert("mul_subgraph");
219 return ::tensorflow::OkStatus();
220 };
221
222 // Builds a TF Lite primary graph with two subgraphs.
223 subgraph_test_util::SubgraphBuilder builder;
224 std::unique_ptr<ErrorReporter> error_reporter =
225 std::make_unique<TestErrorReporter>();
226 Subgraph main_subgraph(error_reporter.get(), /*external_contexts=*/nullptr,
227 /*subgraphs=*/nullptr, /*resources=*/nullptr,
228 /*resource_ids=*/nullptr,
229 /*initialization_status_map=*/nullptr);
230 main_subgraph.SetName("main");
231 TF_ASSERT_OK(RegisterFunctionDefForSubgraphs(
232 main_subgraph, select_subgraphs_to_register,
233 eager_context->HostCPU()->resource_manager(), eager_context,
234 /*flex_delegate=*/nullptr));
235
236 EXPECT_EQ(eager_context->GetFunctionDef("main"), nullptr);
237
238 eager_context->Unref();
239 }
240
241 } // namespace
242 } // namespace flex
243 } // namespace tflite
244