1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h"
17
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/const_op.h"
21 #include "tensorflow/compiler/jit/node_matchers.h"
22 #include "tensorflow/compiler/jit/xla_cluster_util.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/public/session_options.h"
26
27 namespace tensorflow {
28 namespace {
29
30 using ::testing::_;
31 using testing::matchers::AssignedDevice;
32 using testing::matchers::Attr;
33 using testing::matchers::Const;
34 using testing::matchers::CtrlDeps;
35 using testing::matchers::Inputs;
36 using testing::matchers::Name;
37 using testing::matchers::NodeWith;
38 using testing::matchers::Op;
39 using testing::matchers::Out;
40
41 // A fake device used to populate a DeviceSet.
42 class FakeDevice : public Device {
43 public:
FakeDevice(const DeviceAttributes & device_attributes)44 explicit FakeDevice(const DeviceAttributes& device_attributes)
45 : Device(nullptr, device_attributes) {}
46
Sync()47 Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }
48
GetAllocator(AllocatorAttributes attr)49 Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
50
Make(const string & name,const string & type)51 static std::unique_ptr<Device> Make(const string& name, const string& type) {
52 DeviceAttributes device_attributes;
53 device_attributes.set_name(name);
54 device_attributes.set_device_type(DeviceType(type).type());
55 return std::make_unique<FakeDevice>(device_attributes);
56 }
57 };
58
59 const char* kHostName = "/job:worker/replica:0/task:0/device:CPU:0";
60 const char* kDeviceName = "/job:worker/replica:0/task:0/device:GPU:0";
61
IncreaseDynamismForAutoJit(const Scope & s,std::unique_ptr<Graph> * result)62 Status IncreaseDynamismForAutoJit(const Scope& s,
63 std::unique_ptr<Graph>* result) {
64 std::vector<std::unique_ptr<Device>> devices;
65 devices.push_back(FakeDevice::Make(kDeviceName, DEVICE_GPU));
66 devices.push_back(FakeDevice::Make(kHostName, DEVICE_CPU));
67
68 std::unique_ptr<DeviceSet> device_set(new DeviceSet());
69 for (auto& device : devices) {
70 device_set->AddDevice(device.get());
71 }
72
73 auto graph = std::make_unique<Graph>(OpRegistry::Global());
74 SessionOptions session_options;
75 session_options.config.mutable_graph_options()
76 ->mutable_optimizer_options()
77 ->set_global_jit_level(OptimizerOptions::ON_2);
78 GraphOptimizationPassOptions options;
79 options.graph = &graph;
80 options.device_set = device_set.get();
81 options.session_options = &session_options;
82
83 // Scope::ToGraph seems to drop assigned devices, probably because it goes
84 // through a GraphDef. So explicitly maintain the device assignment.
85 std::unordered_map<string, string> assigned_device_names;
86 for (Node* n : s.graph()->nodes()) {
87 assigned_device_names[n->name()] = n->assigned_device_name();
88 }
89 TF_RETURN_IF_ERROR(s.ToGraph(graph.get()));
90 for (Node* n : graph->nodes()) {
91 n->set_assigned_device_name(assigned_device_names[n->name()]);
92 }
93
94 IncreaseDynamismForAutoJitPass rewriter;
95 TF_RETURN_IF_ERROR(rewriter.Run(options));
96 *result = std::move(graph);
97 return OkStatus();
98 }
99
TEST(SliceToDynamicSliceRewriteTest,Basic)100 TEST(SliceToDynamicSliceRewriteTest, Basic) {
101 Scope root = Scope::NewRootScope()
102 .ExitOnError()
103 .WithAssignedDevice(kDeviceName)
104 .WithXlaCluster("cluster_0");
105
106 Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
107 Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
108 Output size = ops::Const(root.WithOpName("size"), {-1, 500});
109 Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
110
111 std::unique_ptr<Graph> result;
112 TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
113
114 const int64_t zero_64 = 0;
115 const int32_t zero_32 = 0;
116 const int64_t one_64 = 1;
117
118 auto m_input = Out(NodeWith(Op("Placeholder"), Name("input")));
119 auto m_begin_s64 = Out(NodeWith(
120 Op("Cast"), Inputs(Out(NodeWith(Op("Placeholder"), Name("begin"))))));
121 auto m_input_shape = Out(NodeWith(Op("Shape"), Inputs(m_input)));
122 auto m_slice_size_0 = Out(NodeWith(
123 Op("Sub"), AssignedDevice(kHostName),
124 Inputs(
125 Out(NodeWith(Op("Slice"), AssignedDevice(kHostName),
126 Inputs(m_input_shape, Const(zero_64), Const(one_64)))),
127 Out(NodeWith(Op("Slice"), AssignedDevice(kHostName),
128 Inputs(m_begin_s64, Const(zero_64), Const(one_64)))))));
129 auto m_dynamic_slice_size =
130 Out(NodeWith(Op("ConcatV2"), AssignedDevice(kHostName),
131 Inputs(m_slice_size_0, Const(static_cast<int64_t>(500)),
132 Const(zero_32))));
133
134 std::vector<string> compile_time_constant_inputs;
135 compile_time_constant_inputs.push_back("size");
136 auto m_dynamic_slice = NodeWith(
137 Op("Slice"), AssignedDevice(kDeviceName),
138 Attr(kXlaCompileTimeConstantInputsAttr, compile_time_constant_inputs),
139 Inputs(m_input, m_begin_s64, m_dynamic_slice_size));
140
141 Node* static_shaped_slice = testing::FindNodeByName(
142 result.get(), "slice/static_shaped_slice/static_shaped_slice");
143 ASSERT_NE(static_shaped_slice, nullptr);
144 EXPECT_THAT(static_shaped_slice, m_dynamic_slice);
145 }
146
TEST(SliceToDynamicSliceRewriteTest,SliceFromVector)147 TEST(SliceToDynamicSliceRewriteTest, SliceFromVector) {
148 Scope root = Scope::NewRootScope()
149 .ExitOnError()
150 .WithAssignedDevice(kDeviceName)
151 .WithXlaCluster("cluster_0");
152
153 Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
154 Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
155 Output size = ops::Const(root.WithOpName("size"), {-1});
156 Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
157
158 std::unique_ptr<Graph> result;
159 TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
160
161 Node* static_shaped_slice = testing::FindNodeByName(
162 result.get(), "slice/static_shaped_slice/static_shaped_slice");
163 EXPECT_NE(static_shaped_slice, nullptr);
164 EXPECT_THAT(result->nodes(), Not(Contains(NodeWith(Op("ConcatV2")))));
165 }
166
TEST(SliceToDynamicSliceRewriteTest,ControlDependencePreserved)167 TEST(SliceToDynamicSliceRewriteTest, ControlDependencePreserved) {
168 Scope root = Scope::NewRootScope()
169 .ExitOnError()
170 .WithAssignedDevice(kDeviceName)
171 .WithXlaCluster("cluster_0");
172
173 Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
174 Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
175 Output size = ops::Const(root.WithOpName("size"), {-1, 500});
176 Output control_pred = ops::Placeholder(root.WithOpName("control"), DT_BOOL);
177 Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
178 root.graph()->AddControlEdge(control_pred.node(), slice.node());
179
180 std::unique_ptr<Graph> result;
181 TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
182
183 Node* static_shaped_slice = testing::FindNodeByName(
184 result.get(), "slice/static_shaped_slice/static_shaped_slice");
185 ASSERT_NE(static_shaped_slice, nullptr);
186 EXPECT_THAT(static_shaped_slice,
187 NodeWith(Op("Slice"),
188 CtrlDeps(NodeWith(Op("Placeholder"), Name("control")))));
189 }
190
ToInt64(int v)191 int64_t ToInt64(int v) { return static_cast<int64_t>(v); }
192
TEST(SliceToDynamicSliceRewriteTest,Int64Indices)193 TEST(SliceToDynamicSliceRewriteTest, Int64Indices) {
194 Scope root = Scope::NewRootScope()
195 .ExitOnError()
196 .WithAssignedDevice(kDeviceName)
197 .WithXlaCluster("cluster_0");
198
199 Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
200 Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
201 Output size =
202 ops::Const(root.WithOpName("size"), {ToInt64(-1), ToInt64(500)});
203 Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
204
205 std::unique_ptr<Graph> result;
206 TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
207
208 EXPECT_THAT(result->nodes(), Not(Contains(NodeWith(Op("Cast")))));
209 }
210
TEST(SliceToDynamicSliceRewriteTest,DontRewriteInvalidSlice)211 TEST(SliceToDynamicSliceRewriteTest, DontRewriteInvalidSlice) {
212 Scope root = Scope::NewRootScope()
213 .ExitOnError()
214 .WithAssignedDevice(kDeviceName)
215 .WithXlaCluster("cluster_0");
216
217 Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
218 Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
219
220 // The shape refiner throws an error if we use a bogus constant value for
221 // size. So we first use a Placeholder to placate the shape refiner, and
222 // later replace it with a bogus constant.
223 Output size_placeholder =
224 ops::Placeholder(root.WithOpName("size_placeholder"), DT_INT32);
225 Output slice =
226 ops::Slice(root.WithOpName("slice"), input, begin, size_placeholder);
227
228 Output size = ops::Const(root.WithOpName("size"), {-8, 500});
229 TF_ASSERT_OK(root.graph()->UpdateEdge(/*new_src=*/size.node(),
230 /*new_src_index=*/0,
231 /*dst=*/slice.node(), /*dst_index=*/2));
232
233 std::unique_ptr<Graph> result;
234 TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
235
236 EXPECT_THAT(result->nodes(),
237 Not(Contains(NodeWith(Op("Slice"),
238 Attr(kXlaCompileTimeConstantInputsAttr)))));
239 }
240
TEST(SliceToDynamicSliceRewriteTest,DontRewriteUnclusteredSlice)241 TEST(SliceToDynamicSliceRewriteTest, DontRewriteUnclusteredSlice) {
242 Scope root =
243 Scope::NewRootScope().ExitOnError().WithAssignedDevice(kDeviceName);
244
245 Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
246 Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
247 Output size = ops::Const(root.WithOpName("size"), {-1, 500});
248 Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
249
250 std::unique_ptr<Graph> result;
251 TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
252
253 EXPECT_THAT(result->nodes(),
254 Not(Contains(NodeWith(Op("Slice"),
255 Attr(kXlaCompileTimeConstantInputsAttr)))));
256 }
257
TEST(SliceToDynamicSliceRewriteTest,DontRewriteSliceWithNonConstSize)258 TEST(SliceToDynamicSliceRewriteTest, DontRewriteSliceWithNonConstSize) {
259 Scope root = Scope::NewRootScope()
260 .ExitOnError()
261 .WithAssignedDevice(kDeviceName)
262 .WithXlaCluster("cluster_0");
263
264 Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
265 Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
266 Output size = ops::Placeholder(root.WithOpName("size"), DT_INT64);
267 Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
268
269 std::unique_ptr<Graph> result;
270 TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
271
272 EXPECT_THAT(result->nodes(),
273 Not(Contains(NodeWith(Op("Slice"),
274 Attr(kXlaCompileTimeConstantInputsAttr)))));
275 }
276
TEST(SliceToDynamicSliceRewriteTest,ScalarSlice)277 TEST(SliceToDynamicSliceRewriteTest, ScalarSlice) {
278 Scope root = Scope::NewRootScope()
279 .ExitOnError()
280 .WithAssignedDevice(kDeviceName)
281 .WithXlaCluster("cluster_0");
282
283 Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
284 Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
285 Output size = ops::Const<int64_t>(root.WithOpName("size"), {});
286 Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
287
288 std::unique_ptr<Graph> result;
289 TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
290
291 Node* static_shaped_slice = testing::FindNodeByName(
292 result.get(), "slice/static_shaped_slice/static_shaped_slice");
293 ASSERT_NE(static_shaped_slice, nullptr);
294 EXPECT_THAT(static_shaped_slice,
295 NodeWith(Op("Slice"), Attr(kXlaCompileTimeConstantInputsAttr),
296 Inputs(_, _, Out(NodeWith(Name(size.node()->name()))))));
297 }
298
TEST(SliceToDynamicSliceRewriteTest,IndicesNotVector)299 TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) {
300 Scope root = Scope::NewRootScope()
301 .ExitOnError()
302 .WithAssignedDevice(kDeviceName)
303 .WithXlaCluster("cluster_0");
304
305 auto ToInt64 = [](int v) { return static_cast<int64_t>(v); };
306
307 Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
308 Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
309
310 // The C++ node bindings immediately error out when we try construct a bogus
311 // slice so we first use a placeholder to construct the Slice and then replace
312 // the input.
313 Output size_placeholder = ops::Placeholder(root.WithOpName("size"), DT_INT64);
314 Output slice =
315 ops::Slice(root.WithOpName("slice"), input, begin, size_placeholder);
316
317 Output size =
318 ops::Const(root.WithOpName("size"), {{ToInt64(-1)}, {ToInt64(500)}});
319 TF_ASSERT_OK(root.graph()->UpdateEdge(size.node(), 0, slice.node(), 2));
320
321 std::unique_ptr<Graph> result;
322 TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
323
324 EXPECT_THAT(result->nodes(),
325 Not(Contains(NodeWith(Op("Slice"),
326 Attr(kXlaCompileTimeConstantInputsAttr)))));
327 }
328
TEST(SliceToDynamicSliceRewriteTest,SliceWithSliceInput)329 TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceInput) {
330 Scope root = Scope::NewRootScope()
331 .ExitOnError()
332 .WithAssignedDevice(kDeviceName)
333 .WithXlaCluster("cluster_0");
334
335 Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
336 Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
337 Output size_a = ops::Const(root.WithOpName("size_a"), {-1, 500});
338 Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size_a);
339
340 Output size_b = ops::Const(root.WithOpName("size_a"), {-1, 200});
341 Output slice_with_slice_input = ops::Slice(
342 root.WithOpName("slice_with_slice_input"), slice, begin, size_b);
343
344 std::unique_ptr<Graph> result;
345 TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
346
347 Node* static_shaped_slice = testing::FindNodeByName(
348 result.get(),
349 "slice_with_slice_input/static_shaped_slice/static_shaped_slice");
350 ASSERT_NE(static_shaped_slice, nullptr);
351 EXPECT_EQ(static_shaped_slice->output_type(0), DT_FLOAT)
352 << "Expected DT_FLOAT, was "
353 << DataType_Name(static_shaped_slice->output_type(0));
354 EXPECT_THAT(
355 static_shaped_slice,
356 NodeWith(
357 Op("Slice"),
358 Inputs(Out(NodeWith(
359 Op("Slice"),
360 Name("slice/static_shaped_slice/static_shaped_slice"))),
361 _, _)));
362 }
363
TEST(SliceToDynamicSliceRewriteTest,SliceWithSliceBegin)364 TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceBegin) {
365 Scope root = Scope::NewRootScope()
366 .ExitOnError()
367 .WithAssignedDevice(kDeviceName)
368 .WithXlaCluster("cluster_0");
369
370 Output input_float =
371 ops::Placeholder(root.WithOpName("input_float"), DT_FLOAT);
372 Output input_i64 = ops::Placeholder(root.WithOpName("input_i64"), DT_INT64);
373
374 Output begin_begin =
375 ops::Placeholder(root.WithOpName("begin_begin"), DT_INT32);
376 Output begin_size = ops::Const(root.WithOpName("begin_size"), {-1});
377 Output begin =
378 ops::Slice(root.WithOpName("begin"), input_i64, begin_begin, begin_size);
379
380 Output size =
381 ops::Const(root.WithOpName("size"), {ToInt64(-1), ToInt64(200)});
382 Output slice_with_slice_begin = ops::Slice(
383 root.WithOpName("slice_with_slice_begin"), input_float, begin, size);
384
385 std::unique_ptr<Graph> result;
386 TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
387
388 Node* static_shaped_slice = testing::FindNodeByName(
389 result.get(),
390 "slice_with_slice_begin/static_shaped_slice/static_shaped_slice");
391 ASSERT_NE(static_shaped_slice, nullptr);
392 EXPECT_EQ(static_shaped_slice->output_type(0), DT_FLOAT)
393 << "Expected DT_FLOAT, was "
394 << DataType_Name(static_shaped_slice->output_type(0));
395 EXPECT_THAT(
396 static_shaped_slice,
397 NodeWith(
398 Op("Slice"),
399 Inputs(_,
400 Out(NodeWith(
401 Op("Slice"),
402 Name("begin/static_shaped_slice/static_shaped_slice"))),
403 _)));
404 }
405
406 // New constants being created need to have control dependencies copied to
407 // ensure correct control flow analysis in TF V2.
TEST(SliceToDynamicSliceRewriteTest,WithControlDepsToConstant)408 TEST(SliceToDynamicSliceRewriteTest, WithControlDepsToConstant) {
409 Scope root = Scope::NewRootScope()
410 .ExitOnError()
411 .WithAssignedDevice(kDeviceName)
412 .WithXlaCluster("cluster_0");
413
414 Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
415 Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
416 Output size = ops::Const(root.WithOpName("size"), {-1});
417 Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
418
419 // Add an additional dependency that should still exist in with the new size
420 // variables.
421 Output dependency = ops::Placeholder(root.WithOpName("dependency"), DT_BOOL);
422 root.graph()->AddControlEdge(dependency.node(), size.node());
423
424 std::unique_ptr<Graph> result;
425 TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
426
427 // Check that the new constants have control dependencies.
428 Node* const_0 = testing::FindNodeByName(result.get(),
429 "slice/static_shaped_slice/const_0");
430 EXPECT_NE(const_0, nullptr);
431 EXPECT_THAT(const_0,
432 NodeWith(Op("Const"), CtrlDeps(NodeWith(Op("Placeholder"),
433 Name("dependency")))));
434 }
435
TEST(SliceToDynamicSliceRewriteTest,DontRewriteSliceWithConstBegin)436 TEST(SliceToDynamicSliceRewriteTest, DontRewriteSliceWithConstBegin) {
437 Scope root = Scope::NewRootScope()
438 .ExitOnError()
439 .WithAssignedDevice(kDeviceName)
440 .WithXlaCluster("cluster_0");
441
442 Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
443 Output begin = ops::Const(root.WithOpName("begin"), {10, 10});
444 Output size = ops::Const(root.WithOpName("size"), {-1, 500});
445 Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
446
447 std::unique_ptr<Graph> result;
448 TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
449
450 Node* slice_node = testing::FindNodeByName(result.get(), "slice");
451 EXPECT_THAT(slice_node,
452 NodeWith(Op("Slice"), Inputs(Out(NodeWith(Op("Placeholder"))),
453 Out(NodeWith(Op("Const"))),
454 Out(NodeWith(Op("Const"))))));
455 }
456
457 } // namespace
458 } // namespace tensorflow
459