1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/executor.h"
17
18 #include <algorithm>
19
20 #include "tensorflow/cc/framework/ops.h"
21 #include "tensorflow/cc/ops/array_ops.h"
22 #include "tensorflow/cc/ops/const_op.h"
23 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
24 #include "tensorflow/cc/ops/function_ops.h"
25 #include "tensorflow/cc/ops/standard_ops.h"
26 #include "tensorflow/core/common_runtime/device.h"
27 #include "tensorflow/core/common_runtime/device_factory.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
30 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
31 #include "tensorflow/core/common_runtime/process_util.h"
32 #include "tensorflow/core/common_runtime/step_stats_collector.h"
33 #include "tensorflow/core/framework/attr_value.pb.h"
34 #include "tensorflow/core/framework/op.h"
35 #include "tensorflow/core/framework/rendezvous.h"
36 #include "tensorflow/core/framework/step_stats.pb.h"
37 #include "tensorflow/core/framework/tensor_testutil.h"
38 #include "tensorflow/core/framework/versions.pb.h"
39 #include "tensorflow/core/graph/algorithm.h"
40 #include "tensorflow/core/graph/testlib.h"
41 #include "tensorflow/core/lib/core/status_test_util.h"
42 #include "tensorflow/core/lib/random/simple_philox.h"
43 #include "tensorflow/core/lib/strings/strcat.h"
44 #include "tensorflow/core/platform/logging.h"
45 #include "tensorflow/core/platform/strcat.h"
46 #include "tensorflow/core/platform/test.h"
47 #include "tensorflow/core/platform/test_benchmark.h"
48 #include "tensorflow/core/platform/tracing.h"
49 #include "tensorflow/core/public/session_options.h"
50
51 namespace tensorflow {
52
53 class ExecutorTest : public ::testing::Test {
54 protected:
ExecutorTest()55 ExecutorTest()
56 : device_(DeviceFactory::NewDevice("CPU", {},
57 "/job:localhost/replica:0/task:0")),
58
59 step_stats_collector_(&step_stats_) {
60 SessionOptions options;
61 thread_pool_ = ComputePool(options);
62 }
63
~ExecutorTest()64 ~ExecutorTest() override {
65 // LocalRendezvous::AsyncRecv() might still executing after done_callback
66 // returns. Wait until the local rc_owner_ releases.
67 while (!rendez_->RefCountIsOne()) {
68 }
69 // There should always be exactly one Ref left on the Rendezvous
70 // when the test completes.
71 CHECK(rendez_->Unref());
72 delete exec_;
73 }
74
75 // Resets executor_ with a new executor based on a graph 'gdef'.
Create(std::unique_ptr<const Graph> graph)76 void Create(std::unique_ptr<const Graph> graph) {
77 const int version = graph->versions().producer();
78 LocalExecutorParams params;
79 params.device = device_.get();
80 params.create_kernel =
81 [this, version](const std::shared_ptr<const NodeProperties>& props,
82 OpKernel** kernel) {
83 return CreateNonCachedKernel(device_.get(), nullptr, props, version,
84 kernel);
85 };
86 params.delete_kernel = [](OpKernel* kernel) {
87 DeleteNonCachedKernel(kernel);
88 };
89 rendez_ = NewLocalRendezvous();
90 delete exec_;
91 TF_CHECK_OK(NewLocalExecutor(params, *graph, &exec_));
92 runner_ = [this](std::function<void()> fn) { thread_pool_->Schedule(fn); };
93 }
94
Run(Rendezvous * rendez)95 Status Run(Rendezvous* rendez) {
96 Executor::Args args;
97 args.rendezvous = rendez;
98 args.stats_collector = &step_stats_collector_;
99 args.runner = runner_;
100 return exec_->Run(args);
101 }
102
103 thread::ThreadPool* thread_pool_ = nullptr;
104 std::unique_ptr<Device> device_;
105 Executor* exec_ = nullptr;
106 StepStatsCollector step_stats_collector_;
107 StepStats step_stats_;
108 Executor::Args::Runner runner_;
109 Rendezvous* rendez_ = nullptr;
110 };
111
112 // A float val -> Tensor<float>
V(const float val)113 Tensor V(const float val) {
114 Tensor tensor(DT_FLOAT, TensorShape({}));
115 tensor.scalar<float>()() = val;
116 return tensor;
117 }
118
119 // A int32 val -> Tensor<int32>
VI(const int32_t val)120 Tensor VI(const int32_t val) {
121 Tensor tensor(DT_INT32, TensorShape({}));
122 tensor.scalar<int32>()() = val;
123 return tensor;
124 }
125
126 // A bool val -> Tensor<bool>
VB(const bool val)127 Tensor VB(const bool val) {
128 Tensor tensor(DT_BOOL, TensorShape({}));
129 tensor.scalar<bool>()() = val;
130 return tensor;
131 }
132
133 // A double val -> Tensor<double>
VD(const double val)134 Tensor VD(const double val) {
135 Tensor tensor(DT_DOUBLE, TensorShape({}));
136 tensor.scalar<double>()() = val;
137 return tensor;
138 }
139
140 // Tensor<float> -> a float val.
V(const Tensor & tensor)141 float V(const Tensor& tensor) {
142 CHECK_EQ(tensor.dtype(), DT_FLOAT);
143 CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
144 return tensor.scalar<float>()();
145 }
146
147 static uint64 kIncarnation = 1; // Uses in following tests.
148
Key(const string & sender,const uint64 incarnation,const string & receiver,const string & name)149 Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation,
150 const string& receiver, const string& name) {
151 Rendezvous::ParsedKey result;
152 CHECK(
153 Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver,
154 name, FrameAndIter(0, 0)),
155 &result)
156 .ok());
157 return result;
158 }
159
160 #define ALICE "/job:j/replica:0/task:0/cpu:0"
161 #define BOB "/job:j/replica:0/task:0/device:GPU:0"
162
TEST_F(ExecutorTest,SimpleAdd)163 TEST_F(ExecutorTest, SimpleAdd) {
164 // c = a + b
165 auto g = std::make_unique<Graph>(OpRegistry::Global());
166 auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB);
167 auto in1 = test::graph::Recv(g.get(), "b", "float", ALICE, 1, BOB);
168 auto tmp = test::graph::Add(g.get(), in0, in1);
169 test::graph::Send(g.get(), tmp, "c", BOB, 1, ALICE);
170 Create(std::move(g));
171 Rendezvous::Args args;
172 TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0),
173 false)); // in0 = 1.0
174 TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "b"), args, V(1.0),
175 false)); // in1 = 1.0
176 TF_ASSERT_OK(Run(rendez_));
177 Tensor out = V(-1);
178 bool is_dead = false;
179 TF_ASSERT_OK(
180 rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead));
181 EXPECT_EQ(2.0, V(out)); // out = 1.0 + 1.0 = 2.0
182 }
183
TEST_F(ExecutorTest,SelfAdd)184 TEST_F(ExecutorTest, SelfAdd) {
185 // v0 <- a
186 // v1 = v0 + v0
187 // v2 = v1 + v1
188 // ... ...
189 // v10 = v9 + v9
190 //
191 // b <- v10
192 // All nodes are executed by one thread.
193 auto g = std::make_unique<Graph>(OpRegistry::Global());
194 auto v = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB);
195 const int N = 10;
196 for (int i = 1; i <= N; ++i) {
197 v = test::graph::Add(g.get(), v, v);
198 }
199 // out <- v10
200 test::graph::Send(g.get(), v, "b", BOB, 1, ALICE);
201 Create(std::move(g));
202 Rendezvous::Args args;
203 // a = 1.0
204 TF_ASSERT_OK(
205 rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false));
206 TF_ASSERT_OK(Run(rendez_));
207 Tensor out = V(-1);
208 bool is_dead = false;
209 TF_ASSERT_OK(
210 rendez_->Recv(Key(BOB, kIncarnation, ALICE, "b"), args, &out, &is_dead));
211 EXPECT_EQ(1024.0, V(out)); // b=v10=2*v9=4*v8=...=1024*a=1024.0
212 }
213
214 // Builds a graph which adds N copies of one variable "in". I.e.,
215 // a + a + a + ... + a
216 // The returned graph is parenthesized ramdonly. I.e.,
217 // a + ((a + a) + a)
218 // (a + a) + (a + a)
219 // ((a + a) + a) + a
220 // are all possibly generated.
BuildTree(int N,Graph * g)221 void BuildTree(int N, Graph* g) {
222 CHECK_GT(N, 1);
223 // A single input node "in".
224 auto in = test::graph::Recv(g, "a", "float", ALICE, 1, BOB);
225 std::vector<Node*> nodes;
226 int i = 0;
227 // Duplicate "in" N times. Each copies is named as l0, l1, l2, ....
228 for (; i < N; ++i) {
229 nodes.push_back(test::graph::Identity(g, in, 0));
230 }
231 random::PhiloxRandom philox(testing::RandomSeed(), 17);
232 random::SimplePhilox rnd(&philox);
233 while (nodes.size() > 1) {
234 // Randomly pick two from nodes and add them. The resulting node
235 // is named lik n10, n11, .... and is put back into "nodes".
236 int x = rnd.Uniform(nodes.size());
237 auto in0 = nodes[x];
238 nodes[x] = nodes.back();
239 nodes.resize(nodes.size() - 1);
240 x = rnd.Uniform(nodes.size());
241 auto in1 = nodes[x];
242 // node = in0 + in1.
243 nodes[x] = test::graph::Add(g, in0, in1);
244 }
245 // The final output node "out".
246 test::graph::Send(g, nodes.back(), "b", BOB, 1, ALICE);
247 }
248
TEST_F(ExecutorTest,RandomTree)249 TEST_F(ExecutorTest, RandomTree) {
250 auto g = std::make_unique<Graph>(OpRegistry::Global());
251 BuildTree(4096, g.get());
252 Create(std::move(g));
253 Rendezvous::Args args;
254 TF_ASSERT_OK(
255 rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false));
256 TF_ASSERT_OK(Run(rendez_));
257 Tensor out = V(-1);
258 bool is_dead = false;
259 TF_ASSERT_OK(
260 rendez_->Recv(Key(BOB, kIncarnation, ALICE, "b"), args, &out, &is_dead));
261 EXPECT_EQ(4096.0, V(out));
262 }
263
BuildConcurrentAddAssign(Graph * g)264 void BuildConcurrentAddAssign(Graph* g) {
265 auto one = test::graph::Constant(g, V(1.0));
266 // A variable holds one float.
267 auto var = test::graph::Var(g, DT_FLOAT, TensorShape({}));
268 // Initialize the variable with 1.0.
269 auto init = test::graph::Assign(g, var, one);
270 // Output
271 auto out = test::graph::Send(g, var, "out", ALICE, kIncarnation, BOB);
272 // Have many concurrent computation. Each does v = v + 1.
273 for (int i = 0; i < 1024; ++i) {
274 auto add = test::graph::Add(g, var, one);
275 g->AddControlEdge(init, add); // Ensures run after init.
276 auto assign = test::graph::Assign(g, var, add);
277 g->AddControlEdge(assign, out);
278 }
279 }
280
281 #ifndef THREAD_SANITIZER
TEST_F(ExecutorTest,ConcurrentAddAssign)282 TEST_F(ExecutorTest, ConcurrentAddAssign) {
283 auto g = std::make_unique<Graph>(OpRegistry::Global());
284 BuildConcurrentAddAssign(g.get());
285 Create(std::move(g));
286 for (int iters = 0; iters < 16; ++iters) {
287 Rendezvous* rendez = NewLocalRendezvous();
288 TF_ASSERT_OK(Run(rendez));
289 Rendezvous::Args args;
290 Tensor out;
291 bool is_dead;
292 TF_ASSERT_OK(rendez->Recv(Key(ALICE, kIncarnation, BOB, "out"), args, &out,
293 &is_dead));
294 VLOG(1) << "Get " << V(out);
295 EXPECT_LE(V(out), 1025.0);
296 rendez->Unref();
297 }
298 }
299 #endif
300
TEST_F(ExecutorTest,SimpleSwitchLive)301 TEST_F(ExecutorTest, SimpleSwitchLive) {
302 auto g = std::make_unique<Graph>(OpRegistry::Global());
303 auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB);
304 auto in1 = test::graph::Constant(g.get(), VB(false));
305 auto tmp = test::graph::Switch(g.get(), in0, in1);
306 test::graph::Send(g.get(), tmp, "c", BOB, 1, ALICE);
307 Create(std::move(g));
308 Rendezvous::Args args;
309 TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0),
310 false)); // in0 = 1.0
311 TF_ASSERT_OK(Run(rendez_));
312 Tensor out = V(-1);
313 bool is_dead = false;
314 TF_ASSERT_OK(
315 rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead));
316 EXPECT_EQ(1.0, V(out)); // out = 1.0
317 EXPECT_FALSE(is_dead);
318 }
319
TEST_F(ExecutorTest,SimpleSwitchDead)320 TEST_F(ExecutorTest, SimpleSwitchDead) {
321 auto g = std::make_unique<Graph>(OpRegistry::Global());
322 auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB);
323 auto in1 = test::graph::Constant(g.get(), VB(true));
324 auto tmp = test::graph::Switch(g.get(), in0, in1);
325 test::graph::Send(g.get(), tmp, "c", BOB, 1, ALICE);
326 Create(std::move(g));
327 Rendezvous::Args args;
328 TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0),
329 false)); // in0 = 1.0
330 TF_ASSERT_OK(Run(rendez_));
331 Tensor out = V(-1);
332 bool is_dead = false;
333 TF_ASSERT_OK(
334 rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead));
335 EXPECT_TRUE(is_dead);
336 }
337
TEST_F(ExecutorTest,Abort)338 TEST_F(ExecutorTest, Abort) {
339 // e = a + b + c + d
340 auto g = std::make_unique<Graph>(OpRegistry::Global());
341 auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB);
342 auto in1 = test::graph::Recv(g.get(), "b", "float", ALICE, 1, BOB);
343 auto in2 = test::graph::Recv(g.get(), "c", "float", ALICE, 1, BOB);
344 auto in3 = test::graph::Recv(g.get(), "d", "float", ALICE, 1, BOB);
345 auto add0 = test::graph::Add(g.get(), in0, in1);
346 auto add1 = test::graph::Add(g.get(), in2, in3);
347 auto add2 = test::graph::Add(g.get(), add0, add1);
348 test::graph::Send(g.get(), add2, "e", BOB, 1, ALICE);
349 Create(std::move(g));
350
351 // Needs 4 inputs (recv). One of them is aborted.
352 rendez_->Ref();
353 SchedClosure([this]() {
354 Env::Default()->SleepForMicroseconds(100 * 1000);
355 Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"),
356 Rendezvous::Args(), V(1.0), false);
357 rendez_->Unref();
358 });
359 rendez_->Ref();
360 SchedClosure([this]() {
361 Env::Default()->SleepForMicroseconds(100 * 1000);
362 Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "b"),
363 Rendezvous::Args(), V(1.0), false);
364 rendez_->Unref();
365 });
366 rendez_->Ref();
367 SchedClosure([this]() {
368 Env::Default()->SleepForMicroseconds(100 * 1000);
369 Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "c"),
370 Rendezvous::Args(), V(1.0), false);
371 rendez_->Unref();
372 });
373 rendez_->Ref();
374 SchedClosure([this]() {
375 Env::Default()->SleepForMicroseconds(100 * 1000);
376 rendez_->StartAbort(errors::Aborted(""));
377 rendez_->Unref();
378 });
379 EXPECT_TRUE(errors::IsAborted(Run(rendez_)));
380 Tensor out = V(-1);
381 bool is_dead = false;
382 EXPECT_TRUE(errors::IsAborted(rendez_->Recv(
383 Key(BOB, kIncarnation, ALICE, "c"), Rendezvous::Args(), &out, &is_dead)));
384 // At this point there can still be pending (albeit Aborted) Send
385 // closures holding Refs on rendez_. We need to wait for them, or
386 // else there can be a memory leak at termination. This wait logic is in test
387 // dtor.
388 }
389
TEST_F(ExecutorTest,RecvInvalidDtype)390 TEST_F(ExecutorTest, RecvInvalidDtype) {
391 auto g = std::make_unique<Graph>(OpRegistry::Global());
392 // An input vector of type float of size 1.
393 auto one = test::graph::Recv(g.get(), "one", "float", ALICE, 1, BOB);
394 // A floating point variable vector of size 1.
395 auto var = test::graph::Var(g.get(), DT_FLOAT, TensorShape({1}));
396 // Initialize the variable with input.
397 auto init = test::graph::Assign(g.get(), var, one);
398 // Output
399 auto* two = test::graph::Send(g.get(), var, "two", BOB, 1, ALICE);
400 g->AddControlEdge(init, two); // Ensures run after init.
401 Create(std::move(g));
402 Rendezvous* rendez = NewLocalRendezvous();
403 // Send a double instead of float.
404 TF_ASSERT_OK(rendez->Send(Key(ALICE, 1, BOB, "one"), Rendezvous::Args(),
405 VD(1.0), false));
406 // Fails due to invalid dtype.
407 EXPECT_TRUE(errors::IsInternal(Run(rendez)));
408 Tensor output;
409 bool is_dead;
410 EXPECT_TRUE(errors::IsInternal(rendez->Recv(
411 Key(BOB, 1, ALICE, "two"), Rendezvous::Args(), &output, &is_dead)));
412 rendez->Unref();
413 }
414
TEST_F(ExecutorTest,RecvInvalidRefDtype)415 TEST_F(ExecutorTest, RecvInvalidRefDtype) {
416 auto g = std::make_unique<Graph>(OpRegistry::Global());
417 // A var that always produces as invalid dtype.
418 auto var = test::graph::InvalidRefType(g.get(), DT_FLOAT, DT_DOUBLE);
419 test::graph::Send(g.get(), var, "out", BOB, 1, ALICE);
420 Create(std::move(g));
421 Rendezvous* rendez = NewLocalRendezvous();
422 EXPECT_TRUE(errors::IsInternal(Run(rendez)));
423 Tensor output;
424 bool is_dead;
425 EXPECT_TRUE(errors::IsInternal(rendez->Recv(
426 Key(BOB, 1, ALICE, "out"), Rendezvous::Args(), &output, &is_dead)));
427 rendez->Unref();
428 }
429
TEST_F(ExecutorTest,NoInputTensors)430 TEST_F(ExecutorTest, NoInputTensors) {
431 // Create a graph where none of the nodes have input tensors.
432 auto g = std::make_unique<Graph>(OpRegistry::Global());
433 test::graph::Constant(g.get(), V(1.0));
434 Create(std::move(g));
435 TF_ASSERT_OK(Run(rendez_));
436 }
437
438 // Create a graph that is 'depth' deep. At each level, fan-in and fan-out a
439 // maximum of 'width' nodes. All nodes are no-ops and all dependencies are
440 // control dependencies.
BM_executor(::testing::benchmark::State & state)441 static void BM_executor(::testing::benchmark::State& state) {
442 const int width = state.range(0);
443 const int depth = state.range(1);
444
445 Graph* g = new Graph(OpRegistry::Global());
446 random::PhiloxRandom philox(1729, 17);
447 random::SimplePhilox rand(&philox);
448 uint64 cur = 0;
449 uint32 r = 1 + rand.Rand32() % width;
450 std::vector<Node*> ready_nodes;
451 for (int i = 0; i < r; ++i) {
452 ready_nodes.push_back(test::graph::NoOp(g, {}));
453 ++cur;
454 }
455 std::random_device random_device;
456 std::mt19937 rng(random_device());
457 for (int i = 0; i < depth; ++i) {
458 std::shuffle(ready_nodes.begin(), ready_nodes.end(), rng);
459 r = 1 + rand.Rand32() % (ready_nodes.size());
460 std::vector<Node*> control_inputs;
461 for (int j = 0; j < r; ++j) {
462 control_inputs.push_back(ready_nodes.back());
463 ready_nodes.pop_back();
464 }
465 Node* n = test::graph::NoOp(g, control_inputs);
466 ++cur;
467 r = 1 + rand.Rand32() % width;
468 for (int j = 0; j < r; ++j) {
469 ready_nodes.push_back(test::graph::NoOp(g, {n}));
470 ++cur;
471 }
472 }
473
474 FixupSourceAndSinkEdges(g);
475 test::Benchmark("cpu", g, /*old_benchmark_api=*/false).Run(state);
476
477 state.SetLabel(strings::StrCat("Nodes = ", cur));
478 state.SetItemsProcessed(cur * static_cast<int64_t>(state.iterations()));
479 }
480
481 // Tall skinny graphs
482 BENCHMARK(BM_executor)->UseRealTime()->ArgPair(16, 1024);
483 BENCHMARK(BM_executor)->UseRealTime()->ArgPair(32, 8192);
484
485 // Short fat graphs
486 BENCHMARK(BM_executor)->UseRealTime()->ArgPair(1024, 16);
487 BENCHMARK(BM_executor)->UseRealTime()->ArgPair(8192, 32);
488
489 // Tall fat graph
490 BENCHMARK(BM_executor)->UseRealTime()->ArgPair(1024, 1024);
491
BM_const_identity(::testing::benchmark::State & state)492 static void BM_const_identity(::testing::benchmark::State& state) {
493 const int width = state.range(0);
494 const int outputs_per_const = state.range(1);
495
496 Graph* g = new Graph(OpRegistry::Global());
497 for (int i = 0; i < width; ++i) {
498 Tensor i_t(i);
499 Node* const_node = test::graph::Constant(g, i_t);
500 for (int j = 0; j < outputs_per_const; ++j) {
501 test::graph::Identity(g, const_node);
502 }
503 }
504 FixupSourceAndSinkEdges(g);
505 test::Benchmark("cpu", g, /*old_benchmark_api=*/false).Run(state);
506 state.SetLabel(strings::StrCat("Nodes = ", (1 + outputs_per_const) * width));
507 state.SetItemsProcessed((1 + outputs_per_const) * width *
508 static_cast<int64_t>(state.iterations()));
509 }
510
511 // Graph with actual op execution.
512 BENCHMARK(BM_const_identity)
513 ->UseRealTime()
514 ->ArgPair(1, 1)
515 ->ArgPair(1, 100)
516 ->ArgPair(100, 1)
517 ->ArgPair(100, 100);
518
BM_FeedInputFetchOutput(::testing::benchmark::State & state)519 static void BM_FeedInputFetchOutput(::testing::benchmark::State& state) {
520 Graph* g = new Graph(OpRegistry::Global());
521 // z = x + y: x and y are provided as benchmark inputs. z is the
522 // output of the benchmark. Conceptually, the caller is ALICE, the
523 // benchmark is BOB.
524 Node* x = test::graph::Recv(g, "x", "float", ALICE, 1, BOB);
525 Node* y = test::graph::Recv(g, "y", "float", ALICE, 1, BOB);
526 Node* sum = test::graph::Add(g, x, y);
527 Node* z = test::graph::Send(g, sum, "z", BOB, 1, ALICE);
528
529 string x_key = test::GetRendezvousKey(x);
530 string y_key = test::GetRendezvousKey(y);
531 string z_key = test::GetRendezvousKey(z);
532
533 Tensor val(DT_FLOAT, TensorShape({}));
534 val.scalar<float>()() = 3.14;
535 FixupSourceAndSinkEdges(g);
536 test::Benchmark("cpu", g, /*old_benchmark_api=*/false)
537 .RunWithRendezvousArgs({{x_key, val}, {y_key, val}}, {z_key}, state);
538 state.SetItemsProcessed(static_cast<int64_t>(state.iterations()));
539 }
540 BENCHMARK(BM_FeedInputFetchOutput);
541
ReplaceEdgeWithSendRecv(Graph * g,const Edge * edge,const string & tensor,const string & sender,const uint64 sender_incarnation,const string & receiver)542 Status ReplaceEdgeWithSendRecv(Graph* g, const Edge* edge, const string& tensor,
543 const string& sender,
544 const uint64 sender_incarnation,
545 const string& receiver) {
546 Node* send;
547 NodeDef send_def;
548 TF_CHECK_OK(NodeDefBuilder(g->NewName("n"), "_Send")
549 .Input(edge->src()->name(), edge->src_output(),
550 edge->src()->output_type(edge->src_output()))
551 .Attr("tensor_name", tensor)
552 .Attr("send_device", sender)
553 .Attr("send_device_incarnation",
554 static_cast<int64_t>(sender_incarnation))
555 .Attr("recv_device", receiver)
556 .Finalize(&send_def));
557
558 TF_ASSIGN_OR_RETURN(send, g->AddNode(send_def));
559
560 Node* recv;
561 NodeDef recv_def;
562 TF_CHECK_OK(
563 NodeDefBuilder(g->NewName("n"), "_Recv")
564 .Attr("tensor_name", tensor)
565 .Attr("send_device", sender)
566 .Attr("send_device_incarnation",
567 static_cast<int64_t>(sender_incarnation))
568 .Attr("recv_device", receiver)
569 .Attr("tensor_type", edge->dst()->input_type(edge->dst_input()))
570 .Finalize(&recv_def));
571
572 TF_ASSIGN_OR_RETURN(recv, g->AddNode(recv_def));
573
574 g->AddEdge(edge->src(), edge->src_output(), send, 0);
575 g->AddEdge(recv, 0, edge->dst(), edge->dst_input());
576
577 // This control dependency can ensure Exit op can still be downstream
578 // op of Enter after inserting Send/Recv.
579 g->AddControlEdge(edge->src(), recv);
580
581 g->RemoveEdge(edge);
582 return OkStatus();
583 }
584
585 // Defines a graph to perform the following computation:
586 //
587 // i = 0
588 // while (i < loop_iters)
589 // i += 1;
590 //
591 // ...using the functional `WhileOp` (if `lower` is false) or the
592 // `Switch`/`Merge`-style of control flow (if `lower` is true).
BM_WhileLoopHelper(::testing::benchmark::State & state,int loop_iters,int loop_vars,bool lower,bool transfer)593 static void BM_WhileLoopHelper(::testing::benchmark::State& state,
594 int loop_iters, int loop_vars, bool lower,
595 bool transfer) {
596 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
597
598 // Add test functions for cond and body.
599 FunctionDefLibrary f_lib_proto;
600
601 // Define the loop body as a function: `x = x + 1`.
602 const Tensor one_t = test::AsScalar<int32>(1);
603
604 std::vector<string> args;
605 args.reserve(loop_vars);
606 args.push_back("x: int32");
607 for (int i = 1; i < loop_vars; ++i) {
608 args.push_back(strings::StrCat("x", i, ": int32"));
609 }
610
611 std::vector<string> body_rets;
612 body_rets.reserve(loop_vars);
613 body_rets.push_back("y: int32");
614 for (int i = 1; i < loop_vars; ++i) {
615 body_rets.push_back(strings::StrCat("y", i, ": int32"));
616 }
617
618 std::vector<FunctionDefHelper::Node> body_nodes;
619 body_nodes.reserve(1 + loop_vars);
620 body_nodes.push_back(
621 {{"one"}, "Const", {}, {{"value", one_t}, {"dtype", DT_INT32}}});
622 body_nodes.push_back({{"y"}, "Add", {"x", "one"}, {{"T", DT_INT32}}});
623 for (int i = 1; i < loop_vars; ++i) {
624 body_nodes.push_back({{strings::StrCat("y", i)},
625 "Relu",
626 {strings::StrCat("x", i)},
627 {{"T", DT_INT32}}});
628 }
629
630 *f_lib_proto.add_function() = FunctionDefHelper::Define(
631 // Name
632 "XPlusOne",
633 // Args
634 args,
635 // Return values
636 body_rets,
637 // Attr def
638 {},
639 // Nodes
640 body_nodes);
641
642 // Define the loop condition as a function: `x < loop_iters`.
643 const Tensor loop_iters_t = test::AsScalar<int32>(loop_iters);
644 *f_lib_proto.add_function() = FunctionDefHelper::Define(
645 // Name
646 "LessThanOrEqualToN",
647 // Args
648 args,
649 // Return values
650 {"z: bool"},
651 // Attr def
652 {},
653 // Nodes
654 {
655 {{"N"}, "Const", {}, {{"value", loop_iters_t}, {"dtype", DT_INT32}}},
656 {{"z"}, "LessEqual", {"x", "N"}, {{"T", DT_INT32}}},
657 });
658
659 Scope root = Scope::NewRootScope().ExitOnError();
660 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
661 auto a = ops::Const(root.WithOpName("A"), 0, {});
662 Node* while_node;
663 std::vector<NodeBuilder::NodeOut> inputs;
664 std::vector<DataType> input_types(loop_vars, DT_INT32);
665 inputs.reserve(loop_vars);
666 for (int i = 0; i < loop_vars; ++i) {
667 inputs.push_back(NodeBuilder::NodeOut(a.node()));
668 }
669 AttrValue int32_attr;
670 int32_attr.set_type(DT_INT32);
671 AttrValue cond_func;
672 cond_func.mutable_func()->set_name("LessThanOrEqualToN");
673 AttrValue body_func;
674 body_func.mutable_func()->set_name("XPlusOne");
675 TF_ASSERT_OK(
676 NodeBuilder("while", "While", &root.graph()->flib_def())
677 .Input(inputs)
678 .Attr("T", input_types)
679 .Attr("cond", cond_func)
680 .Attr("body", body_func)
681 .Attr("parallel_iterations", 20)
682 .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
683 .Finalize(root.graph(), &while_node));
684 auto c = ops::Identity(
685 root.WithOpName("C").WithControlDependencies(Output(while_node)),
686 Output(while_node));
687 TF_ASSERT_OK(root.DoShapeInference(while_node));
688 TF_ASSERT_OK(root.ToGraph(graph.get()));
689
690 if (lower) {
691 FunctionLibraryDefinition flib_def(graph->flib_def());
692 GraphOptimizationPassOptions opt_options;
693 SessionOptions session_options;
694 session_options.config.mutable_graph_options()
695 ->mutable_optimizer_options()
696 ->set_do_function_inlining(true);
697 opt_options.session_options = &session_options;
698 opt_options.graph = &graph;
699 opt_options.flib_def = &flib_def;
700 LowerFunctionalOpsPass pass;
701 TF_ASSERT_OK(pass.Run(opt_options));
702
703 if (transfer) {
704 // Insert Send/Recv between LoopCond and Switch. This can represent
705 // distributed training loop which has been used widely in TF2.
706 for (Node* node : graph->nodes()) {
707 if (node->type_string() != "LoopCond") {
708 continue;
709 }
710
711 for (const Edge* edge : node->out_edges()) {
712 if (edge->dst()->type_string() != "Switch") {
713 continue;
714 }
715 string tensor_name = strings::StrCat("c", edge->id());
716 TF_ASSERT_OK(ReplaceEdgeWithSendRecv(graph.get(), edge, tensor_name,
717 BOB, 1, ALICE));
718 }
719 }
720 }
721 }
722
723 SessionOptions options;
724 options.config.set_inter_op_parallelism_threads(4);
725 FixupSourceAndSinkEdges(graph.get());
726 test::Benchmark("cpu", graph.release(), &options, nullptr, nullptr, "",
727 /*old_benchmark_api=*/false)
728 .Run(state);
729 }
730
BM_LoweredWhileLoop(::testing::benchmark::State & state)731 static void BM_LoweredWhileLoop(::testing::benchmark::State& state) {
732 const int loop_iters = state.range(0);
733 const int loop_vars = state.range(1);
734
735 BM_WhileLoopHelper(state, loop_iters, loop_vars, /* lower= */ true,
736 /* transfer= */ false);
737 }
738 BENCHMARK(BM_LoweredWhileLoop)
739 ->ArgPair(0, 1)
740 ->ArgPair(1, 1)
741 ->ArgPair(10, 1)
742 ->ArgPair(100, 1)
743 ->ArgPair(1000, 1)
744 ->ArgPair(0, 100)
745 ->ArgPair(1, 100)
746 ->ArgPair(10, 100)
747 ->ArgPair(100, 100)
748 ->ArgPair(1000, 100);
749
BM_LoweredWhileLoopWithTransfer(::testing::benchmark::State & state)750 static void BM_LoweredWhileLoopWithTransfer(
751 ::testing::benchmark::State& state) {
752 const int loop_iters = state.range(0);
753 const int loop_vars = state.range(1);
754
755 BM_WhileLoopHelper(state, loop_iters, loop_vars, /* lower= */ true,
756 /* transfer= */ true);
757 }
758 BENCHMARK(BM_LoweredWhileLoopWithTransfer)
759 ->ArgPair(0, 100)
760 ->ArgPair(1, 100)
761 ->ArgPair(10, 100)
762 ->ArgPair(100, 100)
763 ->ArgPair(1000, 100)
764 ->ArgPair(1, 5000)
765 ->ArgPair(10, 5000)
766 ->ArgPair(100, 5000)
767 ->ArgPair(1000, 5000);
768
BM_FunctionalWhileLoop(::testing::benchmark::State & state)769 static void BM_FunctionalWhileLoop(::testing::benchmark::State& state) {
770 const int loop_iters = state.range(0);
771 const int loop_vars = state.range(1);
772
773 BM_WhileLoopHelper(state, loop_iters, loop_vars, /* lower= */ false,
774 /* transfer= */ false);
775 }
776 BENCHMARK(BM_FunctionalWhileLoop)
777 ->ArgPair(0, 1)
778 ->ArgPair(1, 1)
779 ->ArgPair(10, 1)
780 ->ArgPair(100, 1)
781 ->ArgPair(1000, 1)
782 ->ArgPair(0, 100)
783 ->ArgPair(1, 100)
784 ->ArgPair(10, 100)
785 ->ArgPair(100, 100)
786 ->ArgPair(1000, 100);
787 } // namespace tensorflow
788