1 #include "ATen/Parallel.h"
2
3 #include "c10/util/Flags.h"
4 #include "caffe2/core/init.h"
5
6 #include <atomic>
7 #include <chrono>
8 #include <condition_variable>
9 #include <iostream>
10 #include <mutex>
11 #include <ctime>
12
13 C10_DEFINE_int(iter, 10e4, "Number of at::launch iterations (tasks)");
14 C10_DEFINE_int(warmup_iter, 10, "Number of warmup iterations")
15 C10_DEFINE_int(inter_op_threads, 0, "Number of inter-op threads");
16 C10_DEFINE_int(benchmark_iter, 3, "Number of times to run benchmark")
17
18 namespace {
19 int iter = 0;
20 std::atomic<int> counter{0};
21 std::condition_variable cv;
22 std::mutex mutex;
23 }
24
launch_tasks()25 void launch_tasks() {
26 at::launch([]() {
27 at::launch([](){
28 at::launch([]() {
29 auto cur_ctr = ++counter;
30 if (cur_ctr == iter) {
31 std::unique_lock<std::mutex> lk(mutex);
32 cv.notify_one();
33 }
34 });
35 });
36 });
37 }
38
launch_tasks_and_wait(int tasks_num)39 void launch_tasks_and_wait(int tasks_num) {
40 iter = tasks_num;
41 counter = 0;
42 for (auto idx = 0; idx < iter; ++idx) {
43 launch_tasks();
44 }
45 {
46 std::unique_lock<std::mutex> lk(mutex);
47 while (counter < iter) {
48 cv.wait(lk);
49 }
50 }
51 }
52
main(int argc,char ** argv)53 int main(int argc, char** argv) {
54 if (!c10::ParseCommandLineFlags(&argc, &argv)) {
55 std::cout << "Failed to parse command line flags" << std::endl;
56 return -1;
57 }
58 caffe2::unsafeRunCaffe2InitFunction("registerThreadPools");
59 at::init_num_threads();
60
61 if (FLAGS_inter_op_threads > 0) {
62 at::set_num_interop_threads(FLAGS_inter_op_threads);
63 }
64
65 typedef std::chrono::high_resolution_clock clock;
66 typedef std::chrono::milliseconds ms;
67
68 std::cout << "Launching " << FLAGS_warmup_iter << " warmup tasks using "
69 << at::get_num_interop_threads() << " threads "
70 << std::endl;
71
72 std::chrono::time_point<clock> start_time = clock::now();
73 launch_tasks_and_wait(FLAGS_warmup_iter);
74 auto duration = static_cast<float>(
75 std::chrono::duration_cast<ms>(clock::now() - start_time).count());
76
77 std::cout << "Warmup time: " << duration << " ms." << std::endl;
78
79 std::cout << "Launching " << FLAGS_iter << " tasks using "
80 << at::get_num_interop_threads() << " threads "
81 << std::endl;
82
83 for (auto bench_iter = 0; bench_iter < FLAGS_benchmark_iter; ++bench_iter) {
84 start_time = clock::now();
85 launch_tasks_and_wait(FLAGS_iter);
86 duration = static_cast<float>(
87 std::chrono::duration_cast<ms>(clock::now() - start_time).count());
88
89 std::cout << "Time to run " << iter << " iterations "
90 << (duration/1000.0) << " s." << std::endl;
91 }
92
93 return 0;
94 }
95