1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
17 
18 #include "tensorflow/cc/ops/function_ops.h"
19 #include "tensorflow/cc/ops/resource_variable_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/compiler/jit/defs.h"
22 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
23 #include "tensorflow/compiler/jit/xla_cluster_util.h"
24 #include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
25 #include "tensorflow/compiler/tf2xla/test_util.h"
26 #include "tensorflow/core/common_runtime/graph_constructor.h"
27 #include "tensorflow/core/framework/graph_to_functiondef.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/hash/hash.h"
30 #include "tensorflow/core/lib/strings/proto_serialization.h"
31 #include "tensorflow/core/platform/test.h"
32 #include "tensorflow/core/util/equal_graph_def.h"
33 #include "tensorflow/core/util/ptr_util.h"
34 
35 namespace tensorflow {
36 
MakeOuterGraph(const FunctionLibraryDefinition & flib_def,const string & function)37 static std::unique_ptr<Graph> MakeOuterGraph(
38     const FunctionLibraryDefinition& flib_def, const string& function) {
39   Scope scope = Scope::NewRootScope().ExitOnError();
40   TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto()));
41 
42   auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
43   auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
44   auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
45   auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
46   auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
47   auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
48   auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
49 
50   NodeDef def;
51   TF_CHECK_OK(NodeDefBuilder("launch0", function, &flib_def)
52                   .Input(a.node()->name(), 0, DT_INT32)
53                   .Input(b.node()->name(), 0, DT_FLOAT)
54                   .Input(c.node()->name(), 0, DT_INT32)
55                   .Input(d.node()->name(), 0, DT_FLOAT)
56                   .Input(u.node()->name(), 0, DT_RESOURCE)
57                   .Input(v.node()->name(), 0, DT_RESOURCE)
58                   .Input(w.node()->name(), 0, DT_RESOURCE)
59                   .Device("/gpu:0")
60                   .Attr(kXlaClusterIdAttr, "launch0")
61                   .Attr("_variable_start_index", 4)
62                   .Finalize(&def));
63 
64   Status status;
65   Node* launch = scope.graph()->AddNode(def, &status);
66   TF_CHECK_OK(status);
67   TF_CHECK_OK(scope.DoShapeInference(launch));
68   scope.graph()->AddEdge(a.node(), 0, launch, 0);
69   scope.graph()->AddEdge(b.node(), 0, launch, 1);
70   scope.graph()->AddEdge(c.node(), 0, launch, 2);
71   scope.graph()->AddEdge(d.node(), 0, launch, 3);
72   scope.graph()->AddEdge(u.node(), 0, launch, 4);
73   scope.graph()->AddEdge(v.node(), 0, launch, 5);
74   scope.graph()->AddEdge(w.node(), 0, launch, 6);
75 
76   auto out0 =
77       ops::XlaClusterOutput(scope.WithOpName("Out0"), Output(launch, 0));
78   auto out1 =
79       ops::XlaClusterOutput(scope.WithOpName("Out1"), Output(launch, 1));
80   auto out2 =
81       ops::XlaClusterOutput(scope.WithOpName("Out2"), Output(launch, 2));
82   auto out3 =
83       ops::XlaClusterOutput(scope.WithOpName("Out3"), Output(launch, 3));
84 
85   auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0);
86   auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0);
87   auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0);
88   auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1);
89   auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2);
90   auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3);
91 
92   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
93   TF_CHECK_OK(scope.ToGraph(graph.get()));
94   return graph;
95 }
96 
97 // Makes an encapsulate body graph for use in tests.
MakeBodyGraph()98 static std::unique_ptr<Graph> MakeBodyGraph() {
99   Scope scope = Scope::NewRootScope().ExitOnError();
100 
101   auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0);
102   auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1);
103   auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2);
104   auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3);
105 
106   auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4);
107   auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5);
108   auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6);
109 
110   auto add_attrs = [](Node* node) {
111     node->AddAttr(kXlaClusterIdAttr, "launch0");
112     node->set_requested_device("/gpu:0");
113   };
114 
115   auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1);
116   add_attrs(b_identity.node());
117   auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT);
118   add_attrs(read_u.node());
119   auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT);
120   add_attrs(read_v.node());
121   auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT);
122   add_attrs(read_w.node());
123 
124   auto e = ops::Add(scope.WithOpName("E"), arg0, arg2);
125   add_attrs(e.node());
126   auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
127   add_attrs(f.node());
128   auto g = ops::Add(scope.WithOpName("G"), f, arg3);
129   add_attrs(g.node());
130 
131   auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"),
132                            b_identity, 0);
133   auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1);
134   auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2);
135   auto out3 =
136       ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3);
137 
138   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
139   TF_CHECK_OK(scope.ToGraph(graph.get()));
140   return graph;
141 }
142 
TEST(EncapsulateXlaComputations,DeterministicEncapsulate)143 TEST(EncapsulateXlaComputations, DeterministicEncapsulate) {
144   // Test that control edge insertion order doesn't affect the cache key
145   // (cluster name) generated by TPU encapsulate pass.
146   auto get_serialized_graph = [](bool control_input_reversed,
147                                  bool operand_reversed) -> string {
148     FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
149     std::unique_ptr<Graph> graph(new Graph(&flib_def));
150     {
151       Scope scope = Scope::NewRootScope().ExitOnError();
152       auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32);
153       auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32);
154 
155       ops::Add e = operand_reversed ? ops::Add(scope.WithOpName("E"), a0, a1)
156                                     : ops::Add(scope.WithOpName("E"), a1, a0);
157 
158       auto add_attrs = [](Node* node) {
159         node->AddAttr(kXlaClusterIdAttr, "launch0");
160       };
161       add_attrs(e.node());
162 
163       TF_CHECK_OK(scope.ToGraph(graph.get()));
164       auto get_node_in_graph = [&graph](Node* node) {
165         return graph->FindNodeId(node->id());
166       };
167       // Insert control edge in different order. The order should not affect
168       // the encapsulated or serialized graph.
169       if (!control_input_reversed) {
170         graph->AddControlEdge(get_node_in_graph(a0.node()),
171                               get_node_in_graph(e.node()), true);
172         graph->AddControlEdge(get_node_in_graph(a1.node()),
173                               get_node_in_graph(e.node()), true);
174       } else {
175         graph->AddControlEdge(get_node_in_graph(a1.node()),
176                               get_node_in_graph(e.node()), true);
177         graph->AddControlEdge(get_node_in_graph(a0.node()),
178                               get_node_in_graph(e.node()), true);
179       }
180     }
181     TF_CHECK_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def));
182     return SerializeGraphDeterministic(*graph).ValueOrDie();
183   };
184 
185   // Changing the order of control input shouldn't affect the graph generated.
186   EXPECT_EQ(get_serialized_graph(/*control_input_reversed=*/true,
187                                  /*operand_reversed=*/false),
188             get_serialized_graph(/*control_input_reversed=*/false,
189                                  /*operand_reversed=*/false));
190 
191   // Changing the order of data input should affect the graph generated.
192   EXPECT_NE(get_serialized_graph(/*control_input_reversed=*/false,
193                                  /*operand_reversed=*/true),
194             get_serialized_graph(/*control_input_reversed=*/false,
195                                  /*operand_reversed=*/false));
196 }
197 
TEST(EncapsulateXlaComputations,Encapsulate)198 TEST(EncapsulateXlaComputations, Encapsulate) {
199   FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
200   std::unique_ptr<Graph> graph(new Graph(&flib_def));
201   {
202     Scope scope = Scope::NewRootScope().ExitOnError();
203     auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
204     auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
205     auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
206     auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
207     auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
208     auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
209     auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
210 
211     auto add_attrs = [](Node* node) {
212       node->AddAttr(kXlaClusterIdAttr, "launch0");
213       node->set_requested_device("/gpu:0");
214     };
215 
216     auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b);
217     add_attrs(b_identity.node());
218 
219     auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), u, DT_FLOAT);
220     add_attrs(read_u.node());
221     auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), v, DT_FLOAT);
222     add_attrs(read_v.node());
223     auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), w, DT_FLOAT);
224     add_attrs(read_w.node());
225 
226     auto e = ops::Add(scope.WithOpName("E"), a, c);
227     add_attrs(e.node());
228     auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
229     add_attrs(f.node());
230     auto g = ops::Add(scope.WithOpName("G"), f, d);
231     add_attrs(g.node());
232 
233     auto out0 = ops::XlaClusterOutput(scope.WithOpName("Out0"), b_identity);
234     auto out1 = ops::XlaClusterOutput(scope.WithOpName("Out1"), e);
235     auto out2 = ops::XlaClusterOutput(scope.WithOpName("Out2"), g);
236     auto out3 = ops::XlaClusterOutput(scope.WithOpName("Out3"), read_u);
237 
238     auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0);
239     auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0);
240     auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0);
241     auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1);
242     auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2);
243     auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3);
244     TF_ASSERT_OK(scope.ToGraph(graph.get()));
245   }
246 
247   std::unique_ptr<Graph> graph_copy(new Graph(&flib_def));
248   CopyGraph(*graph, graph_copy.get());
249 
250   TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def));
251 
252   std::unordered_map<string, Node*> index = graph->BuildNodeNameIndex();
253   string function = index.at("launch0")->type_string();
254 
255   // Tests the outer graph is as expected.
256   {
257     std::unique_ptr<Graph> outer = MakeOuterGraph(flib_def, function);
258     GraphDef expected_def;
259     outer->ToGraphDef(&expected_def);
260 
261     GraphDef actual_def;
262     graph->ToGraphDef(&actual_def);
263     TF_EXPECT_GRAPH_EQ_INTERNAL(expected_def, actual_def);
264   }
265 
266   // Tests the encapsulated body graph is as expected.
267   {
268     std::unique_ptr<Graph> body = MakeBodyGraph();
269     GraphDef expected_body_def;
270     body->ToGraphDef(&expected_body_def);
271 
272     InstantiationResultForTest result;
273     TF_EXPECT_OK(InstantiateFunctionForTest(function, flib_def, &result));
274 
275     EXPECT_EQ((DataTypeVector{DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT,
276                               DT_RESOURCE, DT_RESOURCE, DT_RESOURCE}),
277               result.arg_types);
278     EXPECT_EQ((DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}),
279               result.ret_types);
280     TF_EXPECT_GRAPH_EQ(expected_body_def, result.gdef);
281   }
282 
283   // Encapsulates the same computation again, verifies we reuse the same
284   // function. Encapsulation should be deterministic to avoid recompilation.
285   TF_ASSERT_OK(
286       EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def));
287   std::unordered_map<string, Node*> index_copy =
288       graph_copy->BuildNodeNameIndex();
289   string function_copy = index_copy.at("launch0")->type_string();
290   EXPECT_EQ(function, function_copy);
291 }
292 
TEST(EncapsulateXlaComputations,BuildXlaLaunchOp)293 TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) {
294   std::unique_ptr<Graph> body_graph = MakeBodyGraph();
295   FunctionDefLibrary flib;
296   TF_ASSERT_OK(GraphToFunctionDef(*body_graph, "launch0", flib.add_function()));
297 
298   FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
299 
300   std::unique_ptr<Graph> graph = MakeOuterGraph(flib_def, "launch0");
301   TF_ASSERT_OK(EncapsulateXlaComputationsPass::BuildXlaLaunchOps(graph.get()));
302 
303   Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError();
304   TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib));
305 
306   auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
307   auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
308   auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
309   auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
310   auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
311   auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
312   auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
313 
314   NameAttrList function;
315   function.set_name("launch0");
316   auto launch = ops::XlaLaunch(
317       scope.WithOpName("launch0").WithDevice("/gpu:0"),
318       std::initializer_list<Input>{}, std::initializer_list<Input>{a, b, c, d},
319       std::initializer_list<Input>{u, v, w},
320       DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function);
321 
322   auto consumer0_a =
323       ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]);
324   auto consumer0_b =
325       ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]);
326   auto consumer0_c =
327       ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]);
328   auto consumer1 =
329       ops::Identity(scope.WithOpName("consumer1"), launch.results[1]);
330   auto consumer2 =
331       ops::Identity(scope.WithOpName("consumer2"), launch.results[2]);
332   auto consumer3 =
333       ops::Identity(scope.WithOpName("consumer3"), launch.results[3]);
334 
335   GraphDef expected_def;
336   TF_ASSERT_OK(scope.ToGraphDef(&expected_def));
337 
338   GraphDef actual_def;
339   graph->ToGraphDef(&actual_def);
340   TF_EXPECT_GRAPH_EQ(expected_def, actual_def);
341 }
342 
343 }  // namespace tensorflow
344