xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ParallelNative.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Config.h>
2 #if AT_PARALLEL_NATIVE
3 #include <ATen/Parallel.h>
4 #include <ATen/ParallelFuture.h>
5 #include <ATen/PTThreadPool.h>
6 
7 #ifndef C10_MOBILE
8 #include <c10/core/thread_pool.h>
9 #include <c10/util/irange.h>
10 #else
11 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
12 #endif // C10_MOBILE
13 
14 #include <atomic>
15 #include <utility>
16 
17 #ifdef _OPENMP
18 #include <omp.h>
19 #endif
20 
21 #if AT_MKL_ENABLED()
22 #include <mkl.h>
23 #endif
24 
25 namespace at {
26 namespace {
27 // used with _set_in_parallel_region to mark master thread
28 // as in parallel region while executing parallel primitives
29 thread_local bool in_parallel_region_ = false;
30 
31 // thread number (task_id) set by parallel primitive
32 thread_local int thread_num_ = 0;
33 
_set_in_parallel_region(bool in_region)34 void _set_in_parallel_region(bool in_region) {
35   in_parallel_region_ = in_region;
36 }
37 
38 }  // namespace (anonymous)
39 
40 namespace internal {
set_thread_num(int thread_num)41 void set_thread_num(int thread_num) {
42   thread_num_ = thread_num;
43 }
44 }
45 
46 namespace {
_unset_thread_num()47 void _unset_thread_num() {
48   thread_num_ = 0;
49 }
50 
51 #ifndef C10_MOBILE
52 
53 const int NOT_SET = -1;
54 const int CONSUMED = -2;
55 
56 // Number of threads set by the user
57 // NOT_SET -> positive value -> CONSUMED
58 // or
59 // NOT_SET -> CONSUMED
60 // Meaning:
61 //  - NOT_SET - pool not initialized, user value is not set
62 //  - positive value - pool not initialized, user value set
63 //  - CONSUMED - pool is initialized
64 std::atomic<int> num_intraop_threads{NOT_SET};
65 
_num_pool_threads(int nthreads)66 int _num_pool_threads(int nthreads) {
67   if (nthreads == NOT_SET) {
68     nthreads = intraop_default_num_threads();
69   } else {
70     TORCH_INTERNAL_ASSERT(nthreads > 0);
71   }
72   // minus one because of the master thread
73   return nthreads - 1;
74 }
75 
_get_intraop_pool()76 TaskThreadPoolBase& _get_intraop_pool() {
77   static std::shared_ptr<TaskThreadPoolBase> pool =
78       ThreadPoolRegistry()->Create(
79           "C10",
80           /* device_id */ 0,
81           /* pool_size */ _num_pool_threads(num_intraop_threads.exchange(CONSUMED)),
82           /* create_new */ true); // create a separate thread pool for intra-op
83   return *pool;
84 }
85 
86 #endif // C10_MOBILE
87 
88 // Run lambda function `fn` over `task_id` in [0, `range`) with threadpool.
89 // `fn` will be called with params: (thread_pool_task_id, task_id).
_run_with_pool(const std::function<void (int,size_t)> & fn,size_t range)90 void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range) {
91 #ifndef C10_MOBILE
92   for (const auto i : c10::irange(1, range)) {
93     _get_intraop_pool().run([fn, i]() { fn((int)i, i); });
94   }
95   // Run the first task on the current thread directly.
96   fn(0, 0);
97 #else
98   caffe2::PThreadPool* const pool = caffe2::pthreadpool();
99   TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
100 
101   pool->run(
102     // PThreadPool::run() is blocking.  A std::function [const] reference to
103     // this lambda cannot go out of scope before PThreadPool::run() returns.
104     [&fn](const size_t task_id) {
105       fn(0 /* unused */, task_id);
106     }, range);
107 #endif // C10_MOBILE
108 }
109 
110 // RAII guard helps to support in_parallel_region() and get_thread_num() API.
111 struct ParallelRegionGuard {
ParallelRegionGuardat::__anona5067dc90211::ParallelRegionGuard112   ParallelRegionGuard(int task_id) {
113     internal::set_thread_num(task_id);
114     _set_in_parallel_region(true);
115   }
116 
~ParallelRegionGuardat::__anona5067dc90211::ParallelRegionGuard117   ~ParallelRegionGuard() {
118     _set_in_parallel_region(false);
119     _unset_thread_num();
120   }
121 };
122 
123 } // namespace
124 
125 namespace internal {
126 
calc_num_tasks_and_chunk_size(int64_t begin,int64_t end,int64_t grain_size)127 inline std::tuple<size_t, size_t> calc_num_tasks_and_chunk_size(
128     int64_t begin, int64_t end, int64_t grain_size) {
129   if ((end - begin) < grain_size) {
130     return std::make_tuple(1, std::max((int64_t)0, end - begin));
131   }
132   // Choose number of tasks based on grain size and number of threads.
133   size_t chunk_size = divup((end - begin), get_num_threads());
134   // Make sure each task is at least grain_size size.
135   chunk_size = std::max((size_t)grain_size, chunk_size);
136   size_t num_tasks = divup((end - begin), chunk_size);
137   return std::make_tuple(num_tasks, chunk_size);
138 }
139 
invoke_parallel(const int64_t begin,const int64_t end,const int64_t grain_size,const std::function<void (int64_t,int64_t)> & f)140 void invoke_parallel(
141   const int64_t begin,
142   const int64_t end,
143   const int64_t grain_size,
144   const std::function<void(int64_t, int64_t)>& f) {
145   at::internal::lazy_init_num_threads();
146 
147   size_t num_tasks = 0, chunk_size = 0;
148   std::tie(num_tasks, chunk_size) =
149       internal::calc_num_tasks_and_chunk_size(begin, end, grain_size);
150 
151   struct {
152     std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
153     std::exception_ptr eptr;
154     std::mutex mutex;
155     std::atomic_size_t remaining{0};
156     std::condition_variable cv;
157   } state;
158 
159   auto task = [f, &state, begin, end, chunk_size]
160       (int /* unused */, size_t task_id) {
161     int64_t local_start = begin + task_id * chunk_size;
162     if (local_start < end) {
163       int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start));
164       try {
165         ParallelRegionGuard guard(task_id);
166         f(local_start, local_end);
167       } catch (...) {
168         if (!state.err_flag.test_and_set()) {
169           state.eptr = std::current_exception();
170         }
171       }
172     }
173     {
174       std::unique_lock<std::mutex> lk(state.mutex);
175       if (--state.remaining == 0) {
176         state.cv.notify_one();
177       }
178     }
179   };
180   state.remaining = num_tasks;
181   _run_with_pool(std::move(task), num_tasks);
182 
183   // Wait for all tasks to finish.
184   {
185     std::unique_lock<std::mutex> lk(state.mutex);
186     if (state.remaining != 0) {
187       state.cv.wait(lk);
188     }
189   }
190   if (state.eptr) {
191     std::rethrow_exception(state.eptr);
192   }
193 }
194 
195 } // namespace internal
196 
init_num_threads()197 void init_num_threads() {
198 #ifdef _OPENMP
199   omp_set_num_threads(1);
200 #endif
201 
202 #if AT_MKL_ENABLED()
203   mkl_set_num_threads(1);
204 #endif
205 
206 #ifdef C10_MOBILE
207   caffe2::pthreadpool();
208 #endif
209 }
210 
set_num_threads(int nthreads)211 void set_num_threads(int nthreads) {
212 #ifndef C10_MOBILE
213   TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
214   int no_value = NOT_SET;
215   if (!num_intraop_threads.compare_exchange_strong(no_value, nthreads)) {
216     // num_intraop_threads either stores a positive integer or CONSUMED,
217     // check that requested size is the same as the current one
218     int stored_nthreads = num_intraop_threads.load();
219     if (stored_nthreads <= 0) {
220       // plus one because of master thread
221       // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
222       stored_nthreads = _get_intraop_pool().size() + 1;
223     }
224     if (stored_nthreads != nthreads) {
225       TORCH_WARN(
226         "Cannot set number of intraop threads "
227         "after parallel work has started or after set_num_threads call "
228         "when using native parallel backend");
229     }
230   }
231 #else
232   caffe2::PThreadPool* const pool = caffe2::pthreadpool();
233   TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
234   pool->set_thread_count(nthreads);
235 #endif // C10_MOBILE
236 }
237 
get_num_threads()238 int get_num_threads() {
239   at::internal::lazy_init_num_threads();
240 #ifndef C10_MOBILE
241   // not initializing pool unnecessarily,
242   // because pool cannot be resized after initialization
243   int nthreads = num_intraop_threads.load();
244   if (nthreads > 0) {
245     return nthreads;
246   } else if (nthreads == NOT_SET) {
247     return intraop_default_num_threads();
248   } else {
249     TORCH_INTERNAL_ASSERT(nthreads == CONSUMED);
250     // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
251     return _get_intraop_pool().size() + 1;
252   }
253 #else
254   caffe2::PThreadPool* const pool = caffe2::pthreadpool();
255   TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!")
256   return in_parallel_region() ? 1 /* current thread */ : pool->get_thread_count();
257 #endif // C10_MOBILE
258 }
259 
get_thread_num()260 int get_thread_num() {
261   return thread_num_;
262 }
263 
in_parallel_region()264 bool in_parallel_region() {
265 #ifndef C10_MOBILE
266   return in_parallel_region_ || (
267     num_intraop_threads.load() == CONSUMED &&
268     // Needed as intraop_launch() doesn't set in_parallel_region().
269     _get_intraop_pool().inThreadPool()
270   );
271 #else
272   return in_parallel_region_;
273 #endif // C10_MOBILE
274 }
275 
intraop_launch(std::function<void ()> func)276 void intraop_launch(std::function<void()> func) {
277 #ifndef C10_MOBILE
278   if (!in_parallel_region() && get_num_threads() > 1) {
279     _get_intraop_pool().run(std::move(func));
280   } else {
281     // execute inline if we're in parallel region
282     func();
283   }
284 #else
285   // TODO: caffe2::PThreadPool only provides a data-parallel API.
286   // Task parallelism is not currently supported.
287   func();
288 #endif // C10_MOBILE
289 }
290 
intraop_launch_future(std::function<void ()> func)291 c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
292     std::function<void()> func) {
293 #ifndef C10_MOBILE
294   auto future = c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
295   if (!in_parallel_region() && get_num_threads() > 1) {
296     _get_intraop_pool().run(
297       [func, future]() {
298         func();
299         future->markCompleted();
300       }
301     );
302   } else {
303     func();
304     future->markCompleted();
305   }
306   return future;
307 #else
308   // TODO: caffe2::PThreadPool only provides a data-parallel API.
309   // Task parallelism is not currently supported.
310   auto future = c10::make_intrusive<c10::ivalue::Future>(c10::dynT<NoneType>());
311   func();
312   future->markCompleted();
313   return future;
314 #endif // C10_MOBILE
315 }
316 
317 } // namespace at
318 #endif
319