1 #pragma once
2
3 #include <random>
4 #include <thread>
5
6 #include <torch/csrc/autograd/profiler.h>
7 #include <torch/csrc/jit/python/pybind_utils.h>
8 #include <torch/csrc/utils/pybind.h>
9
10 #include <ATen/Parallel.h>
11 #include <c10/core/GradMode.h>
12 #include <c10/core/impl/LocalDispatchKeySet.h>
13 #include <c10/util/irange.h>
14
15 namespace torch::throughput_benchmark::detail {
16
17 template <class Input, class Output, class Model>
benchmark(const BenchmarkConfig & config)18 BenchmarkExecutionStats BenchmarkHelper<Input, Output, Model>::benchmark(
19 const BenchmarkConfig& config) const {
20 CHECK(initialized_);
21 TORCH_CHECK(
22 config.num_worker_threads == 1,
23 "Only parallelization by callers is supported");
24
25 LOG(INFO) << at::get_parallel_info();
26
27 // We pre-generate inputs here for each of the threads. This allows us to
28 // safely move inputs out for each of the threads independently and thus avoid
29 // overhead from the benchmark runner itself
30 std::vector<std::vector<Input>> thread_inputs(config.num_calling_threads);
31 std::vector<size_t> input_iters(config.num_calling_threads);
32 {
33 std::random_device seeder;
34 std::mt19937 engine(seeder());
35 TORCH_CHECK(
36 !inputs_.empty(),
37 "Please provide benchmark inputs."
38 "Did you forget to call add_input()? ");
39 std::uniform_int_distribution<int> dist(0, inputs_.size() - 1);
40
41 for (const auto thread_id : c10::irange(config.num_calling_threads)) {
42 // Just in case we generate num_iters inputs for each of the threads
43 // This was if one thread does all the work we will be fine
44 for (const auto i [[maybe_unused]] :
45 c10::irange(config.num_iters + config.num_warmup_iters)) {
46 thread_inputs[thread_id].push_back(cloneInput(inputs_[dist(engine)]));
47 }
48 input_iters[thread_id] = 0;
49 }
50 }
51
52 std::mutex m;
53 std::condition_variable worker_main_cv;
54 std::condition_variable main_worker_cv;
55 // TODO: add GUARDED_BY once it is available
56 int64_t initialized{0};
57 int64_t finished{0};
58 bool start{false};
59 std::atomic<int64_t> num_attempted_iters{0};
60 std::vector<std::thread> callers;
61
62 callers.reserve(config.num_calling_threads);
63
64 bool tls_grad_enabled = c10::GradMode::is_enabled();
65 c10::impl::LocalDispatchKeySet tls_key_set =
66 c10::impl::tls_local_dispatch_key_set();
67
68 for (const auto thread_id : c10::irange(config.num_calling_threads)) {
69 callers.emplace_back([&, thread_id]() {
70 // We use conditional variable as a barrier to make sure each thread
71 // performs required warmeup iterations before we start measuring
72 c10::GradMode::set_enabled(tls_grad_enabled);
73 c10::impl::_force_tls_local_dispatch_key_set(tls_key_set);
74
75 for (const auto j : c10::irange(config.num_warmup_iters)) {
76 (void)j;
77 runOnce(std::move(thread_inputs[thread_id][input_iters[thread_id]]));
78 ++input_iters[thread_id];
79 }
80 {
81 std::unique_lock<std::mutex> lock(m);
82 ++initialized;
83 worker_main_cv.notify_one();
84 // NOLINTNEXTLINE(bugprone-infinite-loop)
85 while (!start) {
86 main_worker_cv.wait(lock);
87 }
88 }
89 LOG(INFO) << "Starting forward thread " << thread_id;
90 while (num_attempted_iters.fetch_add(1) < config.num_iters) {
91 runOnce(std::move(thread_inputs[thread_id][input_iters[thread_id]]));
92 ++input_iters[thread_id];
93 }
94
95 {
96 std::unique_lock<std::mutex> lock(m);
97 ++finished;
98 worker_main_cv.notify_one();
99 LOG(INFO) << "Shutting down forward thread " << thread_id
100 << ". Total number of finished threads: " << finished;
101 }
102 });
103 }
104
105 using Clock = std::chrono::high_resolution_clock;
106 using RecordProfile = torch::autograd::profiler::RecordProfile;
107 using TimePoint = std::chrono::time_point<Clock>;
108 TimePoint start_time;
109
110 std::unique_ptr<RecordProfile> profiler_guard;
111 {
112 std::unique_lock<std::mutex> lock(m);
113 while (initialized != config.num_calling_threads) {
114 worker_main_cv.wait(lock);
115 }
116 if (!config.profiler_output_path.empty()) {
117 LOG(INFO) << "Using Autograd profiler. Trace will be saved to "
118 << config.profiler_output_path;
119 profiler_guard =
120 std::make_unique<RecordProfile>(config.profiler_output_path);
121 }
122 LOG(INFO) << "Starting threads";
123 start = true;
124 start_time = Clock::now();
125 }
126
127 main_worker_cv.notify_all();
128 {
129 std::unique_lock<std::mutex> lock(m);
130 worker_main_cv.wait(
131 lock, [&]() { return finished == config.num_calling_threads; });
132 }
133 auto end_time = std::chrono::high_resolution_clock::now();
134 profiler_guard.reset();
135 LOG(INFO) << "Finished benchmark";
136
137 BenchmarkExecutionStats stats;
138 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
139 float total_time_ms = std::chrono::duration_cast<std::chrono::nanoseconds>(
140 end_time - start_time)
141 .count() /
142 1000.0 / 1000.0;
143 // We use config.num_iters instead of num_attempted_iters as it is
144 // repsesatative of the real work done. Last attempted iteration on each
145 // calling threads doesn't represent the real work (i.e. running the model)
146 stats.latency_avg_ms =
147 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
148 total_time_ms * config.num_calling_threads / config.num_iters;
149 stats.num_iters = config.num_iters;
150
151 for (auto& t : callers) {
152 t.join();
153 }
154 return stats;
155 }
156
157 } // namespace torch::throughput_benchmark::detail
158