xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/executor_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/executor.h"
17 
18 #include <algorithm>
19 
20 #include "tensorflow/cc/framework/ops.h"
21 #include "tensorflow/cc/ops/array_ops.h"
22 #include "tensorflow/cc/ops/const_op.h"
23 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
24 #include "tensorflow/cc/ops/function_ops.h"
25 #include "tensorflow/cc/ops/standard_ops.h"
26 #include "tensorflow/core/common_runtime/device.h"
27 #include "tensorflow/core/common_runtime/device_factory.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
30 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
31 #include "tensorflow/core/common_runtime/process_util.h"
32 #include "tensorflow/core/common_runtime/step_stats_collector.h"
33 #include "tensorflow/core/framework/attr_value.pb.h"
34 #include "tensorflow/core/framework/op.h"
35 #include "tensorflow/core/framework/rendezvous.h"
36 #include "tensorflow/core/framework/step_stats.pb.h"
37 #include "tensorflow/core/framework/tensor_testutil.h"
38 #include "tensorflow/core/framework/versions.pb.h"
39 #include "tensorflow/core/graph/algorithm.h"
40 #include "tensorflow/core/graph/testlib.h"
41 #include "tensorflow/core/lib/core/status_test_util.h"
42 #include "tensorflow/core/lib/random/simple_philox.h"
43 #include "tensorflow/core/lib/strings/strcat.h"
44 #include "tensorflow/core/platform/logging.h"
45 #include "tensorflow/core/platform/strcat.h"
46 #include "tensorflow/core/platform/test.h"
47 #include "tensorflow/core/platform/test_benchmark.h"
48 #include "tensorflow/core/platform/tracing.h"
49 #include "tensorflow/core/public/session_options.h"
50 
51 namespace tensorflow {
52 
53 class ExecutorTest : public ::testing::Test {
54  protected:
ExecutorTest()55   ExecutorTest()
56       : device_(DeviceFactory::NewDevice("CPU", {},
57                                          "/job:localhost/replica:0/task:0")),
58 
59         step_stats_collector_(&step_stats_) {
60     SessionOptions options;
61     thread_pool_ = ComputePool(options);
62   }
63 
~ExecutorTest()64   ~ExecutorTest() override {
65     // LocalRendezvous::AsyncRecv() might still executing after done_callback
66     // returns. Wait until the local rc_owner_ releases.
67     while (!rendez_->RefCountIsOne()) {
68     }
69     // There should always be exactly one Ref left on the Rendezvous
70     // when the test completes.
71     CHECK(rendez_->Unref());
72     delete exec_;
73   }
74 
75   // Resets executor_ with a new executor based on a graph 'gdef'.
Create(std::unique_ptr<const Graph> graph)76   void Create(std::unique_ptr<const Graph> graph) {
77     const int version = graph->versions().producer();
78     LocalExecutorParams params;
79     params.device = device_.get();
80     params.create_kernel =
81         [this, version](const std::shared_ptr<const NodeProperties>& props,
82                         OpKernel** kernel) {
83           return CreateNonCachedKernel(device_.get(), nullptr, props, version,
84                                        kernel);
85         };
86     params.delete_kernel = [](OpKernel* kernel) {
87       DeleteNonCachedKernel(kernel);
88     };
89     rendez_ = NewLocalRendezvous();
90     delete exec_;
91     TF_CHECK_OK(NewLocalExecutor(params, *graph, &exec_));
92     runner_ = [this](std::function<void()> fn) { thread_pool_->Schedule(fn); };
93   }
94 
Run(Rendezvous * rendez)95   Status Run(Rendezvous* rendez) {
96     Executor::Args args;
97     args.rendezvous = rendez;
98     args.stats_collector = &step_stats_collector_;
99     args.runner = runner_;
100     return exec_->Run(args);
101   }
102 
103   thread::ThreadPool* thread_pool_ = nullptr;
104   std::unique_ptr<Device> device_;
105   Executor* exec_ = nullptr;
106   StepStatsCollector step_stats_collector_;
107   StepStats step_stats_;
108   Executor::Args::Runner runner_;
109   Rendezvous* rendez_ = nullptr;
110 };
111 
112 // A float val -> Tensor<float>
V(const float val)113 Tensor V(const float val) {
114   Tensor tensor(DT_FLOAT, TensorShape({}));
115   tensor.scalar<float>()() = val;
116   return tensor;
117 }
118 
119 // A int32 val -> Tensor<int32>
VI(const int32_t val)120 Tensor VI(const int32_t val) {
121   Tensor tensor(DT_INT32, TensorShape({}));
122   tensor.scalar<int32>()() = val;
123   return tensor;
124 }
125 
126 // A bool val -> Tensor<bool>
VB(const bool val)127 Tensor VB(const bool val) {
128   Tensor tensor(DT_BOOL, TensorShape({}));
129   tensor.scalar<bool>()() = val;
130   return tensor;
131 }
132 
133 // A double val -> Tensor<double>
VD(const double val)134 Tensor VD(const double val) {
135   Tensor tensor(DT_DOUBLE, TensorShape({}));
136   tensor.scalar<double>()() = val;
137   return tensor;
138 }
139 
140 // Tensor<float> -> a float val.
V(const Tensor & tensor)141 float V(const Tensor& tensor) {
142   CHECK_EQ(tensor.dtype(), DT_FLOAT);
143   CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
144   return tensor.scalar<float>()();
145 }
146 
147 static uint64 kIncarnation = 1;  // Uses in following tests.
148 
Key(const string & sender,const uint64 incarnation,const string & receiver,const string & name)149 Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation,
150                           const string& receiver, const string& name) {
151   Rendezvous::ParsedKey result;
152   CHECK(
153       Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver,
154                                                  name, FrameAndIter(0, 0)),
155                            &result)
156           .ok());
157   return result;
158 }
159 
160 #define ALICE "/job:j/replica:0/task:0/cpu:0"
161 #define BOB "/job:j/replica:0/task:0/device:GPU:0"
162 
TEST_F(ExecutorTest,SimpleAdd)163 TEST_F(ExecutorTest, SimpleAdd) {
164   // c = a + b
165   auto g = std::make_unique<Graph>(OpRegistry::Global());
166   auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB);
167   auto in1 = test::graph::Recv(g.get(), "b", "float", ALICE, 1, BOB);
168   auto tmp = test::graph::Add(g.get(), in0, in1);
169   test::graph::Send(g.get(), tmp, "c", BOB, 1, ALICE);
170   Create(std::move(g));
171   Rendezvous::Args args;
172   TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0),
173                              false));  // in0 = 1.0
174   TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "b"), args, V(1.0),
175                              false));  // in1 = 1.0
176   TF_ASSERT_OK(Run(rendez_));
177   Tensor out = V(-1);
178   bool is_dead = false;
179   TF_ASSERT_OK(
180       rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead));
181   EXPECT_EQ(2.0, V(out));  // out = 1.0 + 1.0 = 2.0
182 }
183 
TEST_F(ExecutorTest,SelfAdd)184 TEST_F(ExecutorTest, SelfAdd) {
185   // v0 <- a
186   // v1 = v0 + v0
187   // v2 = v1 + v1
188   // ... ...
189   // v10 = v9 + v9
190   //
191   // b <- v10
192   // All nodes are executed by one thread.
193   auto g = std::make_unique<Graph>(OpRegistry::Global());
194   auto v = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB);
195   const int N = 10;
196   for (int i = 1; i <= N; ++i) {
197     v = test::graph::Add(g.get(), v, v);
198   }
199   // out <- v10
200   test::graph::Send(g.get(), v, "b", BOB, 1, ALICE);
201   Create(std::move(g));
202   Rendezvous::Args args;
203   // a = 1.0
204   TF_ASSERT_OK(
205       rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false));
206   TF_ASSERT_OK(Run(rendez_));
207   Tensor out = V(-1);
208   bool is_dead = false;
209   TF_ASSERT_OK(
210       rendez_->Recv(Key(BOB, kIncarnation, ALICE, "b"), args, &out, &is_dead));
211   EXPECT_EQ(1024.0, V(out));  // b=v10=2*v9=4*v8=...=1024*a=1024.0
212 }
213 
214 // Builds a graph which adds N copies of one variable "in". I.e.,
215 //     a + a + a + ... + a
216 // The returned graph is parenthesized ramdonly. I.e.,
217 //     a + ((a + a) + a)
218 //     (a + a) + (a + a)
219 //     ((a + a) + a) + a
220 // are all possibly generated.
BuildTree(int N,Graph * g)221 void BuildTree(int N, Graph* g) {
222   CHECK_GT(N, 1);
223   // A single input node "in".
224   auto in = test::graph::Recv(g, "a", "float", ALICE, 1, BOB);
225   std::vector<Node*> nodes;
226   int i = 0;
227   // Duplicate "in" N times. Each copies is named as l0, l1, l2, ....
228   for (; i < N; ++i) {
229     nodes.push_back(test::graph::Identity(g, in, 0));
230   }
231   random::PhiloxRandom philox(testing::RandomSeed(), 17);
232   random::SimplePhilox rnd(&philox);
233   while (nodes.size() > 1) {
234     // Randomly pick two from nodes and add them. The resulting node
235     // is named lik n10, n11, .... and is put back into "nodes".
236     int x = rnd.Uniform(nodes.size());
237     auto in0 = nodes[x];
238     nodes[x] = nodes.back();
239     nodes.resize(nodes.size() - 1);
240     x = rnd.Uniform(nodes.size());
241     auto in1 = nodes[x];
242     // node = in0 + in1.
243     nodes[x] = test::graph::Add(g, in0, in1);
244   }
245   // The final output node "out".
246   test::graph::Send(g, nodes.back(), "b", BOB, 1, ALICE);
247 }
248 
TEST_F(ExecutorTest,RandomTree)249 TEST_F(ExecutorTest, RandomTree) {
250   auto g = std::make_unique<Graph>(OpRegistry::Global());
251   BuildTree(4096, g.get());
252   Create(std::move(g));
253   Rendezvous::Args args;
254   TF_ASSERT_OK(
255       rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false));
256   TF_ASSERT_OK(Run(rendez_));
257   Tensor out = V(-1);
258   bool is_dead = false;
259   TF_ASSERT_OK(
260       rendez_->Recv(Key(BOB, kIncarnation, ALICE, "b"), args, &out, &is_dead));
261   EXPECT_EQ(4096.0, V(out));
262 }
263 
BuildConcurrentAddAssign(Graph * g)264 void BuildConcurrentAddAssign(Graph* g) {
265   auto one = test::graph::Constant(g, V(1.0));
266   // A variable holds one float.
267   auto var = test::graph::Var(g, DT_FLOAT, TensorShape({}));
268   // Initialize the variable with 1.0.
269   auto init = test::graph::Assign(g, var, one);
270   // Output
271   auto out = test::graph::Send(g, var, "out", ALICE, kIncarnation, BOB);
272   // Have many concurrent computation. Each does v = v + 1.
273   for (int i = 0; i < 1024; ++i) {
274     auto add = test::graph::Add(g, var, one);
275     g->AddControlEdge(init, add);  // Ensures run after init.
276     auto assign = test::graph::Assign(g, var, add);
277     g->AddControlEdge(assign, out);
278   }
279 }
280 
281 #ifndef THREAD_SANITIZER
TEST_F(ExecutorTest,ConcurrentAddAssign)282 TEST_F(ExecutorTest, ConcurrentAddAssign) {
283   auto g = std::make_unique<Graph>(OpRegistry::Global());
284   BuildConcurrentAddAssign(g.get());
285   Create(std::move(g));
286   for (int iters = 0; iters < 16; ++iters) {
287     Rendezvous* rendez = NewLocalRendezvous();
288     TF_ASSERT_OK(Run(rendez));
289     Rendezvous::Args args;
290     Tensor out;
291     bool is_dead;
292     TF_ASSERT_OK(rendez->Recv(Key(ALICE, kIncarnation, BOB, "out"), args, &out,
293                               &is_dead));
294     VLOG(1) << "Get " << V(out);
295     EXPECT_LE(V(out), 1025.0);
296     rendez->Unref();
297   }
298 }
299 #endif
300 
TEST_F(ExecutorTest,SimpleSwitchLive)301 TEST_F(ExecutorTest, SimpleSwitchLive) {
302   auto g = std::make_unique<Graph>(OpRegistry::Global());
303   auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB);
304   auto in1 = test::graph::Constant(g.get(), VB(false));
305   auto tmp = test::graph::Switch(g.get(), in0, in1);
306   test::graph::Send(g.get(), tmp, "c", BOB, 1, ALICE);
307   Create(std::move(g));
308   Rendezvous::Args args;
309   TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0),
310                              false));  // in0 = 1.0
311   TF_ASSERT_OK(Run(rendez_));
312   Tensor out = V(-1);
313   bool is_dead = false;
314   TF_ASSERT_OK(
315       rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead));
316   EXPECT_EQ(1.0, V(out));  // out = 1.0
317   EXPECT_FALSE(is_dead);
318 }
319 
TEST_F(ExecutorTest,SimpleSwitchDead)320 TEST_F(ExecutorTest, SimpleSwitchDead) {
321   auto g = std::make_unique<Graph>(OpRegistry::Global());
322   auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB);
323   auto in1 = test::graph::Constant(g.get(), VB(true));
324   auto tmp = test::graph::Switch(g.get(), in0, in1);
325   test::graph::Send(g.get(), tmp, "c", BOB, 1, ALICE);
326   Create(std::move(g));
327   Rendezvous::Args args;
328   TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0),
329                              false));  // in0 = 1.0
330   TF_ASSERT_OK(Run(rendez_));
331   Tensor out = V(-1);
332   bool is_dead = false;
333   TF_ASSERT_OK(
334       rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead));
335   EXPECT_TRUE(is_dead);
336 }
337 
TEST_F(ExecutorTest,Abort)338 TEST_F(ExecutorTest, Abort) {
339   // e = a + b + c + d
340   auto g = std::make_unique<Graph>(OpRegistry::Global());
341   auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB);
342   auto in1 = test::graph::Recv(g.get(), "b", "float", ALICE, 1, BOB);
343   auto in2 = test::graph::Recv(g.get(), "c", "float", ALICE, 1, BOB);
344   auto in3 = test::graph::Recv(g.get(), "d", "float", ALICE, 1, BOB);
345   auto add0 = test::graph::Add(g.get(), in0, in1);
346   auto add1 = test::graph::Add(g.get(), in2, in3);
347   auto add2 = test::graph::Add(g.get(), add0, add1);
348   test::graph::Send(g.get(), add2, "e", BOB, 1, ALICE);
349   Create(std::move(g));
350 
351   // Needs 4 inputs (recv). One of them is aborted.
352   rendez_->Ref();
353   SchedClosure([this]() {
354     Env::Default()->SleepForMicroseconds(100 * 1000);
355     Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"),
356                              Rendezvous::Args(), V(1.0), false);
357     rendez_->Unref();
358   });
359   rendez_->Ref();
360   SchedClosure([this]() {
361     Env::Default()->SleepForMicroseconds(100 * 1000);
362     Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "b"),
363                              Rendezvous::Args(), V(1.0), false);
364     rendez_->Unref();
365   });
366   rendez_->Ref();
367   SchedClosure([this]() {
368     Env::Default()->SleepForMicroseconds(100 * 1000);
369     Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "c"),
370                              Rendezvous::Args(), V(1.0), false);
371     rendez_->Unref();
372   });
373   rendez_->Ref();
374   SchedClosure([this]() {
375     Env::Default()->SleepForMicroseconds(100 * 1000);
376     rendez_->StartAbort(errors::Aborted(""));
377     rendez_->Unref();
378   });
379   EXPECT_TRUE(errors::IsAborted(Run(rendez_)));
380   Tensor out = V(-1);
381   bool is_dead = false;
382   EXPECT_TRUE(errors::IsAborted(rendez_->Recv(
383       Key(BOB, kIncarnation, ALICE, "c"), Rendezvous::Args(), &out, &is_dead)));
384   // At this point there can still be pending (albeit Aborted) Send
385   // closures holding Refs on rendez_.  We need to wait for them, or
386   // else there can be a memory leak at termination. This wait logic is in test
387   // dtor.
388 }
389 
TEST_F(ExecutorTest,RecvInvalidDtype)390 TEST_F(ExecutorTest, RecvInvalidDtype) {
391   auto g = std::make_unique<Graph>(OpRegistry::Global());
392   // An input vector of type float of size 1.
393   auto one = test::graph::Recv(g.get(), "one", "float", ALICE, 1, BOB);
394   // A floating point variable vector of size 1.
395   auto var = test::graph::Var(g.get(), DT_FLOAT, TensorShape({1}));
396   // Initialize the variable with input.
397   auto init = test::graph::Assign(g.get(), var, one);
398   // Output
399   auto* two = test::graph::Send(g.get(), var, "two", BOB, 1, ALICE);
400   g->AddControlEdge(init, two);  // Ensures run after init.
401   Create(std::move(g));
402   Rendezvous* rendez = NewLocalRendezvous();
403   // Send a double instead of float.
404   TF_ASSERT_OK(rendez->Send(Key(ALICE, 1, BOB, "one"), Rendezvous::Args(),
405                             VD(1.0), false));
406   // Fails due to invalid dtype.
407   EXPECT_TRUE(errors::IsInternal(Run(rendez)));
408   Tensor output;
409   bool is_dead;
410   EXPECT_TRUE(errors::IsInternal(rendez->Recv(
411       Key(BOB, 1, ALICE, "two"), Rendezvous::Args(), &output, &is_dead)));
412   rendez->Unref();
413 }
414 
TEST_F(ExecutorTest,RecvInvalidRefDtype)415 TEST_F(ExecutorTest, RecvInvalidRefDtype) {
416   auto g = std::make_unique<Graph>(OpRegistry::Global());
417   // A var that always produces as invalid dtype.
418   auto var = test::graph::InvalidRefType(g.get(), DT_FLOAT, DT_DOUBLE);
419   test::graph::Send(g.get(), var, "out", BOB, 1, ALICE);
420   Create(std::move(g));
421   Rendezvous* rendez = NewLocalRendezvous();
422   EXPECT_TRUE(errors::IsInternal(Run(rendez)));
423   Tensor output;
424   bool is_dead;
425   EXPECT_TRUE(errors::IsInternal(rendez->Recv(
426       Key(BOB, 1, ALICE, "out"), Rendezvous::Args(), &output, &is_dead)));
427   rendez->Unref();
428 }
429 
TEST_F(ExecutorTest,NoInputTensors)430 TEST_F(ExecutorTest, NoInputTensors) {
431   // Create a graph where none of the nodes have input tensors.
432   auto g = std::make_unique<Graph>(OpRegistry::Global());
433   test::graph::Constant(g.get(), V(1.0));
434   Create(std::move(g));
435   TF_ASSERT_OK(Run(rendez_));
436 }
437 
438 // Create a graph that is 'depth' deep. At each level, fan-in and fan-out a
439 // maximum of 'width' nodes. All nodes are no-ops and all dependencies are
440 // control dependencies.
BM_executor(::testing::benchmark::State & state)441 static void BM_executor(::testing::benchmark::State& state) {
442   const int width = state.range(0);
443   const int depth = state.range(1);
444 
445   Graph* g = new Graph(OpRegistry::Global());
446   random::PhiloxRandom philox(1729, 17);
447   random::SimplePhilox rand(&philox);
448   uint64 cur = 0;
449   uint32 r = 1 + rand.Rand32() % width;
450   std::vector<Node*> ready_nodes;
451   for (int i = 0; i < r; ++i) {
452     ready_nodes.push_back(test::graph::NoOp(g, {}));
453     ++cur;
454   }
455   std::random_device random_device;
456   std::mt19937 rng(random_device());
457   for (int i = 0; i < depth; ++i) {
458     std::shuffle(ready_nodes.begin(), ready_nodes.end(), rng);
459     r = 1 + rand.Rand32() % (ready_nodes.size());
460     std::vector<Node*> control_inputs;
461     for (int j = 0; j < r; ++j) {
462       control_inputs.push_back(ready_nodes.back());
463       ready_nodes.pop_back();
464     }
465     Node* n = test::graph::NoOp(g, control_inputs);
466     ++cur;
467     r = 1 + rand.Rand32() % width;
468     for (int j = 0; j < r; ++j) {
469       ready_nodes.push_back(test::graph::NoOp(g, {n}));
470       ++cur;
471     }
472   }
473 
474   FixupSourceAndSinkEdges(g);
475   test::Benchmark("cpu", g, /*old_benchmark_api=*/false).Run(state);
476 
477   state.SetLabel(strings::StrCat("Nodes = ", cur));
478   state.SetItemsProcessed(cur * static_cast<int64_t>(state.iterations()));
479 }
480 
481 // Tall skinny graphs
482 BENCHMARK(BM_executor)->UseRealTime()->ArgPair(16, 1024);
483 BENCHMARK(BM_executor)->UseRealTime()->ArgPair(32, 8192);
484 
485 // Short fat graphs
486 BENCHMARK(BM_executor)->UseRealTime()->ArgPair(1024, 16);
487 BENCHMARK(BM_executor)->UseRealTime()->ArgPair(8192, 32);
488 
489 // Tall fat graph
490 BENCHMARK(BM_executor)->UseRealTime()->ArgPair(1024, 1024);
491 
BM_const_identity(::testing::benchmark::State & state)492 static void BM_const_identity(::testing::benchmark::State& state) {
493   const int width = state.range(0);
494   const int outputs_per_const = state.range(1);
495 
496   Graph* g = new Graph(OpRegistry::Global());
497   for (int i = 0; i < width; ++i) {
498     Tensor i_t(i);
499     Node* const_node = test::graph::Constant(g, i_t);
500     for (int j = 0; j < outputs_per_const; ++j) {
501       test::graph::Identity(g, const_node);
502     }
503   }
504   FixupSourceAndSinkEdges(g);
505   test::Benchmark("cpu", g, /*old_benchmark_api=*/false).Run(state);
506   state.SetLabel(strings::StrCat("Nodes = ", (1 + outputs_per_const) * width));
507   state.SetItemsProcessed((1 + outputs_per_const) * width *
508                           static_cast<int64_t>(state.iterations()));
509 }
510 
511 // Graph with actual op execution.
512 BENCHMARK(BM_const_identity)
513     ->UseRealTime()
514     ->ArgPair(1, 1)
515     ->ArgPair(1, 100)
516     ->ArgPair(100, 1)
517     ->ArgPair(100, 100);
518 
BM_FeedInputFetchOutput(::testing::benchmark::State & state)519 static void BM_FeedInputFetchOutput(::testing::benchmark::State& state) {
520   Graph* g = new Graph(OpRegistry::Global());
521   // z = x + y: x and y are provided as benchmark inputs.  z is the
522   // output of the benchmark.  Conceptually, the caller is ALICE, the
523   // benchmark is BOB.
524   Node* x = test::graph::Recv(g, "x", "float", ALICE, 1, BOB);
525   Node* y = test::graph::Recv(g, "y", "float", ALICE, 1, BOB);
526   Node* sum = test::graph::Add(g, x, y);
527   Node* z = test::graph::Send(g, sum, "z", BOB, 1, ALICE);
528 
529   string x_key = test::GetRendezvousKey(x);
530   string y_key = test::GetRendezvousKey(y);
531   string z_key = test::GetRendezvousKey(z);
532 
533   Tensor val(DT_FLOAT, TensorShape({}));
534   val.scalar<float>()() = 3.14;
535   FixupSourceAndSinkEdges(g);
536   test::Benchmark("cpu", g, /*old_benchmark_api=*/false)
537       .RunWithRendezvousArgs({{x_key, val}, {y_key, val}}, {z_key}, state);
538   state.SetItemsProcessed(static_cast<int64_t>(state.iterations()));
539 }
540 BENCHMARK(BM_FeedInputFetchOutput);
541 
ReplaceEdgeWithSendRecv(Graph * g,const Edge * edge,const string & tensor,const string & sender,const uint64 sender_incarnation,const string & receiver)542 Status ReplaceEdgeWithSendRecv(Graph* g, const Edge* edge, const string& tensor,
543                                const string& sender,
544                                const uint64 sender_incarnation,
545                                const string& receiver) {
546   Node* send;
547   NodeDef send_def;
548   TF_CHECK_OK(NodeDefBuilder(g->NewName("n"), "_Send")
549                   .Input(edge->src()->name(), edge->src_output(),
550                          edge->src()->output_type(edge->src_output()))
551                   .Attr("tensor_name", tensor)
552                   .Attr("send_device", sender)
553                   .Attr("send_device_incarnation",
554                         static_cast<int64_t>(sender_incarnation))
555                   .Attr("recv_device", receiver)
556                   .Finalize(&send_def));
557 
558   TF_ASSIGN_OR_RETURN(send, g->AddNode(send_def));
559 
560   Node* recv;
561   NodeDef recv_def;
562   TF_CHECK_OK(
563       NodeDefBuilder(g->NewName("n"), "_Recv")
564           .Attr("tensor_name", tensor)
565           .Attr("send_device", sender)
566           .Attr("send_device_incarnation",
567                 static_cast<int64_t>(sender_incarnation))
568           .Attr("recv_device", receiver)
569           .Attr("tensor_type", edge->dst()->input_type(edge->dst_input()))
570           .Finalize(&recv_def));
571 
572   TF_ASSIGN_OR_RETURN(recv, g->AddNode(recv_def));
573 
574   g->AddEdge(edge->src(), edge->src_output(), send, 0);
575   g->AddEdge(recv, 0, edge->dst(), edge->dst_input());
576 
577   // This control dependency can ensure Exit op can still be downstream
578   // op of Enter after inserting Send/Recv.
579   g->AddControlEdge(edge->src(), recv);
580 
581   g->RemoveEdge(edge);
582   return OkStatus();
583 }
584 
585 // Defines a graph to perform the following computation:
586 //
587 //     i = 0
588 //     while (i < loop_iters)
589 //       i += 1;
590 //
591 // ...using the functional `WhileOp` (if `lower` is false) or the
592 // `Switch`/`Merge`-style of control flow (if `lower` is true).
BM_WhileLoopHelper(::testing::benchmark::State & state,int loop_iters,int loop_vars,bool lower,bool transfer)593 static void BM_WhileLoopHelper(::testing::benchmark::State& state,
594                                int loop_iters, int loop_vars, bool lower,
595                                bool transfer) {
596   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
597 
598   // Add test functions for cond and body.
599   FunctionDefLibrary f_lib_proto;
600 
601   // Define the loop body as a function: `x = x + 1`.
602   const Tensor one_t = test::AsScalar<int32>(1);
603 
604   std::vector<string> args;
605   args.reserve(loop_vars);
606   args.push_back("x: int32");
607   for (int i = 1; i < loop_vars; ++i) {
608     args.push_back(strings::StrCat("x", i, ": int32"));
609   }
610 
611   std::vector<string> body_rets;
612   body_rets.reserve(loop_vars);
613   body_rets.push_back("y: int32");
614   for (int i = 1; i < loop_vars; ++i) {
615     body_rets.push_back(strings::StrCat("y", i, ": int32"));
616   }
617 
618   std::vector<FunctionDefHelper::Node> body_nodes;
619   body_nodes.reserve(1 + loop_vars);
620   body_nodes.push_back(
621       {{"one"}, "Const", {}, {{"value", one_t}, {"dtype", DT_INT32}}});
622   body_nodes.push_back({{"y"}, "Add", {"x", "one"}, {{"T", DT_INT32}}});
623   for (int i = 1; i < loop_vars; ++i) {
624     body_nodes.push_back({{strings::StrCat("y", i)},
625                           "Relu",
626                           {strings::StrCat("x", i)},
627                           {{"T", DT_INT32}}});
628   }
629 
630   *f_lib_proto.add_function() = FunctionDefHelper::Define(
631       // Name
632       "XPlusOne",
633       // Args
634       args,
635       // Return values
636       body_rets,
637       // Attr def
638       {},
639       // Nodes
640       body_nodes);
641 
642   // Define the loop condition as a function: `x < loop_iters`.
643   const Tensor loop_iters_t = test::AsScalar<int32>(loop_iters);
644   *f_lib_proto.add_function() = FunctionDefHelper::Define(
645       // Name
646       "LessThanOrEqualToN",
647       // Args
648       args,
649       // Return values
650       {"z: bool"},
651       // Attr def
652       {},
653       // Nodes
654       {
655           {{"N"}, "Const", {}, {{"value", loop_iters_t}, {"dtype", DT_INT32}}},
656           {{"z"}, "LessEqual", {"x", "N"}, {{"T", DT_INT32}}},
657       });
658 
659   Scope root = Scope::NewRootScope().ExitOnError();
660   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
661   auto a = ops::Const(root.WithOpName("A"), 0, {});
662   Node* while_node;
663   std::vector<NodeBuilder::NodeOut> inputs;
664   std::vector<DataType> input_types(loop_vars, DT_INT32);
665   inputs.reserve(loop_vars);
666   for (int i = 0; i < loop_vars; ++i) {
667     inputs.push_back(NodeBuilder::NodeOut(a.node()));
668   }
669   AttrValue int32_attr;
670   int32_attr.set_type(DT_INT32);
671   AttrValue cond_func;
672   cond_func.mutable_func()->set_name("LessThanOrEqualToN");
673   AttrValue body_func;
674   body_func.mutable_func()->set_name("XPlusOne");
675   TF_ASSERT_OK(
676       NodeBuilder("while", "While", &root.graph()->flib_def())
677           .Input(inputs)
678           .Attr("T", input_types)
679           .Attr("cond", cond_func)
680           .Attr("body", body_func)
681           .Attr("parallel_iterations", 20)
682           .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
683           .Finalize(root.graph(), &while_node));
684   auto c = ops::Identity(
685       root.WithOpName("C").WithControlDependencies(Output(while_node)),
686       Output(while_node));
687   TF_ASSERT_OK(root.DoShapeInference(while_node));
688   TF_ASSERT_OK(root.ToGraph(graph.get()));
689 
690   if (lower) {
691     FunctionLibraryDefinition flib_def(graph->flib_def());
692     GraphOptimizationPassOptions opt_options;
693     SessionOptions session_options;
694     session_options.config.mutable_graph_options()
695         ->mutable_optimizer_options()
696         ->set_do_function_inlining(true);
697     opt_options.session_options = &session_options;
698     opt_options.graph = &graph;
699     opt_options.flib_def = &flib_def;
700     LowerFunctionalOpsPass pass;
701     TF_ASSERT_OK(pass.Run(opt_options));
702 
703     if (transfer) {
704       // Insert Send/Recv between LoopCond and Switch. This can represent
705       // distributed training loop which has been used widely in TF2.
706       for (Node* node : graph->nodes()) {
707         if (node->type_string() != "LoopCond") {
708           continue;
709         }
710 
711         for (const Edge* edge : node->out_edges()) {
712           if (edge->dst()->type_string() != "Switch") {
713             continue;
714           }
715           string tensor_name = strings::StrCat("c", edge->id());
716           TF_ASSERT_OK(ReplaceEdgeWithSendRecv(graph.get(), edge, tensor_name,
717                                                BOB, 1, ALICE));
718         }
719       }
720     }
721   }
722 
723   SessionOptions options;
724   options.config.set_inter_op_parallelism_threads(4);
725   FixupSourceAndSinkEdges(graph.get());
726   test::Benchmark("cpu", graph.release(), &options, nullptr, nullptr, "",
727                   /*old_benchmark_api=*/false)
728       .Run(state);
729 }
730 
BM_LoweredWhileLoop(::testing::benchmark::State & state)731 static void BM_LoweredWhileLoop(::testing::benchmark::State& state) {
732   const int loop_iters = state.range(0);
733   const int loop_vars = state.range(1);
734 
735   BM_WhileLoopHelper(state, loop_iters, loop_vars, /* lower= */ true,
736                      /* transfer= */ false);
737 }
738 BENCHMARK(BM_LoweredWhileLoop)
739     ->ArgPair(0, 1)
740     ->ArgPair(1, 1)
741     ->ArgPair(10, 1)
742     ->ArgPair(100, 1)
743     ->ArgPair(1000, 1)
744     ->ArgPair(0, 100)
745     ->ArgPair(1, 100)
746     ->ArgPair(10, 100)
747     ->ArgPair(100, 100)
748     ->ArgPair(1000, 100);
749 
BM_LoweredWhileLoopWithTransfer(::testing::benchmark::State & state)750 static void BM_LoweredWhileLoopWithTransfer(
751     ::testing::benchmark::State& state) {
752   const int loop_iters = state.range(0);
753   const int loop_vars = state.range(1);
754 
755   BM_WhileLoopHelper(state, loop_iters, loop_vars, /* lower= */ true,
756                      /* transfer= */ true);
757 }
758 BENCHMARK(BM_LoweredWhileLoopWithTransfer)
759     ->ArgPair(0, 100)
760     ->ArgPair(1, 100)
761     ->ArgPair(10, 100)
762     ->ArgPair(100, 100)
763     ->ArgPair(1000, 100)
764     ->ArgPair(1, 5000)
765     ->ArgPair(10, 5000)
766     ->ArgPair(100, 5000)
767     ->ArgPair(1000, 5000);
768 
BM_FunctionalWhileLoop(::testing::benchmark::State & state)769 static void BM_FunctionalWhileLoop(::testing::benchmark::State& state) {
770   const int loop_iters = state.range(0);
771   const int loop_vars = state.range(1);
772 
773   BM_WhileLoopHelper(state, loop_iters, loop_vars, /* lower= */ false,
774                      /* transfer= */ false);
775 }
776 BENCHMARK(BM_FunctionalWhileLoop)
777     ->ArgPair(0, 1)
778     ->ArgPair(1, 1)
779     ->ArgPair(10, 1)
780     ->ArgPair(100, 1)
781     ->ArgPair(1000, 1)
782     ->ArgPair(0, 100)
783     ->ArgPair(1, 100)
784     ->ArgPair(10, 100)
785     ->ArgPair(100, 100)
786     ->ArgPair(1000, 100);
787 }  // namespace tensorflow
788