xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/throughput_benchmark-inl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <random>
4 #include <thread>
5 
6 #include <torch/csrc/autograd/profiler.h>
7 #include <torch/csrc/jit/python/pybind_utils.h>
8 #include <torch/csrc/utils/pybind.h>
9 
10 #include <ATen/Parallel.h>
11 #include <c10/core/GradMode.h>
12 #include <c10/core/impl/LocalDispatchKeySet.h>
13 #include <c10/util/irange.h>
14 
15 namespace torch::throughput_benchmark::detail {
16 
17 template <class Input, class Output, class Model>
benchmark(const BenchmarkConfig & config)18 BenchmarkExecutionStats BenchmarkHelper<Input, Output, Model>::benchmark(
19     const BenchmarkConfig& config) const {
20   CHECK(initialized_);
21   TORCH_CHECK(
22       config.num_worker_threads == 1,
23       "Only parallelization by callers is supported");
24 
25   LOG(INFO) << at::get_parallel_info();
26 
27   // We pre-generate inputs here for each of the threads. This allows us to
28   // safely move inputs out for each of the threads independently and thus avoid
29   // overhead from the benchmark runner itself
30   std::vector<std::vector<Input>> thread_inputs(config.num_calling_threads);
31   std::vector<size_t> input_iters(config.num_calling_threads);
32   {
33     std::random_device seeder;
34     std::mt19937 engine(seeder());
35     TORCH_CHECK(
36         !inputs_.empty(),
37         "Please provide benchmark inputs."
38         "Did you forget to call add_input()? ");
39     std::uniform_int_distribution<int> dist(0, inputs_.size() - 1);
40 
41     for (const auto thread_id : c10::irange(config.num_calling_threads)) {
42       // Just in case we generate num_iters inputs for each of the threads
43       // This was if one thread does all the work we will be fine
44       for (const auto i [[maybe_unused]] :
45            c10::irange(config.num_iters + config.num_warmup_iters)) {
46         thread_inputs[thread_id].push_back(cloneInput(inputs_[dist(engine)]));
47       }
48       input_iters[thread_id] = 0;
49     }
50   }
51 
52   std::mutex m;
53   std::condition_variable worker_main_cv;
54   std::condition_variable main_worker_cv;
55   // TODO: add GUARDED_BY once it is available
56   int64_t initialized{0};
57   int64_t finished{0};
58   bool start{false};
59   std::atomic<int64_t> num_attempted_iters{0};
60   std::vector<std::thread> callers;
61 
62   callers.reserve(config.num_calling_threads);
63 
64   bool tls_grad_enabled = c10::GradMode::is_enabled();
65   c10::impl::LocalDispatchKeySet tls_key_set =
66       c10::impl::tls_local_dispatch_key_set();
67 
68   for (const auto thread_id : c10::irange(config.num_calling_threads)) {
69     callers.emplace_back([&, thread_id]() {
70       // We use conditional variable as a barrier to make sure each thread
71       // performs required warmeup iterations before we start measuring
72       c10::GradMode::set_enabled(tls_grad_enabled);
73       c10::impl::_force_tls_local_dispatch_key_set(tls_key_set);
74 
75       for (const auto j : c10::irange(config.num_warmup_iters)) {
76         (void)j;
77         runOnce(std::move(thread_inputs[thread_id][input_iters[thread_id]]));
78         ++input_iters[thread_id];
79       }
80       {
81         std::unique_lock<std::mutex> lock(m);
82         ++initialized;
83         worker_main_cv.notify_one();
84         // NOLINTNEXTLINE(bugprone-infinite-loop)
85         while (!start) {
86           main_worker_cv.wait(lock);
87         }
88       }
89       LOG(INFO) << "Starting forward thread " << thread_id;
90       while (num_attempted_iters.fetch_add(1) < config.num_iters) {
91         runOnce(std::move(thread_inputs[thread_id][input_iters[thread_id]]));
92         ++input_iters[thread_id];
93       }
94 
95       {
96         std::unique_lock<std::mutex> lock(m);
97         ++finished;
98         worker_main_cv.notify_one();
99         LOG(INFO) << "Shutting down forward thread " << thread_id
100                   << ". Total number of finished threads: " << finished;
101       }
102     });
103   }
104 
105   using Clock = std::chrono::high_resolution_clock;
106   using RecordProfile = torch::autograd::profiler::RecordProfile;
107   using TimePoint = std::chrono::time_point<Clock>;
108   TimePoint start_time;
109 
110   std::unique_ptr<RecordProfile> profiler_guard;
111   {
112     std::unique_lock<std::mutex> lock(m);
113     while (initialized != config.num_calling_threads) {
114       worker_main_cv.wait(lock);
115     }
116     if (!config.profiler_output_path.empty()) {
117       LOG(INFO) << "Using Autograd profiler. Trace will be saved to "
118                 << config.profiler_output_path;
119       profiler_guard =
120           std::make_unique<RecordProfile>(config.profiler_output_path);
121     }
122     LOG(INFO) << "Starting threads";
123     start = true;
124     start_time = Clock::now();
125   }
126 
127   main_worker_cv.notify_all();
128   {
129     std::unique_lock<std::mutex> lock(m);
130     worker_main_cv.wait(
131         lock, [&]() { return finished == config.num_calling_threads; });
132   }
133   auto end_time = std::chrono::high_resolution_clock::now();
134   profiler_guard.reset();
135   LOG(INFO) << "Finished benchmark";
136 
137   BenchmarkExecutionStats stats;
138   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
139   float total_time_ms = std::chrono::duration_cast<std::chrono::nanoseconds>(
140                             end_time - start_time)
141                             .count() /
142       1000.0 / 1000.0;
143   // We use config.num_iters instead of num_attempted_iters as it is
144   // repsesatative of the real work done. Last attempted iteration on each
145   // calling threads doesn't represent the real work (i.e. running the model)
146   stats.latency_avg_ms =
147       // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
148       total_time_ms * config.num_calling_threads / config.num_iters;
149   stats.num_iters = config.num_iters;
150 
151   for (auto& t : callers) {
152     t.join();
153   }
154   return stats;
155 }
156 
157 } // namespace torch::throughput_benchmark::detail
158