1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
17
18 #include "tensorflow/cc/ops/function_ops.h"
19 #include "tensorflow/cc/ops/resource_variable_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/compiler/jit/defs.h"
22 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
23 #include "tensorflow/compiler/jit/xla_cluster_util.h"
24 #include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
25 #include "tensorflow/compiler/tf2xla/test_util.h"
26 #include "tensorflow/core/common_runtime/graph_constructor.h"
27 #include "tensorflow/core/framework/graph_to_functiondef.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/hash/hash.h"
30 #include "tensorflow/core/lib/strings/proto_serialization.h"
31 #include "tensorflow/core/platform/test.h"
32 #include "tensorflow/core/util/equal_graph_def.h"
33 #include "tensorflow/core/util/ptr_util.h"
34
35 namespace tensorflow {
36
MakeOuterGraph(const FunctionLibraryDefinition & flib_def,const string & function)37 static std::unique_ptr<Graph> MakeOuterGraph(
38 const FunctionLibraryDefinition& flib_def, const string& function) {
39 Scope scope = Scope::NewRootScope().ExitOnError();
40 TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto()));
41
42 auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
43 auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
44 auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
45 auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
46 auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
47 auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
48 auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
49
50 NodeDef def;
51 TF_CHECK_OK(NodeDefBuilder("launch0", function, &flib_def)
52 .Input(a.node()->name(), 0, DT_INT32)
53 .Input(b.node()->name(), 0, DT_FLOAT)
54 .Input(c.node()->name(), 0, DT_INT32)
55 .Input(d.node()->name(), 0, DT_FLOAT)
56 .Input(u.node()->name(), 0, DT_RESOURCE)
57 .Input(v.node()->name(), 0, DT_RESOURCE)
58 .Input(w.node()->name(), 0, DT_RESOURCE)
59 .Device("/gpu:0")
60 .Attr(kXlaClusterIdAttr, "launch0")
61 .Attr("_variable_start_index", 4)
62 .Finalize(&def));
63
64 Status status;
65 Node* launch = scope.graph()->AddNode(def, &status);
66 TF_CHECK_OK(status);
67 TF_CHECK_OK(scope.DoShapeInference(launch));
68 scope.graph()->AddEdge(a.node(), 0, launch, 0);
69 scope.graph()->AddEdge(b.node(), 0, launch, 1);
70 scope.graph()->AddEdge(c.node(), 0, launch, 2);
71 scope.graph()->AddEdge(d.node(), 0, launch, 3);
72 scope.graph()->AddEdge(u.node(), 0, launch, 4);
73 scope.graph()->AddEdge(v.node(), 0, launch, 5);
74 scope.graph()->AddEdge(w.node(), 0, launch, 6);
75
76 auto out0 =
77 ops::XlaClusterOutput(scope.WithOpName("Out0"), Output(launch, 0));
78 auto out1 =
79 ops::XlaClusterOutput(scope.WithOpName("Out1"), Output(launch, 1));
80 auto out2 =
81 ops::XlaClusterOutput(scope.WithOpName("Out2"), Output(launch, 2));
82 auto out3 =
83 ops::XlaClusterOutput(scope.WithOpName("Out3"), Output(launch, 3));
84
85 auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0);
86 auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0);
87 auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0);
88 auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1);
89 auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2);
90 auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3);
91
92 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
93 TF_CHECK_OK(scope.ToGraph(graph.get()));
94 return graph;
95 }
96
97 // Makes an encapsulate body graph for use in tests.
MakeBodyGraph()98 static std::unique_ptr<Graph> MakeBodyGraph() {
99 Scope scope = Scope::NewRootScope().ExitOnError();
100
101 auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0);
102 auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1);
103 auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2);
104 auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3);
105
106 auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4);
107 auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5);
108 auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6);
109
110 auto add_attrs = [](Node* node) {
111 node->AddAttr(kXlaClusterIdAttr, "launch0");
112 node->set_requested_device("/gpu:0");
113 };
114
115 auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1);
116 add_attrs(b_identity.node());
117 auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT);
118 add_attrs(read_u.node());
119 auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT);
120 add_attrs(read_v.node());
121 auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT);
122 add_attrs(read_w.node());
123
124 auto e = ops::Add(scope.WithOpName("E"), arg0, arg2);
125 add_attrs(e.node());
126 auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
127 add_attrs(f.node());
128 auto g = ops::Add(scope.WithOpName("G"), f, arg3);
129 add_attrs(g.node());
130
131 auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"),
132 b_identity, 0);
133 auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1);
134 auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2);
135 auto out3 =
136 ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3);
137
138 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
139 TF_CHECK_OK(scope.ToGraph(graph.get()));
140 return graph;
141 }
142
TEST(EncapsulateXlaComputations,DeterministicEncapsulate)143 TEST(EncapsulateXlaComputations, DeterministicEncapsulate) {
144 // Test that control edge insertion order doesn't affect the cache key
145 // (cluster name) generated by TPU encapsulate pass.
146 auto get_serialized_graph = [](bool control_input_reversed,
147 bool operand_reversed) -> string {
148 FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
149 std::unique_ptr<Graph> graph(new Graph(&flib_def));
150 {
151 Scope scope = Scope::NewRootScope().ExitOnError();
152 auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32);
153 auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32);
154
155 ops::Add e = operand_reversed ? ops::Add(scope.WithOpName("E"), a0, a1)
156 : ops::Add(scope.WithOpName("E"), a1, a0);
157
158 auto add_attrs = [](Node* node) {
159 node->AddAttr(kXlaClusterIdAttr, "launch0");
160 };
161 add_attrs(e.node());
162
163 TF_CHECK_OK(scope.ToGraph(graph.get()));
164 auto get_node_in_graph = [&graph](Node* node) {
165 return graph->FindNodeId(node->id());
166 };
167 // Insert control edge in different order. The order should not affect
168 // the encapsulated or serialized graph.
169 if (!control_input_reversed) {
170 graph->AddControlEdge(get_node_in_graph(a0.node()),
171 get_node_in_graph(e.node()), true);
172 graph->AddControlEdge(get_node_in_graph(a1.node()),
173 get_node_in_graph(e.node()), true);
174 } else {
175 graph->AddControlEdge(get_node_in_graph(a1.node()),
176 get_node_in_graph(e.node()), true);
177 graph->AddControlEdge(get_node_in_graph(a0.node()),
178 get_node_in_graph(e.node()), true);
179 }
180 }
181 TF_CHECK_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def));
182 return SerializeGraphDeterministic(*graph).ValueOrDie();
183 };
184
185 // Changing the order of control input shouldn't affect the graph generated.
186 EXPECT_EQ(get_serialized_graph(/*control_input_reversed=*/true,
187 /*operand_reversed=*/false),
188 get_serialized_graph(/*control_input_reversed=*/false,
189 /*operand_reversed=*/false));
190
191 // Changing the order of data input should affect the graph generated.
192 EXPECT_NE(get_serialized_graph(/*control_input_reversed=*/false,
193 /*operand_reversed=*/true),
194 get_serialized_graph(/*control_input_reversed=*/false,
195 /*operand_reversed=*/false));
196 }
197
TEST(EncapsulateXlaComputations,Encapsulate)198 TEST(EncapsulateXlaComputations, Encapsulate) {
199 FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
200 std::unique_ptr<Graph> graph(new Graph(&flib_def));
201 {
202 Scope scope = Scope::NewRootScope().ExitOnError();
203 auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
204 auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
205 auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
206 auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
207 auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
208 auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
209 auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
210
211 auto add_attrs = [](Node* node) {
212 node->AddAttr(kXlaClusterIdAttr, "launch0");
213 node->set_requested_device("/gpu:0");
214 };
215
216 auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b);
217 add_attrs(b_identity.node());
218
219 auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), u, DT_FLOAT);
220 add_attrs(read_u.node());
221 auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), v, DT_FLOAT);
222 add_attrs(read_v.node());
223 auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), w, DT_FLOAT);
224 add_attrs(read_w.node());
225
226 auto e = ops::Add(scope.WithOpName("E"), a, c);
227 add_attrs(e.node());
228 auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
229 add_attrs(f.node());
230 auto g = ops::Add(scope.WithOpName("G"), f, d);
231 add_attrs(g.node());
232
233 auto out0 = ops::XlaClusterOutput(scope.WithOpName("Out0"), b_identity);
234 auto out1 = ops::XlaClusterOutput(scope.WithOpName("Out1"), e);
235 auto out2 = ops::XlaClusterOutput(scope.WithOpName("Out2"), g);
236 auto out3 = ops::XlaClusterOutput(scope.WithOpName("Out3"), read_u);
237
238 auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0);
239 auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0);
240 auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0);
241 auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1);
242 auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2);
243 auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3);
244 TF_ASSERT_OK(scope.ToGraph(graph.get()));
245 }
246
247 std::unique_ptr<Graph> graph_copy(new Graph(&flib_def));
248 CopyGraph(*graph, graph_copy.get());
249
250 TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def));
251
252 std::unordered_map<string, Node*> index = graph->BuildNodeNameIndex();
253 string function = index.at("launch0")->type_string();
254
255 // Tests the outer graph is as expected.
256 {
257 std::unique_ptr<Graph> outer = MakeOuterGraph(flib_def, function);
258 GraphDef expected_def;
259 outer->ToGraphDef(&expected_def);
260
261 GraphDef actual_def;
262 graph->ToGraphDef(&actual_def);
263 TF_EXPECT_GRAPH_EQ_INTERNAL(expected_def, actual_def);
264 }
265
266 // Tests the encapsulated body graph is as expected.
267 {
268 std::unique_ptr<Graph> body = MakeBodyGraph();
269 GraphDef expected_body_def;
270 body->ToGraphDef(&expected_body_def);
271
272 InstantiationResultForTest result;
273 TF_EXPECT_OK(InstantiateFunctionForTest(function, flib_def, &result));
274
275 EXPECT_EQ((DataTypeVector{DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT,
276 DT_RESOURCE, DT_RESOURCE, DT_RESOURCE}),
277 result.arg_types);
278 EXPECT_EQ((DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}),
279 result.ret_types);
280 TF_EXPECT_GRAPH_EQ(expected_body_def, result.gdef);
281 }
282
283 // Encapsulates the same computation again, verifies we reuse the same
284 // function. Encapsulation should be deterministic to avoid recompilation.
285 TF_ASSERT_OK(
286 EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def));
287 std::unordered_map<string, Node*> index_copy =
288 graph_copy->BuildNodeNameIndex();
289 string function_copy = index_copy.at("launch0")->type_string();
290 EXPECT_EQ(function, function_copy);
291 }
292
TEST(EncapsulateXlaComputations,BuildXlaLaunchOp)293 TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) {
294 std::unique_ptr<Graph> body_graph = MakeBodyGraph();
295 FunctionDefLibrary flib;
296 TF_ASSERT_OK(GraphToFunctionDef(*body_graph, "launch0", flib.add_function()));
297
298 FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
299
300 std::unique_ptr<Graph> graph = MakeOuterGraph(flib_def, "launch0");
301 TF_ASSERT_OK(EncapsulateXlaComputationsPass::BuildXlaLaunchOps(graph.get()));
302
303 Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError();
304 TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib));
305
306 auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
307 auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
308 auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
309 auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
310 auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
311 auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
312 auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
313
314 NameAttrList function;
315 function.set_name("launch0");
316 auto launch = ops::XlaLaunch(
317 scope.WithOpName("launch0").WithDevice("/gpu:0"),
318 std::initializer_list<Input>{}, std::initializer_list<Input>{a, b, c, d},
319 std::initializer_list<Input>{u, v, w},
320 DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function);
321
322 auto consumer0_a =
323 ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]);
324 auto consumer0_b =
325 ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]);
326 auto consumer0_c =
327 ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]);
328 auto consumer1 =
329 ops::Identity(scope.WithOpName("consumer1"), launch.results[1]);
330 auto consumer2 =
331 ops::Identity(scope.WithOpName("consumer2"), launch.results[2]);
332 auto consumer3 =
333 ops::Identity(scope.WithOpName("consumer3"), launch.results[3]);
334
335 GraphDef expected_def;
336 TF_ASSERT_OK(scope.ToGraphDef(&expected_def));
337
338 GraphDef actual_def;
339 graph->ToGraphDef(&actual_def);
340 TF_EXPECT_GRAPH_EQ(expected_def, actual_def);
341 }
342
343 } // namespace tensorflow
344