xref: /aosp_15_r20/external/pytorch/binaries/at_launch_benchmark.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include "ATen/Parallel.h"
2 
3 #include "c10/util/Flags.h"
4 #include "caffe2/core/init.h"
5 
6 #include <atomic>
7 #include <chrono>
8 #include <condition_variable>
9 #include <iostream>
10 #include <mutex>
11 #include <ctime>
12 
13 C10_DEFINE_int(iter, 10e4, "Number of at::launch iterations (tasks)");
14 C10_DEFINE_int(warmup_iter, 10, "Number of warmup iterations")
15 C10_DEFINE_int(inter_op_threads, 0, "Number of inter-op threads");
16 C10_DEFINE_int(benchmark_iter, 3, "Number of times to run benchmark")
17 
18 namespace {
19 int iter = 0;
20 std::atomic<int> counter{0};
21 std::condition_variable cv;
22 std::mutex mutex;
23 }
24 
launch_tasks()25  void launch_tasks() {
26   at::launch([]() {
27     at::launch([](){
28       at::launch([]() {
29         auto cur_ctr = ++counter;
30         if (cur_ctr == iter) {
31           std::unique_lock<std::mutex> lk(mutex);
32           cv.notify_one();
33         }
34       });
35     });
36   });
37 }
38 
launch_tasks_and_wait(int tasks_num)39 void launch_tasks_and_wait(int tasks_num) {
40   iter = tasks_num;
41   counter = 0;
42   for (auto idx = 0; idx < iter; ++idx) {
43     launch_tasks();
44   }
45   {
46     std::unique_lock<std::mutex> lk(mutex);
47     while (counter < iter) {
48       cv.wait(lk);
49     }
50   }
51 }
52 
main(int argc,char ** argv)53 int main(int argc, char** argv) {
54   if (!c10::ParseCommandLineFlags(&argc, &argv)) {
55     std::cout << "Failed to parse command line flags" << std::endl;
56     return -1;
57   }
58   caffe2::unsafeRunCaffe2InitFunction("registerThreadPools");
59   at::init_num_threads();
60 
61   if (FLAGS_inter_op_threads > 0) {
62     at::set_num_interop_threads(FLAGS_inter_op_threads);
63   }
64 
65   typedef std::chrono::high_resolution_clock clock;
66   typedef std::chrono::milliseconds ms;
67 
68   std::cout << "Launching " << FLAGS_warmup_iter << " warmup tasks using "
69             << at::get_num_interop_threads() << " threads "
70             << std::endl;
71 
72   std::chrono::time_point<clock> start_time = clock::now();
73   launch_tasks_and_wait(FLAGS_warmup_iter);
74   auto duration = static_cast<float>(
75       std::chrono::duration_cast<ms>(clock::now() - start_time).count());
76 
77   std::cout << "Warmup time: " << duration << " ms." << std::endl;
78 
79   std::cout << "Launching " << FLAGS_iter << " tasks using "
80             << at::get_num_interop_threads() << " threads "
81             << std::endl;
82 
83   for (auto bench_iter = 0; bench_iter < FLAGS_benchmark_iter; ++bench_iter) {
84     start_time = clock::now();
85     launch_tasks_and_wait(FLAGS_iter);
86     duration = static_cast<float>(
87         std::chrono::duration_cast<ms>(clock::now() - start_time).count());
88 
89     std::cout << "Time to run " << iter << " iterations "
90               << (duration/1000.0) << " s." << std::endl;
91   }
92 
93   return 0;
94 }
95