1 #include <ATen/Config.h>
2 #if AT_PARALLEL_NATIVE
3 #include <ATen/Parallel.h>
4 #include <ATen/ParallelFuture.h>
5 #include <ATen/PTThreadPool.h>
6
7 #ifndef C10_MOBILE
8 #include <c10/core/thread_pool.h>
9 #include <c10/util/irange.h>
10 #else
11 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
12 #endif // C10_MOBILE
13
14 #include <atomic>
15 #include <utility>
16
17 #ifdef _OPENMP
18 #include <omp.h>
19 #endif
20
21 #if AT_MKL_ENABLED()
22 #include <mkl.h>
23 #endif
24
25 namespace at {
26 namespace {
27 // used with _set_in_parallel_region to mark master thread
28 // as in parallel region while executing parallel primitives
29 thread_local bool in_parallel_region_ = false;
30
31 // thread number (task_id) set by parallel primitive
32 thread_local int thread_num_ = 0;
33
_set_in_parallel_region(bool in_region)34 void _set_in_parallel_region(bool in_region) {
35 in_parallel_region_ = in_region;
36 }
37
38 } // namespace (anonymous)
39
40 namespace internal {
set_thread_num(int thread_num)41 void set_thread_num(int thread_num) {
42 thread_num_ = thread_num;
43 }
44 }
45
46 namespace {
_unset_thread_num()47 void _unset_thread_num() {
48 thread_num_ = 0;
49 }
50
51 #ifndef C10_MOBILE
52
53 const int NOT_SET = -1;
54 const int CONSUMED = -2;
55
56 // Number of threads set by the user
57 // NOT_SET -> positive value -> CONSUMED
58 // or
59 // NOT_SET -> CONSUMED
60 // Meaning:
61 // - NOT_SET - pool not initialized, user value is not set
62 // - positive value - pool not initialized, user value set
63 // - CONSUMED - pool is initialized
64 std::atomic<int> num_intraop_threads{NOT_SET};
65
_num_pool_threads(int nthreads)66 int _num_pool_threads(int nthreads) {
67 if (nthreads == NOT_SET) {
68 nthreads = intraop_default_num_threads();
69 } else {
70 TORCH_INTERNAL_ASSERT(nthreads > 0);
71 }
72 // minus one because of the master thread
73 return nthreads - 1;
74 }
75
_get_intraop_pool()76 TaskThreadPoolBase& _get_intraop_pool() {
77 static std::shared_ptr<TaskThreadPoolBase> pool =
78 ThreadPoolRegistry()->Create(
79 "C10",
80 /* device_id */ 0,
81 /* pool_size */ _num_pool_threads(num_intraop_threads.exchange(CONSUMED)),
82 /* create_new */ true); // create a separate thread pool for intra-op
83 return *pool;
84 }
85
86 #endif // C10_MOBILE
87
88 // Run lambda function `fn` over `task_id` in [0, `range`) with threadpool.
89 // `fn` will be called with params: (thread_pool_task_id, task_id).
_run_with_pool(const std::function<void (int,size_t)> & fn,size_t range)90 void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range) {
91 #ifndef C10_MOBILE
92 for (const auto i : c10::irange(1, range)) {
93 _get_intraop_pool().run([fn, i]() { fn((int)i, i); });
94 }
95 // Run the first task on the current thread directly.
96 fn(0, 0);
97 #else
98 caffe2::PThreadPool* const pool = caffe2::pthreadpool();
99 TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
100
101 pool->run(
102 // PThreadPool::run() is blocking. A std::function [const] reference to
103 // this lambda cannot go out of scope before PThreadPool::run() returns.
104 [&fn](const size_t task_id) {
105 fn(0 /* unused */, task_id);
106 }, range);
107 #endif // C10_MOBILE
108 }
109
110 // RAII guard helps to support in_parallel_region() and get_thread_num() API.
111 struct ParallelRegionGuard {
ParallelRegionGuardat::__anona5067dc90211::ParallelRegionGuard112 ParallelRegionGuard(int task_id) {
113 internal::set_thread_num(task_id);
114 _set_in_parallel_region(true);
115 }
116
~ParallelRegionGuardat::__anona5067dc90211::ParallelRegionGuard117 ~ParallelRegionGuard() {
118 _set_in_parallel_region(false);
119 _unset_thread_num();
120 }
121 };
122
123 } // namespace
124
125 namespace internal {
126
calc_num_tasks_and_chunk_size(int64_t begin,int64_t end,int64_t grain_size)127 inline std::tuple<size_t, size_t> calc_num_tasks_and_chunk_size(
128 int64_t begin, int64_t end, int64_t grain_size) {
129 if ((end - begin) < grain_size) {
130 return std::make_tuple(1, std::max((int64_t)0, end - begin));
131 }
132 // Choose number of tasks based on grain size and number of threads.
133 size_t chunk_size = divup((end - begin), get_num_threads());
134 // Make sure each task is at least grain_size size.
135 chunk_size = std::max((size_t)grain_size, chunk_size);
136 size_t num_tasks = divup((end - begin), chunk_size);
137 return std::make_tuple(num_tasks, chunk_size);
138 }
139
invoke_parallel(const int64_t begin,const int64_t end,const int64_t grain_size,const std::function<void (int64_t,int64_t)> & f)140 void invoke_parallel(
141 const int64_t begin,
142 const int64_t end,
143 const int64_t grain_size,
144 const std::function<void(int64_t, int64_t)>& f) {
145 at::internal::lazy_init_num_threads();
146
147 size_t num_tasks = 0, chunk_size = 0;
148 std::tie(num_tasks, chunk_size) =
149 internal::calc_num_tasks_and_chunk_size(begin, end, grain_size);
150
151 struct {
152 std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
153 std::exception_ptr eptr;
154 std::mutex mutex;
155 std::atomic_size_t remaining{0};
156 std::condition_variable cv;
157 } state;
158
159 auto task = [f, &state, begin, end, chunk_size]
160 (int /* unused */, size_t task_id) {
161 int64_t local_start = begin + task_id * chunk_size;
162 if (local_start < end) {
163 int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start));
164 try {
165 ParallelRegionGuard guard(task_id);
166 f(local_start, local_end);
167 } catch (...) {
168 if (!state.err_flag.test_and_set()) {
169 state.eptr = std::current_exception();
170 }
171 }
172 }
173 {
174 std::unique_lock<std::mutex> lk(state.mutex);
175 if (--state.remaining == 0) {
176 state.cv.notify_one();
177 }
178 }
179 };
180 state.remaining = num_tasks;
181 _run_with_pool(std::move(task), num_tasks);
182
183 // Wait for all tasks to finish.
184 {
185 std::unique_lock<std::mutex> lk(state.mutex);
186 if (state.remaining != 0) {
187 state.cv.wait(lk);
188 }
189 }
190 if (state.eptr) {
191 std::rethrow_exception(state.eptr);
192 }
193 }
194
195 } // namespace internal
196
init_num_threads()197 void init_num_threads() {
198 #ifdef _OPENMP
199 omp_set_num_threads(1);
200 #endif
201
202 #if AT_MKL_ENABLED()
203 mkl_set_num_threads(1);
204 #endif
205
206 #ifdef C10_MOBILE
207 caffe2::pthreadpool();
208 #endif
209 }
210
set_num_threads(int nthreads)211 void set_num_threads(int nthreads) {
212 #ifndef C10_MOBILE
213 TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
214 int no_value = NOT_SET;
215 if (!num_intraop_threads.compare_exchange_strong(no_value, nthreads)) {
216 // num_intraop_threads either stores a positive integer or CONSUMED,
217 // check that requested size is the same as the current one
218 int stored_nthreads = num_intraop_threads.load();
219 if (stored_nthreads <= 0) {
220 // plus one because of master thread
221 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
222 stored_nthreads = _get_intraop_pool().size() + 1;
223 }
224 if (stored_nthreads != nthreads) {
225 TORCH_WARN(
226 "Cannot set number of intraop threads "
227 "after parallel work has started or after set_num_threads call "
228 "when using native parallel backend");
229 }
230 }
231 #else
232 caffe2::PThreadPool* const pool = caffe2::pthreadpool();
233 TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
234 pool->set_thread_count(nthreads);
235 #endif // C10_MOBILE
236 }
237
get_num_threads()238 int get_num_threads() {
239 at::internal::lazy_init_num_threads();
240 #ifndef C10_MOBILE
241 // not initializing pool unnecessarily,
242 // because pool cannot be resized after initialization
243 int nthreads = num_intraop_threads.load();
244 if (nthreads > 0) {
245 return nthreads;
246 } else if (nthreads == NOT_SET) {
247 return intraop_default_num_threads();
248 } else {
249 TORCH_INTERNAL_ASSERT(nthreads == CONSUMED);
250 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
251 return _get_intraop_pool().size() + 1;
252 }
253 #else
254 caffe2::PThreadPool* const pool = caffe2::pthreadpool();
255 TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!")
256 return in_parallel_region() ? 1 /* current thread */ : pool->get_thread_count();
257 #endif // C10_MOBILE
258 }
259
get_thread_num()260 int get_thread_num() {
261 return thread_num_;
262 }
263
in_parallel_region()264 bool in_parallel_region() {
265 #ifndef C10_MOBILE
266 return in_parallel_region_ || (
267 num_intraop_threads.load() == CONSUMED &&
268 // Needed as intraop_launch() doesn't set in_parallel_region().
269 _get_intraop_pool().inThreadPool()
270 );
271 #else
272 return in_parallel_region_;
273 #endif // C10_MOBILE
274 }
275
intraop_launch(std::function<void ()> func)276 void intraop_launch(std::function<void()> func) {
277 #ifndef C10_MOBILE
278 if (!in_parallel_region() && get_num_threads() > 1) {
279 _get_intraop_pool().run(std::move(func));
280 } else {
281 // execute inline if we're in parallel region
282 func();
283 }
284 #else
285 // TODO: caffe2::PThreadPool only provides a data-parallel API.
286 // Task parallelism is not currently supported.
287 func();
288 #endif // C10_MOBILE
289 }
290
intraop_launch_future(std::function<void ()> func)291 c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
292 std::function<void()> func) {
293 #ifndef C10_MOBILE
294 auto future = c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
295 if (!in_parallel_region() && get_num_threads() > 1) {
296 _get_intraop_pool().run(
297 [func, future]() {
298 func();
299 future->markCompleted();
300 }
301 );
302 } else {
303 func();
304 future->markCompleted();
305 }
306 return future;
307 #else
308 // TODO: caffe2::PThreadPool only provides a data-parallel API.
309 // Task parallelism is not currently supported.
310 auto future = c10::make_intrusive<c10::ivalue::Future>(c10::dynT<NoneType>());
311 func();
312 future->markCompleted();
313 return future;
314 #endif // C10_MOBILE
315 }
316
317 } // namespace at
318 #endif
319