xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/device/device_event_mgr_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17 
18 #include <atomic>
19 
20 #include "tensorflow/core/common_runtime/device/device_event_mgr.h"
21 #include "tensorflow/core/common_runtime/dma_helper.h"
22 #include "tensorflow/core/common_runtime/gpu/gpu_device.h"
23 #include "tensorflow/core/common_runtime/gpu/gpu_init.h"
24 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
25 #include "tensorflow/core/framework/fake_input.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/node_def_builder.h"
28 #include "tensorflow/core/graph/node_builder.h"
29 #include "tensorflow/core/lib/core/notification.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/platform/stream_executor.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/platform/test_benchmark.h"
34 #include "tensorflow/core/protobuf/config.pb.h"
35 #include "tensorflow/core/public/version.h"
36 
37 namespace tensorflow {
38 
39 // Subclass EventMgr to access its private constructor.
40 class TEST_EventMgr : public EventMgr {
41  public:
TEST_EventMgr(se::StreamExecutor * se,const GPUOptions & gpu_options)42   TEST_EventMgr(se::StreamExecutor* se, const GPUOptions& gpu_options)
43       : EventMgr(se, gpu_options) {}
44 };
45 
46 class TEST_EventMgrHelper {
47  public:
TEST_EventMgrHelper(EventMgr * em)48   explicit TEST_EventMgrHelper(EventMgr* em) : em_(em) {
49     // The polling loop can interfere with the measurements made here, and
50     // isn't needed since the member PollEvents() always clears the queue.
51     // The tested behavior is slightly different from what may occur in
52     // ordinary execution.
53     StopPollingLoop();
54   }
55 
queue_size()56   size_t queue_size() {
57     mutex_lock l(em_->mu_);
58     return em_->used_events_.size();
59   }
60 
free_size()61   size_t free_size() {
62     mutex_lock l(em_->mu_);
63     return em_->free_events_.size();
64   }
65 
PollEvents()66   void PollEvents() {
67     while (queue_size() > 0) {
68       // For ordinary tensor frees, this function
69       // should synchronously harvest all complete
70       // events and execute the corresponding memory frees.
71       EventMgr::ToFreeVector to_free;
72       {
73         mutex_lock l(em_->mu_);
74         em_->PollEvents(true, &to_free);
75       }
76       em_->FreeMemory(to_free);
77     }
78   }
79 
StopPollingLoop()80   void StopPollingLoop() { return em_->StopPollingLoop(); }
81 
StartPollingLoop()82   void StartPollingLoop() { return em_->StartPollingLoop(); }
83 
84  private:
85   EventMgr* em_;
86 };
87 
88 static std::atomic_int_fast64_t live_tensor_bytes(0);
89 
90 // A TensorBuffer that counts live memory usage for testing
91 class TestTensorBuffer : public TensorBuffer {
92  public:
TestTensorBuffer(size_t bytes)93   explicit TestTensorBuffer(size_t bytes)
94       : TensorBuffer(nullptr), bytes_(bytes) {
95     live_tensor_bytes += bytes_;
96   }
~TestTensorBuffer()97   ~TestTensorBuffer() override { live_tensor_bytes -= bytes_; }
98 
size() const99   size_t size() const override { return bytes_; }
100 
101   // Not used in this test
root_buffer()102   TensorBuffer* root_buffer() override { return nullptr; }
FillAllocationDescription(AllocationDescription * arg) const103   void FillAllocationDescription(AllocationDescription* arg) const override {}
104 
105  private:
106   size_t bytes_;
107 };
108 
109 namespace {
110 
TEST(EventMgr,Empty)111 TEST(EventMgr, Empty) {
112   auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
113   TEST_EventMgr em(stream_exec, GPUOptions());
114   TEST_EventMgrHelper th(&em);
115   EXPECT_EQ(0, th.queue_size());
116   EXPECT_EQ(0, th.free_size());
117 }
118 
119 // Tests that WarnIfInCallback() triggers correctly.
TEST(EventMgr,WarnIfInCallback)120 TEST(EventMgr, WarnIfInCallback) {
121   auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
122   TEST_EventMgr em(stream_exec, GPUOptions());
123   TEST_EventMgrHelper th(&em);
124   std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
125   CHECK(stream);
126   stream->Init();
127   bool hit = false;
128   th.StartPollingLoop();
129   device_event_mgr::WarnIfInCallback([&hit] { hit = true; });
130   EXPECT_FALSE(hit);
131   Notification note;
132   em.ThenExecute(stream.get(), [&hit, &note]() {
133     device_event_mgr::WarnIfInCallback([&hit, &note] {
134       hit = true;
135       note.Notify();
136     });
137   });
138   note.WaitForNotification();
139   EXPECT_TRUE(hit);
140 }
141 }  // namespace
142 
143 // Provides access to private resources of BaseGPUDevice.
144 class GPUDeviceTestHelper {
145  public:
GPUDeviceTestHelper(size_t memory_limit,int pending_cap)146   GPUDeviceTestHelper(size_t memory_limit, int pending_cap) {
147     SessionOptions sops;
148     device_ =
149         DeviceFactory::NewDevice(DEVICE_GPU, sops, "/job:a/replica:0/task:0");
150     gpu_.reset(reinterpret_cast<BaseGPUDevice*>(device_.release()));
151     gpu_allocator_ = GPUProcessState::singleton()->GetGPUAllocator(
152         GPUOptions(), TfDeviceId(0), memory_limit, /*peer_gpu_ids=*/{});
153     host_allocator_ = GPUProcessState::singleton()->GetGpuHostAllocator(0);
154   }
155 
gpu()156   BaseGPUDevice* gpu() { return gpu_.get(); }
gpu_allocator()157   Allocator* gpu_allocator() { return gpu_allocator_; }
host_allocator()158   Allocator* host_allocator() { return host_allocator_; }
compute_stream()159   se::Stream* compute_stream() { return gpu_->stream_->compute; }
h2d_stream()160   se::Stream* h2d_stream() { return gpu_->stream_->host_to_device; }
d2h_stream()161   se::Stream* d2h_stream() { return gpu_->stream_->device_to_host; }
d2d_stream()162   se::Stream* d2d_stream() { return gpu_->stream_->device_to_device[0]; }
event_mgr()163   EventMgr* event_mgr() { return gpu_->em_; }
pending_cap()164   int pending_cap() { return gpu_->pending_cap_; }
165 
166  private:
167   std::unique_ptr<Device> device_;
168   std::unique_ptr<BaseGPUDevice> gpu_;
169   Allocator* gpu_allocator_;
170   Allocator* host_allocator_;
171 };
172 
173 namespace {
174 
175 // Class that can queue some GPU data transfers and simple kernels.
176 class EMBenchmarkHelper {
177   GPUDeviceTestHelper* gpu_helper_;
178   // We need one of these for each Add op in the chain.
179   std::vector<std::unique_ptr<OpKernel>> add_kernels_;
180   std::vector<OpKernelContext::Params*> add_params_;
181   std::vector<std::unique_ptr<OpKernelContext>> add_contexts_;
182   // The rest of these are one per chain.
183   NodeDef add_node_def_;
184   NodeDef id_node_def_;
185   gtl::InlinedVector<TensorValue, 4> add_inputs_;
186   std::vector<AllocatorAttributes> allocator_attrs_;
187   gtl::InlinedVector<Tensor, 4> gpu_inputs_;
188   gtl::InlinedVector<Tensor, 4> gpu_outputs_;
189   gtl::InlinedVector<Tensor, 4> host_inputs_;
190   gtl::InlinedVector<Tensor, 4> host_outputs_;
191 
192  public:
193   // Length of tensors.  TODO(tucker): make this a variable parameter.
194   static constexpr int kTDim = 1024;
195 
num_ops() const196   int num_ops() const { return add_kernels_.size(); }
tensor_size() const197   size_t tensor_size() const {
198     return add_inputs_.empty() ? 0 : add_inputs_[0]->NumElements();
199   }
200 
host_outputs(int i)201   Tensor& host_outputs(int i) { return host_outputs_[i]; }
host_inputs(int i)202   Tensor& host_inputs(int i) { return host_inputs_[i]; }
203 
EMBenchmarkHelper(GPUDeviceTestHelper * h)204   EMBenchmarkHelper(GPUDeviceTestHelper* h) : gpu_helper_(h) {}
205 
ReInit(int num_ops,int tensor_size)206   void ReInit(int num_ops, int tensor_size) {
207     gpu_inputs_.clear();
208     while (gpu_inputs_.size() < 2) {
209       gpu_inputs_.push_back(Tensor(gpu_helper_->gpu_allocator(), DT_FLOAT,
210                                    {tensor_size}, AllocationAttributes()));
211     }
212     gpu_outputs_.clear();
213     while (gpu_outputs_.size() < 1) {
214       gpu_outputs_.push_back(Tensor(gpu_helper_->gpu_allocator(), DT_FLOAT,
215                                     {tensor_size}, AllocationAttributes()));
216     }
217     host_inputs_.clear();
218     while (host_inputs_.size() < 2) {
219       int instance_index = host_inputs_.size();
220       host_inputs_.push_back(Tensor(gpu_helper_->host_allocator(), DT_FLOAT,
221                                     {tensor_size}, AllocationAttributes()));
222       for (int i = 0; i < tensor_size; ++i) {
223         host_inputs_.back().flat<float>()(i) =
224             i * (1.0 + (0.5 * instance_index));
225       }
226     }
227     host_outputs_.clear();
228     while (host_outputs_.size() < 1) {
229       host_outputs_.push_back(Tensor(gpu_helper_->host_allocator(), DT_FLOAT,
230                                      {tensor_size}, AllocationAttributes()));
231       for (int i = 0; i < tensor_size; ++i) {
232         host_outputs_.back().flat<float>()(i) = -1;
233       }
234     }
235     add_kernels_.clear();
236     add_params_.clear();
237     while (add_kernels_.size() < num_ops) {
238       MakeAddOp();
239     }
240   }
241 
GetOpKernel(const NodeDef & node_def,Status * status)242   std::unique_ptr<OpKernel> GetOpKernel(const NodeDef& node_def,
243                                         Status* status) {
244     return CreateOpKernel("GPU", gpu_helper_->gpu(),
245                           gpu_helper_->gpu_allocator(), node_def,
246                           TF_GRAPH_DEF_VERSION, status);
247   }
248 
MakeAddOp()249   void MakeAddOp() {
250     if (add_kernels_.empty()) {
251       TF_ASSERT_OK(NodeDefBuilder("add_op", "Add")
252                        .Input(FakeInput(DT_FLOAT))
253                        .Input(FakeInput(DT_FLOAT))
254                        .Device("/job:a/replica:0/task:0/GPU:0")
255                        .Finalize(&add_node_def_));
256     }
257     Status status;
258     add_kernels_.emplace_back(GetOpKernel(add_node_def_, &status));
259     TF_ASSERT_OK(status);
260     add_params_.push_back(new OpKernelContext::Params);
261     PrepOpKernel(add_params_.back(), add_kernels_.back().get());
262   }
263 
SetOutputAttrs(OpKernelContext::Params * params,std::vector<AllocatorAttributes> * attrs)264   void SetOutputAttrs(OpKernelContext::Params* params,
265                       std::vector<AllocatorAttributes>* attrs) {
266     attrs->clear();
267     for (int index = 0; index < params->op_kernel->num_outputs(); index++) {
268       AllocatorAttributes attr;
269       const bool on_host =
270           (params->op_kernel->output_memory_types()[index] == HOST_MEMORY);
271       attr.set_on_host(on_host);
272       attrs->push_back(attr);
273     }
274     params->output_attr_array = attrs->data();
275     params->forward_from_array = {};
276   }
277 
PrepOpKernel(OpKernelContext::Params * params,OpKernel * kernel)278   void PrepOpKernel(OpKernelContext::Params* params, OpKernel* kernel) {
279     // This mimics what happens in ExecutorState::Process to run
280     // a single graph node.
281     params->step_id = 1;
282     params->device = gpu_helper_->gpu();
283     params->log_memory = false;
284     params->rendezvous = nullptr;
285     params->collective_executor = nullptr;
286     params->session_state = nullptr;  // ???
287     params->session_handle = "session_handle";
288     params->tensor_store = nullptr;
289     params->cancellation_manager = nullptr;
290 
291     params->call_frame = nullptr;
292     params->function_library = nullptr;
293     params->runner = nullptr;
294     params->graph_collector = nullptr;
295 
296     params->step_container = nullptr;
297     params->slice_reader_cache = nullptr;
298     params->resource_manager = gpu_helper_->gpu()->resource_manager();
299 
300     params->stats_collector = nullptr;
301     params->inc_num_deferred_ops_function = nullptr;
302     params->dec_num_deferred_ops_function = nullptr;
303 
304     params->op_device_context = nullptr;
305     params->track_allocations = false;
306     params->op_kernel = kernel;
307     params->frame_iter = FrameAndIter(0, 0);
308     params->is_input_dead = false;
309 
310     if (add_inputs_.empty()) {
311       add_inputs_.resize(2);
312       add_inputs_[0] = TensorValue(&gpu_inputs_[0]);
313       add_inputs_[1] = TensorValue(&gpu_inputs_[1]);
314     }
315     params->inputs = add_inputs_;
316     SetOutputAttrs(params, &allocator_attrs_);
317   }
318 
319   struct TimeSet {
320     int iter = 0;
321     int64_t start = 0;
322     int64_t copy_done = 0;
323     int64_t compute_done = 0;
324     int64_t final_copy = 0;
325     int64_t all_done = 0;
326   };
327 
328   // Display sampled iteration times giving the approximate breakdown
329   // within iterations and overall curve.
DisplayTimes(std::vector<TimeSet> * times)330   void DisplayTimes(std::vector<TimeSet>* times) {
331     LOG(INFO) << "Summarize set of " << times->size() << " iters";
332     for (auto& ts : *times) {
333       ts.final_copy = ts.all_done - ts.compute_done;
334       ts.compute_done = ts.compute_done - ts.copy_done;
335       ts.copy_done = ts.copy_done - ts.start;
336       ts.all_done = ts.all_done - ts.start;
337     }
338     struct TSSort {
339       bool operator()(const TimeSet& a, const TimeSet& b) {
340         return a.all_done < b.all_done;
341       }
342     };
343     std::sort(times->begin(), times->end(), TSSort());
344     int64_t last_time = 0;
345     // Display first, last and every > 5% change.
346     for (int i = 0; i < times->size(); ++i) {
347       if (i == (times->size() - 1) ||
348           (times->at(i).all_done >= (1.05 * last_time))) {
349         LOG(INFO) << "rank " << i << " iter: " << times->at(i).iter
350                   << " copy: " << times->at(i).copy_done
351                   << " compute: " << times->at(i).compute_done
352                   << " copy back: " << times->at(i).final_copy
353                   << " sum: " << times->at(i).all_done;
354         last_time = times->at(i).all_done;
355       }
356     }
357   }
358 
359   // Queue one work unit on the GPU as follows:
360   // 1. Copy 2 input tensors from CPU to GPU using h2d stream.
361   // 2. Instruct compute stream to wait on h2d stream.
362   // 3. Queue a sequence of Add ops on the compute stream, all using
363   //    the same input tensors, allocating their own output tensors.
364   // 4. Instruct d2h stream to wait on the compute stream.
365   // 5. Copy final output tensor back to the CPU.
366   // 6. Instruct the EventMgr to execute callback when the final tensor
367   //    copy completes.
368   // If event_after_add == true then additionally instruct the EventMgr
369   //    to execute the callback after each Add completes.
370   // The optional times parameter is used for gathering detailed timing
371   // data.
DoAddChain(int adds_per_copy,int rounds,bool event_after_add,std::function<void ()> callback,std::vector<TimeSet> * times)372   void DoAddChain(int adds_per_copy, int rounds, bool event_after_add,
373                   std::function<void()> callback, std::vector<TimeSet>* times) {
374     // Take an extra ref on the inputs so that the add doesn't compute in place.
375     Tensor alias0(gpu_inputs_[0]);
376     Tensor alias1(gpu_inputs_[1]);
377     for (int r = 0; r < rounds; ++r) {
378       if (times) {
379         times->at(r).iter = r;
380         times->at(r).start = Env::Default()->NowMicros();
381       }
382       gpu_helper_->h2d_stream()->ThenWaitFor(gpu_helper_->compute_stream());
383       // Begin by copying the input values from CPU to GPU.
384       const int64_t src_bytes = host_inputs_[0].TotalBytes();
385       se::DeviceMemoryBase gpu_dst_ptr0(DMAHelper::base(&gpu_inputs_[0]),
386                                         src_bytes);
387       gpu_helper_->h2d_stream()->ThenMemcpy(
388           &gpu_dst_ptr0, DMAHelper::base(&host_inputs_[0]), src_bytes);
389       se::DeviceMemoryBase gpu_dst_ptr1(DMAHelper::base(&gpu_inputs_[1]),
390                                         src_bytes);
391       gpu_helper_->h2d_stream()->ThenMemcpy(
392           &gpu_dst_ptr1, DMAHelper::base(&host_inputs_[1]), src_bytes);
393       gpu_helper_->compute_stream()->ThenWaitFor(gpu_helper_->h2d_stream());
394       if (times) {
395         gpu_helper_->event_mgr()->ThenExecute(
396             gpu_helper_->compute_stream(), [times, r]() {
397               times->at(r).copy_done = Env::Default()->NowMicros();
398             });
399       }
400       std::unique_ptr<OpKernelContext> ctx;
401       for (int apc = 0; apc < adds_per_copy; ++apc) {
402         ctx.reset(new OpKernelContext(add_params_[apc], 1));
403         gpu_helper_->gpu()->Compute(add_kernels_[apc].get(), ctx.get());
404         TF_ASSERT_OK(ctx->status());
405         if (event_after_add) {
406           gpu_helper_->event_mgr()->ThenExecute(gpu_helper_->compute_stream(),
407                                                 callback);
408         }
409       }
410       // Finish by copying output back to CPU.
411       if (times) {
412         gpu_helper_->event_mgr()->ThenExecute(
413             gpu_helper_->compute_stream(), [times, r]() {
414               times->at(r).compute_done = Env::Default()->NowMicros();
415             });
416       }
417       gpu_helper_->d2h_stream()->ThenWaitFor(gpu_helper_->compute_stream());
418       const int64_t return_bytes = ctx->mutable_output(0)->TotalBytes();
419       se::DeviceMemoryBase gpu_src_ptr(DMAHelper::base(ctx->mutable_output(0)),
420                                        return_bytes);
421       gpu_helper_->d2h_stream()->ThenMemcpy(DMAHelper::base(&host_outputs_[0]),
422                                             gpu_src_ptr, return_bytes);
423       gpu_helper_->event_mgr()->ThenExecute(gpu_helper_->d2h_stream(),
424                                             callback);
425       if (times) {
426         gpu_helper_->event_mgr()->ThenExecute(
427             gpu_helper_->d2h_stream(), [times, r]() {
428               times->at(r).all_done = Env::Default()->NowMicros();
429             });
430       }
431     }
432   }
433 };
434 
BM_no_ops(::testing::benchmark::State & state)435 static void BM_no_ops(::testing::benchmark::State& state) {
436   const int threads = state.range(0);
437   const int iters = state.max_iterations;
438 
439   auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
440   std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
441   CHECK(stream);
442   stream->Init();
443   TEST_EventMgr em(stream_exec, GPUOptions());
444 
445   auto benchmark_exec = [&]() {
446     std::atomic<int> counter;
447     counter.store(0, std::memory_order_seq_cst);
448     se::Stream* stream_ptr = stream.get();
449     auto runner = [&em, &counter, stream_ptr, iters]() {
450       auto callback = [&counter]() { counter.fetch_add(1); };
451       for (int i = 0; i < iters; ++i) {
452         em.ThenExecute(stream_ptr, callback);
453       }
454     };
455     for (int t = 0; t < threads; ++t) {
456       Env::Default()->SchedClosure(runner);
457     }
458     int expected = iters * threads;
459     while (counter < expected) {
460       Env::Default()->SleepForMicroseconds(1);
461     }
462   };
463 
464 #ifdef PLATFORM_GOOGLE
465 
466   // The timer starts automatically
467   while (state.KeepRunningBatch(state.max_iterations)) {
468     benchmark_exec();
469   }
470 #else
471   // The tensorflow's own implementation of the benchmark does not support
472   // running-batch (yet), therefore we had to use the Stop/StartTimer.
473   // FIXME: Remove this if-def once we switched all tensorflow's benchmarks to
474   // using the OSS benchmark library.
475 
476   state.ResumeTiming();
477   benchmark_exec();
478   state.PauseTiming();
479 #endif
480 }
481 BENCHMARK(BM_no_ops)->UseRealTime()->Arg(4)->Arg(8)->Arg(32);
482 
483 // Benchmark functions are defined at top level.  In order to provide a real,
484 // persistent GPUDevice to the following function it also needs to be at top
485 // level.  But then we can't clean it up without a cuda runtime error, so we
486 // just leak it.
487 GPUDeviceTestHelper* gpu_helper = nullptr;
488 EMBenchmarkHelper* bm_helper = nullptr;
489 mutex helper_mu;
490 
491 #ifdef PLATFORM_GOOGLE
BM_chain_ops(::testing::benchmark::State & state,int tensor_size,int adds_per_round,bool event_after_add,int pending_cap)492 static void BM_chain_ops(::testing::benchmark::State& state, int tensor_size,
493                          int adds_per_round, bool event_after_add,
494                          int pending_cap) {
495 #else
496 static void BM_chain_ops(::testing::benchmark::State& state, int tensor_size,
497                          int adds_per_round, bool event_after_add,
498                          int pending_cap, int threads) {
499 #endif
500   const int iters = state.max_iterations;
501   {
502     mutex_lock l(helper_mu);
503     if (gpu_helper && gpu_helper->pending_cap() != pending_cap) {
504       delete bm_helper;
505       bm_helper = nullptr;
506       delete gpu_helper;
507       gpu_helper = nullptr;
508     }
509     if (!gpu_helper) {
510       gpu_helper = new GPUDeviceTestHelper(1 << 24, pending_cap);
511       bm_helper = new EMBenchmarkHelper(gpu_helper);
512     }
513     if (bm_helper->num_ops() != adds_per_round ||
514         bm_helper->tensor_size() != tensor_size) {
515       bm_helper->ReInit(adds_per_round, tensor_size);
516     }
517   }
518   std::vector<EMBenchmarkHelper::TimeSet> times;
519   std::vector<EMBenchmarkHelper::TimeSet>* time_ptr = nullptr;
520   if (VLOG_IS_ON(1)) {
521     times.resize(iters);
522     time_ptr = &times;
523   }
524   std::atomic<int> counter;
525   counter.store(0, std::memory_order_seq_cst);
526   auto callback = [&counter]() { counter.fetch_add(1); };
527   // First iter is always slow, so do one prior to the timed loop.
528   int expected = 1 + (event_after_add ? adds_per_round : 0);
529   bm_helper->DoAddChain(adds_per_round, 1, event_after_add, callback, nullptr);
530   while (counter < expected) {
531     Env::Default()->SleepForMicroseconds(1);
532   }
533   counter = 0;
534 
535 #ifdef PLATFORM_GOOGLE
536   while (state.KeepRunningBatch(state.max_iterations)) {
537     expected = iters * (1 + (event_after_add ? adds_per_round : 0));
538     bm_helper->DoAddChain(adds_per_round, iters, event_after_add, callback,
539                           time_ptr);
540     while (counter < expected) {
541       Env::Default()->SleepForMicroseconds(1);
542     }
543   }
544 #else
545   state.ResumeTiming();
546   expected = threads * iters * (1 + (event_after_add ? adds_per_round : 0));
547   for (int i = 0; i < threads; ++i) {
548     Env::Default()->SchedClosure(
549         [callback, iters, adds_per_round, event_after_add, time_ptr]() {
550           bm_helper->DoAddChain(adds_per_round, iters, event_after_add,
551                                 callback, time_ptr);
552         });
553   }
554   while (counter < expected) {
555     Env::Default()->SleepForMicroseconds(1);
556   }
557   state.PauseTiming();
558 #endif
559   VLOG(1) << "counter = " << counter << " post_execute Output: "
560           << bm_helper->host_outputs(0).SummarizeValue(64);
561   if (time_ptr) bm_helper->DisplayTimes(time_ptr);
562 }
563 
564 #ifdef PLATFORM_GOOGLE
565 static void BM_chain_1024_1_false(::testing::benchmark::State& state) {
566   BM_chain_ops(state, 1024, 1, false, 0);
567 }
568 
569 static void BM_chain_1024_1_true(::testing::benchmark::State& state) {
570   BM_chain_ops(state, 1024, 1, true, 0);
571 }
572 
573 static void BM_chain_1024_10_false(::testing::benchmark::State& state) {
574   BM_chain_ops(state, 1024, 10, false, 0);
575 }
576 
577 static void BM_chain_1024_10_true(::testing::benchmark::State& state) {
578   BM_chain_ops(state, 1024, 10, true, 0);
579 }
580 
581 static void BM_chain_1024_100_false(::testing::benchmark::State& state) {
582   BM_chain_ops(state, 1024, 100, false, 0);
583 }
584 
585 static void BM_chain_1024_100_true(::testing::benchmark::State& state) {
586   BM_chain_ops(state, 1024, 100, true, 0);
587 }
588 
589 static void BM_chain_1M_1_false(::testing::benchmark::State& state) {
590   BM_chain_ops(state, 1 << 20, 1, false, 0);
591 }
592 
593 static void BM_chain_1M_1_true(::testing::benchmark::State& state) {
594   BM_chain_ops(state, 1 << 20, 1, true, 0);
595 }
596 
597 static void BM_chain_1M_10_false(::testing::benchmark::State& state) {
598   BM_chain_ops(state, 1 << 20, 10, false, 0);
599 }
600 
601 static void BM_chain_1M_10_true(::testing::benchmark::State& state) {
602   BM_chain_ops(state, 1 << 20, 10, true, 0);
603 }
604 
605 static void BM_chain_1M_100_false(::testing::benchmark::State& state) {
606   BM_chain_ops(state, 1 << 20, 100, false, 0);
607 }
608 
609 static void BM_chain_1M_100_true(::testing::benchmark::State& state) {
610   BM_chain_ops(state, 1 << 20, 100, true, 0);
611 }
612 
613 BENCHMARK(BM_chain_1024_1_false)->UseRealTime()->Threads(1);
614 BENCHMARK(BM_chain_1024_1_true)->UseRealTime()->Threads(1);
615 BENCHMARK(BM_chain_1024_1_false)->UseRealTime()->Threads(2);
616 BENCHMARK(BM_chain_1024_1_true)->UseRealTime()->Threads(2);
617 BENCHMARK(BM_chain_1024_1_false)->UseRealTime()->Threads(8);
618 BENCHMARK(BM_chain_1024_1_true)->UseRealTime()->Threads(8);
619 BENCHMARK(BM_chain_1024_10_false)->UseRealTime()->Threads(1);
620 BENCHMARK(BM_chain_1024_10_true)->UseRealTime()->Threads(1);
621 BENCHMARK(BM_chain_1024_10_false)->UseRealTime()->Threads(8);
622 BENCHMARK(BM_chain_1024_10_true)->UseRealTime()->Threads(8);
623 BENCHMARK(BM_chain_1024_100_false)->UseRealTime()->Threads(1);
624 BENCHMARK(BM_chain_1024_100_true)->UseRealTime()->Threads(1);
625 BENCHMARK(BM_chain_1024_100_false)->UseRealTime()->Threads(2);
626 BENCHMARK(BM_chain_1024_100_true)->UseRealTime()->Threads(2);
627 BENCHMARK(BM_chain_1024_100_false)->UseRealTime()->Threads(8);
628 BENCHMARK(BM_chain_1024_100_true)->UseRealTime()->Threads(8);
629 
630 BENCHMARK(BM_chain_1M_1_false)->UseRealTime()->Threads(1);
631 BENCHMARK(BM_chain_1M_1_true)->UseRealTime()->Threads(1);
632 BENCHMARK(BM_chain_1M_1_false)->UseRealTime()->Threads(2);
633 BENCHMARK(BM_chain_1M_1_true)->UseRealTime()->Threads(2);
634 BENCHMARK(BM_chain_1M_1_false)->UseRealTime()->Threads(8);
635 BENCHMARK(BM_chain_1M_1_true)->UseRealTime()->Threads(8);
636 BENCHMARK(BM_chain_1M_10_false)->UseRealTime()->Threads(1);
637 BENCHMARK(BM_chain_1M_10_true)->UseRealTime()->Threads(1);
638 BENCHMARK(BM_chain_1M_10_false)->UseRealTime()->Threads(8);
639 BENCHMARK(BM_chain_1M_10_true)->UseRealTime()->Threads(8);
640 BENCHMARK(BM_chain_1M_100_false)->UseRealTime()->Threads(1);
641 BENCHMARK(BM_chain_1M_100_true)->UseRealTime()->Threads(1);
642 BENCHMARK(BM_chain_1M_100_false)->UseRealTime()->Threads(2);
643 BENCHMARK(BM_chain_1M_100_true)->UseRealTime()->Threads(2);
644 BENCHMARK(BM_chain_1M_100_false)->UseRealTime()->Threads(8);
645 BENCHMARK(BM_chain_1M_100_true)->UseRealTime()->Threads(8);
646 #else
647 static void BM_chain_1024_1_false(::testing::benchmark::State& state) {
648   const int threads = state.range(0);
649   BM_chain_ops(state, 1024, 1, false, 0, threads);
650 }
651 
652 static void BM_chain_1024_1_true(::testing::benchmark::State& state) {
653   const int threads = state.range(0);
654   BM_chain_ops(state, 1024, 1, true, 0, threads);
655 }
656 
657 static void BM_chain_1024_10_false(::testing::benchmark::State& state) {
658   const int threads = state.range(0);
659   BM_chain_ops(state, 1024, 10, false, 0, threads);
660 }
661 
662 static void BM_chain_1024_10_true(::testing::benchmark::State& state) {
663   const int threads = state.range(0);
664   BM_chain_ops(state, 1024, 10, true, 0, threads);
665 }
666 
667 static void BM_chain_1024_100_false(::testing::benchmark::State& state) {
668   const int threads = state.range(0);
669   BM_chain_ops(state, 1024, 100, false, 0, threads);
670 }
671 
672 static void BM_chain_1024_100_true(::testing::benchmark::State& state) {
673   const int threads = state.range(0);
674   BM_chain_ops(state, 1024, 100, true, 0, threads);
675 }
676 
677 static void BM_chain_1M_1_false(::testing::benchmark::State& state) {
678   const int threads = state.range(0);
679   BM_chain_ops(state, 1 << 20, 1, false, 0, threads);
680 }
681 
682 static void BM_chain_1M_1_true(::testing::benchmark::State& state) {
683   const int threads = state.range(0);
684   BM_chain_ops(state, 1 << 20, 1, true, 0, threads);
685 }
686 
687 static void BM_chain_1M_10_false(::testing::benchmark::State& state) {
688   const int threads = state.range(0);
689   BM_chain_ops(state, 1 << 20, 10, false, 0, threads);
690 }
691 
692 static void BM_chain_1M_10_true(::testing::benchmark::State& state) {
693   const int threads = state.range(0);
694   BM_chain_ops(state, 1 << 20, 10, true, 0, threads);
695 }
696 
697 static void BM_chain_1M_100_false(::testing::benchmark::State& state) {
698   const int threads = state.range(0);
699   BM_chain_ops(state, 1 << 20, 100, false, 0, threads);
700 }
701 
702 static void BM_chain_1M_100_true(::testing::benchmark::State& state) {
703   const int threads = state.range(0);
704   BM_chain_ops(state, 1 << 20, 100, true, 0, threads);
705 }
706 
707 BENCHMARK(BM_chain_1024_1_false)->UseRealTime()->Arg(1);
708 BENCHMARK(BM_chain_1024_1_true)->UseRealTime()->Arg(1);
709 BENCHMARK(BM_chain_1024_1_false)->UseRealTime()->Arg(2);
710 BENCHMARK(BM_chain_1024_1_true)->UseRealTime()->Arg(2);
711 BENCHMARK(BM_chain_1024_1_false)->UseRealTime()->Arg(8);
712 BENCHMARK(BM_chain_1024_1_true)->UseRealTime()->Arg(8);
713 BENCHMARK(BM_chain_1024_10_false)->UseRealTime()->Arg(1);
714 BENCHMARK(BM_chain_1024_10_true)->UseRealTime()->Arg(1);
715 BENCHMARK(BM_chain_1024_10_false)->UseRealTime()->Arg(8);
716 BENCHMARK(BM_chain_1024_10_true)->UseRealTime()->Arg(8);
717 BENCHMARK(BM_chain_1024_100_false)->UseRealTime()->Arg(1);
718 BENCHMARK(BM_chain_1024_100_true)->UseRealTime()->Arg(1);
719 BENCHMARK(BM_chain_1024_100_false)->UseRealTime()->Arg(2);
720 BENCHMARK(BM_chain_1024_100_true)->UseRealTime()->Arg(2);
721 BENCHMARK(BM_chain_1024_100_false)->UseRealTime()->Arg(8);
722 BENCHMARK(BM_chain_1024_100_true)->UseRealTime()->Arg(8);
723 
724 BENCHMARK(BM_chain_1M_1_false)->UseRealTime()->Arg(1);
725 BENCHMARK(BM_chain_1M_1_true)->UseRealTime()->Arg(1);
726 BENCHMARK(BM_chain_1M_1_false)->UseRealTime()->Arg(2);
727 BENCHMARK(BM_chain_1M_1_true)->UseRealTime()->Arg(2);
728 BENCHMARK(BM_chain_1M_1_false)->UseRealTime()->Arg(8);
729 BENCHMARK(BM_chain_1M_1_true)->UseRealTime()->Arg(8);
730 BENCHMARK(BM_chain_1M_10_false)->UseRealTime()->Arg(1);
731 BENCHMARK(BM_chain_1M_10_true)->UseRealTime()->Arg(1);
732 BENCHMARK(BM_chain_1M_10_false)->UseRealTime()->Arg(8);
733 BENCHMARK(BM_chain_1M_10_true)->UseRealTime()->Arg(8);
734 BENCHMARK(BM_chain_1M_100_false)->UseRealTime()->Arg(1);
735 BENCHMARK(BM_chain_1M_100_true)->UseRealTime()->Arg(1);
736 BENCHMARK(BM_chain_1M_100_false)->UseRealTime()->Arg(2);
737 BENCHMARK(BM_chain_1M_100_true)->UseRealTime()->Arg(2);
738 BENCHMARK(BM_chain_1M_100_false)->UseRealTime()->Arg(8);
739 BENCHMARK(BM_chain_1M_100_true)->UseRealTime()->Arg(8);
740 #endif
741 }  // namespace
742 }  // namespace tensorflow
743 
744 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
745