1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h"
17 
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/const_op.h"
21 #include "tensorflow/compiler/jit/node_matchers.h"
22 #include "tensorflow/compiler/jit/xla_cluster_util.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/public/session_options.h"
26 
27 namespace tensorflow {
28 namespace {
29 
30 using ::testing::_;
31 using testing::matchers::AssignedDevice;
32 using testing::matchers::Attr;
33 using testing::matchers::Const;
34 using testing::matchers::CtrlDeps;
35 using testing::matchers::Inputs;
36 using testing::matchers::Name;
37 using testing::matchers::NodeWith;
38 using testing::matchers::Op;
39 using testing::matchers::Out;
40 
41 // A fake device used to populate a DeviceSet.
42 class FakeDevice : public Device {
43  public:
FakeDevice(const DeviceAttributes & device_attributes)44   explicit FakeDevice(const DeviceAttributes& device_attributes)
45       : Device(nullptr, device_attributes) {}
46 
Sync()47   Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }
48 
GetAllocator(AllocatorAttributes attr)49   Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
50 
Make(const string & name,const string & type)51   static std::unique_ptr<Device> Make(const string& name, const string& type) {
52     DeviceAttributes device_attributes;
53     device_attributes.set_name(name);
54     device_attributes.set_device_type(DeviceType(type).type());
55     return std::make_unique<FakeDevice>(device_attributes);
56   }
57 };
58 
59 const char* kHostName = "/job:worker/replica:0/task:0/device:CPU:0";
60 const char* kDeviceName = "/job:worker/replica:0/task:0/device:GPU:0";
61 
IncreaseDynamismForAutoJit(const Scope & s,std::unique_ptr<Graph> * result)62 Status IncreaseDynamismForAutoJit(const Scope& s,
63                                   std::unique_ptr<Graph>* result) {
64   std::vector<std::unique_ptr<Device>> devices;
65   devices.push_back(FakeDevice::Make(kDeviceName, DEVICE_GPU));
66   devices.push_back(FakeDevice::Make(kHostName, DEVICE_CPU));
67 
68   std::unique_ptr<DeviceSet> device_set(new DeviceSet());
69   for (auto& device : devices) {
70     device_set->AddDevice(device.get());
71   }
72 
73   auto graph = std::make_unique<Graph>(OpRegistry::Global());
74   SessionOptions session_options;
75   session_options.config.mutable_graph_options()
76       ->mutable_optimizer_options()
77       ->set_global_jit_level(OptimizerOptions::ON_2);
78   GraphOptimizationPassOptions options;
79   options.graph = &graph;
80   options.device_set = device_set.get();
81   options.session_options = &session_options;
82 
83   // Scope::ToGraph seems to drop assigned devices, probably because it goes
84   // through a GraphDef.  So explicitly maintain the device assignment.
85   std::unordered_map<string, string> assigned_device_names;
86   for (Node* n : s.graph()->nodes()) {
87     assigned_device_names[n->name()] = n->assigned_device_name();
88   }
89   TF_RETURN_IF_ERROR(s.ToGraph(graph.get()));
90   for (Node* n : graph->nodes()) {
91     n->set_assigned_device_name(assigned_device_names[n->name()]);
92   }
93 
94   IncreaseDynamismForAutoJitPass rewriter;
95   TF_RETURN_IF_ERROR(rewriter.Run(options));
96   *result = std::move(graph);
97   return OkStatus();
98 }
99 
TEST(SliceToDynamicSliceRewriteTest,Basic)100 TEST(SliceToDynamicSliceRewriteTest, Basic) {
101   Scope root = Scope::NewRootScope()
102                    .ExitOnError()
103                    .WithAssignedDevice(kDeviceName)
104                    .WithXlaCluster("cluster_0");
105 
106   Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
107   Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
108   Output size = ops::Const(root.WithOpName("size"), {-1, 500});
109   Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
110 
111   std::unique_ptr<Graph> result;
112   TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
113 
114   const int64_t zero_64 = 0;
115   const int32_t zero_32 = 0;
116   const int64_t one_64 = 1;
117 
118   auto m_input = Out(NodeWith(Op("Placeholder"), Name("input")));
119   auto m_begin_s64 = Out(NodeWith(
120       Op("Cast"), Inputs(Out(NodeWith(Op("Placeholder"), Name("begin"))))));
121   auto m_input_shape = Out(NodeWith(Op("Shape"), Inputs(m_input)));
122   auto m_slice_size_0 = Out(NodeWith(
123       Op("Sub"), AssignedDevice(kHostName),
124       Inputs(
125           Out(NodeWith(Op("Slice"), AssignedDevice(kHostName),
126                        Inputs(m_input_shape, Const(zero_64), Const(one_64)))),
127           Out(NodeWith(Op("Slice"), AssignedDevice(kHostName),
128                        Inputs(m_begin_s64, Const(zero_64), Const(one_64)))))));
129   auto m_dynamic_slice_size =
130       Out(NodeWith(Op("ConcatV2"), AssignedDevice(kHostName),
131                    Inputs(m_slice_size_0, Const(static_cast<int64_t>(500)),
132                           Const(zero_32))));
133 
134   std::vector<string> compile_time_constant_inputs;
135   compile_time_constant_inputs.push_back("size");
136   auto m_dynamic_slice = NodeWith(
137       Op("Slice"), AssignedDevice(kDeviceName),
138       Attr(kXlaCompileTimeConstantInputsAttr, compile_time_constant_inputs),
139       Inputs(m_input, m_begin_s64, m_dynamic_slice_size));
140 
141   Node* static_shaped_slice = testing::FindNodeByName(
142       result.get(), "slice/static_shaped_slice/static_shaped_slice");
143   ASSERT_NE(static_shaped_slice, nullptr);
144   EXPECT_THAT(static_shaped_slice, m_dynamic_slice);
145 }
146 
TEST(SliceToDynamicSliceRewriteTest,SliceFromVector)147 TEST(SliceToDynamicSliceRewriteTest, SliceFromVector) {
148   Scope root = Scope::NewRootScope()
149                    .ExitOnError()
150                    .WithAssignedDevice(kDeviceName)
151                    .WithXlaCluster("cluster_0");
152 
153   Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
154   Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
155   Output size = ops::Const(root.WithOpName("size"), {-1});
156   Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
157 
158   std::unique_ptr<Graph> result;
159   TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
160 
161   Node* static_shaped_slice = testing::FindNodeByName(
162       result.get(), "slice/static_shaped_slice/static_shaped_slice");
163   EXPECT_NE(static_shaped_slice, nullptr);
164   EXPECT_THAT(result->nodes(), Not(Contains(NodeWith(Op("ConcatV2")))));
165 }
166 
TEST(SliceToDynamicSliceRewriteTest,ControlDependencePreserved)167 TEST(SliceToDynamicSliceRewriteTest, ControlDependencePreserved) {
168   Scope root = Scope::NewRootScope()
169                    .ExitOnError()
170                    .WithAssignedDevice(kDeviceName)
171                    .WithXlaCluster("cluster_0");
172 
173   Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
174   Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
175   Output size = ops::Const(root.WithOpName("size"), {-1, 500});
176   Output control_pred = ops::Placeholder(root.WithOpName("control"), DT_BOOL);
177   Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
178   root.graph()->AddControlEdge(control_pred.node(), slice.node());
179 
180   std::unique_ptr<Graph> result;
181   TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
182 
183   Node* static_shaped_slice = testing::FindNodeByName(
184       result.get(), "slice/static_shaped_slice/static_shaped_slice");
185   ASSERT_NE(static_shaped_slice, nullptr);
186   EXPECT_THAT(static_shaped_slice,
187               NodeWith(Op("Slice"),
188                        CtrlDeps(NodeWith(Op("Placeholder"), Name("control")))));
189 }
190 
ToInt64(int v)191 int64_t ToInt64(int v) { return static_cast<int64_t>(v); }
192 
TEST(SliceToDynamicSliceRewriteTest,Int64Indices)193 TEST(SliceToDynamicSliceRewriteTest, Int64Indices) {
194   Scope root = Scope::NewRootScope()
195                    .ExitOnError()
196                    .WithAssignedDevice(kDeviceName)
197                    .WithXlaCluster("cluster_0");
198 
199   Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
200   Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
201   Output size =
202       ops::Const(root.WithOpName("size"), {ToInt64(-1), ToInt64(500)});
203   Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
204 
205   std::unique_ptr<Graph> result;
206   TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
207 
208   EXPECT_THAT(result->nodes(), Not(Contains(NodeWith(Op("Cast")))));
209 }
210 
TEST(SliceToDynamicSliceRewriteTest,DontRewriteInvalidSlice)211 TEST(SliceToDynamicSliceRewriteTest, DontRewriteInvalidSlice) {
212   Scope root = Scope::NewRootScope()
213                    .ExitOnError()
214                    .WithAssignedDevice(kDeviceName)
215                    .WithXlaCluster("cluster_0");
216 
217   Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
218   Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
219 
220   // The shape refiner throws an error if we use a bogus constant value for
221   // size.  So we first use a Placeholder to placate the shape refiner, and
222   // later replace it with a bogus constant.
223   Output size_placeholder =
224       ops::Placeholder(root.WithOpName("size_placeholder"), DT_INT32);
225   Output slice =
226       ops::Slice(root.WithOpName("slice"), input, begin, size_placeholder);
227 
228   Output size = ops::Const(root.WithOpName("size"), {-8, 500});
229   TF_ASSERT_OK(root.graph()->UpdateEdge(/*new_src=*/size.node(),
230                                         /*new_src_index=*/0,
231                                         /*dst=*/slice.node(), /*dst_index=*/2));
232 
233   std::unique_ptr<Graph> result;
234   TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
235 
236   EXPECT_THAT(result->nodes(),
237               Not(Contains(NodeWith(Op("Slice"),
238                                     Attr(kXlaCompileTimeConstantInputsAttr)))));
239 }
240 
TEST(SliceToDynamicSliceRewriteTest,DontRewriteUnclusteredSlice)241 TEST(SliceToDynamicSliceRewriteTest, DontRewriteUnclusteredSlice) {
242   Scope root =
243       Scope::NewRootScope().ExitOnError().WithAssignedDevice(kDeviceName);
244 
245   Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
246   Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
247   Output size = ops::Const(root.WithOpName("size"), {-1, 500});
248   Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
249 
250   std::unique_ptr<Graph> result;
251   TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
252 
253   EXPECT_THAT(result->nodes(),
254               Not(Contains(NodeWith(Op("Slice"),
255                                     Attr(kXlaCompileTimeConstantInputsAttr)))));
256 }
257 
TEST(SliceToDynamicSliceRewriteTest,DontRewriteSliceWithNonConstSize)258 TEST(SliceToDynamicSliceRewriteTest, DontRewriteSliceWithNonConstSize) {
259   Scope root = Scope::NewRootScope()
260                    .ExitOnError()
261                    .WithAssignedDevice(kDeviceName)
262                    .WithXlaCluster("cluster_0");
263 
264   Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
265   Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
266   Output size = ops::Placeholder(root.WithOpName("size"), DT_INT64);
267   Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
268 
269   std::unique_ptr<Graph> result;
270   TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
271 
272   EXPECT_THAT(result->nodes(),
273               Not(Contains(NodeWith(Op("Slice"),
274                                     Attr(kXlaCompileTimeConstantInputsAttr)))));
275 }
276 
TEST(SliceToDynamicSliceRewriteTest,ScalarSlice)277 TEST(SliceToDynamicSliceRewriteTest, ScalarSlice) {
278   Scope root = Scope::NewRootScope()
279                    .ExitOnError()
280                    .WithAssignedDevice(kDeviceName)
281                    .WithXlaCluster("cluster_0");
282 
283   Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
284   Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
285   Output size = ops::Const<int64_t>(root.WithOpName("size"), {});
286   Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
287 
288   std::unique_ptr<Graph> result;
289   TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
290 
291   Node* static_shaped_slice = testing::FindNodeByName(
292       result.get(), "slice/static_shaped_slice/static_shaped_slice");
293   ASSERT_NE(static_shaped_slice, nullptr);
294   EXPECT_THAT(static_shaped_slice,
295               NodeWith(Op("Slice"), Attr(kXlaCompileTimeConstantInputsAttr),
296                        Inputs(_, _, Out(NodeWith(Name(size.node()->name()))))));
297 }
298 
TEST(SliceToDynamicSliceRewriteTest,IndicesNotVector)299 TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) {
300   Scope root = Scope::NewRootScope()
301                    .ExitOnError()
302                    .WithAssignedDevice(kDeviceName)
303                    .WithXlaCluster("cluster_0");
304 
305   auto ToInt64 = [](int v) { return static_cast<int64_t>(v); };
306 
307   Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
308   Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
309 
310   // The C++ node bindings immediately error out when we try construct a bogus
311   // slice so we first use a placeholder to construct the Slice and then replace
312   // the input.
313   Output size_placeholder = ops::Placeholder(root.WithOpName("size"), DT_INT64);
314   Output slice =
315       ops::Slice(root.WithOpName("slice"), input, begin, size_placeholder);
316 
317   Output size =
318       ops::Const(root.WithOpName("size"), {{ToInt64(-1)}, {ToInt64(500)}});
319   TF_ASSERT_OK(root.graph()->UpdateEdge(size.node(), 0, slice.node(), 2));
320 
321   std::unique_ptr<Graph> result;
322   TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
323 
324   EXPECT_THAT(result->nodes(),
325               Not(Contains(NodeWith(Op("Slice"),
326                                     Attr(kXlaCompileTimeConstantInputsAttr)))));
327 }
328 
TEST(SliceToDynamicSliceRewriteTest,SliceWithSliceInput)329 TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceInput) {
330   Scope root = Scope::NewRootScope()
331                    .ExitOnError()
332                    .WithAssignedDevice(kDeviceName)
333                    .WithXlaCluster("cluster_0");
334 
335   Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
336   Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
337   Output size_a = ops::Const(root.WithOpName("size_a"), {-1, 500});
338   Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size_a);
339 
340   Output size_b = ops::Const(root.WithOpName("size_a"), {-1, 200});
341   Output slice_with_slice_input = ops::Slice(
342       root.WithOpName("slice_with_slice_input"), slice, begin, size_b);
343 
344   std::unique_ptr<Graph> result;
345   TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
346 
347   Node* static_shaped_slice = testing::FindNodeByName(
348       result.get(),
349       "slice_with_slice_input/static_shaped_slice/static_shaped_slice");
350   ASSERT_NE(static_shaped_slice, nullptr);
351   EXPECT_EQ(static_shaped_slice->output_type(0), DT_FLOAT)
352       << "Expected DT_FLOAT, was "
353       << DataType_Name(static_shaped_slice->output_type(0));
354   EXPECT_THAT(
355       static_shaped_slice,
356       NodeWith(
357           Op("Slice"),
358           Inputs(Out(NodeWith(
359                      Op("Slice"),
360                      Name("slice/static_shaped_slice/static_shaped_slice"))),
361                  _, _)));
362 }
363 
TEST(SliceToDynamicSliceRewriteTest,SliceWithSliceBegin)364 TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceBegin) {
365   Scope root = Scope::NewRootScope()
366                    .ExitOnError()
367                    .WithAssignedDevice(kDeviceName)
368                    .WithXlaCluster("cluster_0");
369 
370   Output input_float =
371       ops::Placeholder(root.WithOpName("input_float"), DT_FLOAT);
372   Output input_i64 = ops::Placeholder(root.WithOpName("input_i64"), DT_INT64);
373 
374   Output begin_begin =
375       ops::Placeholder(root.WithOpName("begin_begin"), DT_INT32);
376   Output begin_size = ops::Const(root.WithOpName("begin_size"), {-1});
377   Output begin =
378       ops::Slice(root.WithOpName("begin"), input_i64, begin_begin, begin_size);
379 
380   Output size =
381       ops::Const(root.WithOpName("size"), {ToInt64(-1), ToInt64(200)});
382   Output slice_with_slice_begin = ops::Slice(
383       root.WithOpName("slice_with_slice_begin"), input_float, begin, size);
384 
385   std::unique_ptr<Graph> result;
386   TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
387 
388   Node* static_shaped_slice = testing::FindNodeByName(
389       result.get(),
390       "slice_with_slice_begin/static_shaped_slice/static_shaped_slice");
391   ASSERT_NE(static_shaped_slice, nullptr);
392   EXPECT_EQ(static_shaped_slice->output_type(0), DT_FLOAT)
393       << "Expected DT_FLOAT, was "
394       << DataType_Name(static_shaped_slice->output_type(0));
395   EXPECT_THAT(
396       static_shaped_slice,
397       NodeWith(
398           Op("Slice"),
399           Inputs(_,
400                  Out(NodeWith(
401                      Op("Slice"),
402                      Name("begin/static_shaped_slice/static_shaped_slice"))),
403                  _)));
404 }
405 
406 // New constants being created need to have control dependencies copied to
407 // ensure correct control flow analysis in TF V2.
TEST(SliceToDynamicSliceRewriteTest,WithControlDepsToConstant)408 TEST(SliceToDynamicSliceRewriteTest, WithControlDepsToConstant) {
409   Scope root = Scope::NewRootScope()
410                    .ExitOnError()
411                    .WithAssignedDevice(kDeviceName)
412                    .WithXlaCluster("cluster_0");
413 
414   Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
415   Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
416   Output size = ops::Const(root.WithOpName("size"), {-1});
417   Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
418 
419   // Add an additional dependency that should still exist in with the new size
420   // variables.
421   Output dependency = ops::Placeholder(root.WithOpName("dependency"), DT_BOOL);
422   root.graph()->AddControlEdge(dependency.node(), size.node());
423 
424   std::unique_ptr<Graph> result;
425   TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
426 
427   // Check that the new constants have control dependencies.
428   Node* const_0 = testing::FindNodeByName(result.get(),
429                                           "slice/static_shaped_slice/const_0");
430   EXPECT_NE(const_0, nullptr);
431   EXPECT_THAT(const_0,
432               NodeWith(Op("Const"), CtrlDeps(NodeWith(Op("Placeholder"),
433                                                       Name("dependency")))));
434 }
435 
TEST(SliceToDynamicSliceRewriteTest,DontRewriteSliceWithConstBegin)436 TEST(SliceToDynamicSliceRewriteTest, DontRewriteSliceWithConstBegin) {
437   Scope root = Scope::NewRootScope()
438                    .ExitOnError()
439                    .WithAssignedDevice(kDeviceName)
440                    .WithXlaCluster("cluster_0");
441 
442   Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
443   Output begin = ops::Const(root.WithOpName("begin"), {10, 10});
444   Output size = ops::Const(root.WithOpName("size"), {-1, 500});
445   Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
446 
447   std::unique_ptr<Graph> result;
448   TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
449 
450   Node* slice_node = testing::FindNodeByName(result.get(), "slice");
451   EXPECT_THAT(slice_node,
452               NodeWith(Op("Slice"), Inputs(Out(NodeWith(Op("Placeholder"))),
453                                            Out(NodeWith(Op("Const"))),
454                                            Out(NodeWith(Op("Const"))))));
455 }
456 
457 }  // namespace
458 }  // namespace tensorflow
459