1*da0073e9SAndroid Build Coastguard Worker #include <ATen/Config.h>
2*da0073e9SAndroid Build Coastguard Worker #if AT_PARALLEL_NATIVE
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/Parallel.h>
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/ParallelFuture.h>
5*da0073e9SAndroid Build Coastguard Worker #include <ATen/PTThreadPool.h>
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker #ifndef C10_MOBILE
8*da0073e9SAndroid Build Coastguard Worker #include <c10/core/thread_pool.h>
9*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
10*da0073e9SAndroid Build Coastguard Worker #else
11*da0073e9SAndroid Build Coastguard Worker #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
12*da0073e9SAndroid Build Coastguard Worker #endif // C10_MOBILE
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker #include <atomic>
15*da0073e9SAndroid Build Coastguard Worker #include <utility>
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker #ifdef _OPENMP
18*da0073e9SAndroid Build Coastguard Worker #include <omp.h>
19*da0073e9SAndroid Build Coastguard Worker #endif
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker #if AT_MKL_ENABLED()
22*da0073e9SAndroid Build Coastguard Worker #include <mkl.h>
23*da0073e9SAndroid Build Coastguard Worker #endif
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker namespace at {
26*da0073e9SAndroid Build Coastguard Worker namespace {
27*da0073e9SAndroid Build Coastguard Worker // used with _set_in_parallel_region to mark master thread
28*da0073e9SAndroid Build Coastguard Worker // as in parallel region while executing parallel primitives
29*da0073e9SAndroid Build Coastguard Worker thread_local bool in_parallel_region_ = false;
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker // thread number (task_id) set by parallel primitive
32*da0073e9SAndroid Build Coastguard Worker thread_local int thread_num_ = 0;
33*da0073e9SAndroid Build Coastguard Worker
_set_in_parallel_region(bool in_region)34*da0073e9SAndroid Build Coastguard Worker void _set_in_parallel_region(bool in_region) {
35*da0073e9SAndroid Build Coastguard Worker in_parallel_region_ = in_region;
36*da0073e9SAndroid Build Coastguard Worker }
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker } // namespace (anonymous)
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker namespace internal {
set_thread_num(int thread_num)41*da0073e9SAndroid Build Coastguard Worker void set_thread_num(int thread_num) {
42*da0073e9SAndroid Build Coastguard Worker thread_num_ = thread_num;
43*da0073e9SAndroid Build Coastguard Worker }
44*da0073e9SAndroid Build Coastguard Worker }
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker namespace {
_unset_thread_num()47*da0073e9SAndroid Build Coastguard Worker void _unset_thread_num() {
48*da0073e9SAndroid Build Coastguard Worker thread_num_ = 0;
49*da0073e9SAndroid Build Coastguard Worker }
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker #ifndef C10_MOBILE
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker const int NOT_SET = -1;
54*da0073e9SAndroid Build Coastguard Worker const int CONSUMED = -2;
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker // Number of threads set by the user
57*da0073e9SAndroid Build Coastguard Worker // NOT_SET -> positive value -> CONSUMED
58*da0073e9SAndroid Build Coastguard Worker // or
59*da0073e9SAndroid Build Coastguard Worker // NOT_SET -> CONSUMED
60*da0073e9SAndroid Build Coastguard Worker // Meaning:
61*da0073e9SAndroid Build Coastguard Worker // - NOT_SET - pool not initialized, user value is not set
62*da0073e9SAndroid Build Coastguard Worker // - positive value - pool not initialized, user value set
63*da0073e9SAndroid Build Coastguard Worker // - CONSUMED - pool is initialized
64*da0073e9SAndroid Build Coastguard Worker std::atomic<int> num_intraop_threads{NOT_SET};
65*da0073e9SAndroid Build Coastguard Worker
_num_pool_threads(int nthreads)66*da0073e9SAndroid Build Coastguard Worker int _num_pool_threads(int nthreads) {
67*da0073e9SAndroid Build Coastguard Worker if (nthreads == NOT_SET) {
68*da0073e9SAndroid Build Coastguard Worker nthreads = intraop_default_num_threads();
69*da0073e9SAndroid Build Coastguard Worker } else {
70*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(nthreads > 0);
71*da0073e9SAndroid Build Coastguard Worker }
72*da0073e9SAndroid Build Coastguard Worker // minus one because of the master thread
73*da0073e9SAndroid Build Coastguard Worker return nthreads - 1;
74*da0073e9SAndroid Build Coastguard Worker }
75*da0073e9SAndroid Build Coastguard Worker
_get_intraop_pool()76*da0073e9SAndroid Build Coastguard Worker TaskThreadPoolBase& _get_intraop_pool() {
77*da0073e9SAndroid Build Coastguard Worker static std::shared_ptr<TaskThreadPoolBase> pool =
78*da0073e9SAndroid Build Coastguard Worker ThreadPoolRegistry()->Create(
79*da0073e9SAndroid Build Coastguard Worker "C10",
80*da0073e9SAndroid Build Coastguard Worker /* device_id */ 0,
81*da0073e9SAndroid Build Coastguard Worker /* pool_size */ _num_pool_threads(num_intraop_threads.exchange(CONSUMED)),
82*da0073e9SAndroid Build Coastguard Worker /* create_new */ true); // create a separate thread pool for intra-op
83*da0073e9SAndroid Build Coastguard Worker return *pool;
84*da0073e9SAndroid Build Coastguard Worker }
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker #endif // C10_MOBILE
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker // Run lambda function `fn` over `task_id` in [0, `range`) with threadpool.
89*da0073e9SAndroid Build Coastguard Worker // `fn` will be called with params: (thread_pool_task_id, task_id).
_run_with_pool(const std::function<void (int,size_t)> & fn,size_t range)90*da0073e9SAndroid Build Coastguard Worker void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range) {
91*da0073e9SAndroid Build Coastguard Worker #ifndef C10_MOBILE
92*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(1, range)) {
93*da0073e9SAndroid Build Coastguard Worker _get_intraop_pool().run([fn, i]() { fn((int)i, i); });
94*da0073e9SAndroid Build Coastguard Worker }
95*da0073e9SAndroid Build Coastguard Worker // Run the first task on the current thread directly.
96*da0073e9SAndroid Build Coastguard Worker fn(0, 0);
97*da0073e9SAndroid Build Coastguard Worker #else
98*da0073e9SAndroid Build Coastguard Worker caffe2::PThreadPool* const pool = caffe2::pthreadpool();
99*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker pool->run(
102*da0073e9SAndroid Build Coastguard Worker // PThreadPool::run() is blocking. A std::function [const] reference to
103*da0073e9SAndroid Build Coastguard Worker // this lambda cannot go out of scope before PThreadPool::run() returns.
104*da0073e9SAndroid Build Coastguard Worker [&fn](const size_t task_id) {
105*da0073e9SAndroid Build Coastguard Worker fn(0 /* unused */, task_id);
106*da0073e9SAndroid Build Coastguard Worker }, range);
107*da0073e9SAndroid Build Coastguard Worker #endif // C10_MOBILE
108*da0073e9SAndroid Build Coastguard Worker }
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker // RAII guard helps to support in_parallel_region() and get_thread_num() API.
111*da0073e9SAndroid Build Coastguard Worker struct ParallelRegionGuard {
ParallelRegionGuardat::__anona5067dc90211::ParallelRegionGuard112*da0073e9SAndroid Build Coastguard Worker ParallelRegionGuard(int task_id) {
113*da0073e9SAndroid Build Coastguard Worker internal::set_thread_num(task_id);
114*da0073e9SAndroid Build Coastguard Worker _set_in_parallel_region(true);
115*da0073e9SAndroid Build Coastguard Worker }
116*da0073e9SAndroid Build Coastguard Worker
~ParallelRegionGuardat::__anona5067dc90211::ParallelRegionGuard117*da0073e9SAndroid Build Coastguard Worker ~ParallelRegionGuard() {
118*da0073e9SAndroid Build Coastguard Worker _set_in_parallel_region(false);
119*da0073e9SAndroid Build Coastguard Worker _unset_thread_num();
120*da0073e9SAndroid Build Coastguard Worker }
121*da0073e9SAndroid Build Coastguard Worker };
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker } // namespace
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker namespace internal {
126*da0073e9SAndroid Build Coastguard Worker
calc_num_tasks_and_chunk_size(int64_t begin,int64_t end,int64_t grain_size)127*da0073e9SAndroid Build Coastguard Worker inline std::tuple<size_t, size_t> calc_num_tasks_and_chunk_size(
128*da0073e9SAndroid Build Coastguard Worker int64_t begin, int64_t end, int64_t grain_size) {
129*da0073e9SAndroid Build Coastguard Worker if ((end - begin) < grain_size) {
130*da0073e9SAndroid Build Coastguard Worker return std::make_tuple(1, std::max((int64_t)0, end - begin));
131*da0073e9SAndroid Build Coastguard Worker }
132*da0073e9SAndroid Build Coastguard Worker // Choose number of tasks based on grain size and number of threads.
133*da0073e9SAndroid Build Coastguard Worker size_t chunk_size = divup((end - begin), get_num_threads());
134*da0073e9SAndroid Build Coastguard Worker // Make sure each task is at least grain_size size.
135*da0073e9SAndroid Build Coastguard Worker chunk_size = std::max((size_t)grain_size, chunk_size);
136*da0073e9SAndroid Build Coastguard Worker size_t num_tasks = divup((end - begin), chunk_size);
137*da0073e9SAndroid Build Coastguard Worker return std::make_tuple(num_tasks, chunk_size);
138*da0073e9SAndroid Build Coastguard Worker }
139*da0073e9SAndroid Build Coastguard Worker
invoke_parallel(const int64_t begin,const int64_t end,const int64_t grain_size,const std::function<void (int64_t,int64_t)> & f)140*da0073e9SAndroid Build Coastguard Worker void invoke_parallel(
141*da0073e9SAndroid Build Coastguard Worker const int64_t begin,
142*da0073e9SAndroid Build Coastguard Worker const int64_t end,
143*da0073e9SAndroid Build Coastguard Worker const int64_t grain_size,
144*da0073e9SAndroid Build Coastguard Worker const std::function<void(int64_t, int64_t)>& f) {
145*da0073e9SAndroid Build Coastguard Worker at::internal::lazy_init_num_threads();
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker size_t num_tasks = 0, chunk_size = 0;
148*da0073e9SAndroid Build Coastguard Worker std::tie(num_tasks, chunk_size) =
149*da0073e9SAndroid Build Coastguard Worker internal::calc_num_tasks_and_chunk_size(begin, end, grain_size);
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker struct {
152*da0073e9SAndroid Build Coastguard Worker std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
153*da0073e9SAndroid Build Coastguard Worker std::exception_ptr eptr;
154*da0073e9SAndroid Build Coastguard Worker std::mutex mutex;
155*da0073e9SAndroid Build Coastguard Worker std::atomic_size_t remaining{0};
156*da0073e9SAndroid Build Coastguard Worker std::condition_variable cv;
157*da0073e9SAndroid Build Coastguard Worker } state;
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker auto task = [f, &state, begin, end, chunk_size]
160*da0073e9SAndroid Build Coastguard Worker (int /* unused */, size_t task_id) {
161*da0073e9SAndroid Build Coastguard Worker int64_t local_start = begin + task_id * chunk_size;
162*da0073e9SAndroid Build Coastguard Worker if (local_start < end) {
163*da0073e9SAndroid Build Coastguard Worker int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start));
164*da0073e9SAndroid Build Coastguard Worker try {
165*da0073e9SAndroid Build Coastguard Worker ParallelRegionGuard guard(task_id);
166*da0073e9SAndroid Build Coastguard Worker f(local_start, local_end);
167*da0073e9SAndroid Build Coastguard Worker } catch (...) {
168*da0073e9SAndroid Build Coastguard Worker if (!state.err_flag.test_and_set()) {
169*da0073e9SAndroid Build Coastguard Worker state.eptr = std::current_exception();
170*da0073e9SAndroid Build Coastguard Worker }
171*da0073e9SAndroid Build Coastguard Worker }
172*da0073e9SAndroid Build Coastguard Worker }
173*da0073e9SAndroid Build Coastguard Worker {
174*da0073e9SAndroid Build Coastguard Worker std::unique_lock<std::mutex> lk(state.mutex);
175*da0073e9SAndroid Build Coastguard Worker if (--state.remaining == 0) {
176*da0073e9SAndroid Build Coastguard Worker state.cv.notify_one();
177*da0073e9SAndroid Build Coastguard Worker }
178*da0073e9SAndroid Build Coastguard Worker }
179*da0073e9SAndroid Build Coastguard Worker };
180*da0073e9SAndroid Build Coastguard Worker state.remaining = num_tasks;
181*da0073e9SAndroid Build Coastguard Worker _run_with_pool(std::move(task), num_tasks);
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker // Wait for all tasks to finish.
184*da0073e9SAndroid Build Coastguard Worker {
185*da0073e9SAndroid Build Coastguard Worker std::unique_lock<std::mutex> lk(state.mutex);
186*da0073e9SAndroid Build Coastguard Worker if (state.remaining != 0) {
187*da0073e9SAndroid Build Coastguard Worker state.cv.wait(lk);
188*da0073e9SAndroid Build Coastguard Worker }
189*da0073e9SAndroid Build Coastguard Worker }
190*da0073e9SAndroid Build Coastguard Worker if (state.eptr) {
191*da0073e9SAndroid Build Coastguard Worker std::rethrow_exception(state.eptr);
192*da0073e9SAndroid Build Coastguard Worker }
193*da0073e9SAndroid Build Coastguard Worker }
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker } // namespace internal
196*da0073e9SAndroid Build Coastguard Worker
init_num_threads()197*da0073e9SAndroid Build Coastguard Worker void init_num_threads() {
198*da0073e9SAndroid Build Coastguard Worker #ifdef _OPENMP
199*da0073e9SAndroid Build Coastguard Worker omp_set_num_threads(1);
200*da0073e9SAndroid Build Coastguard Worker #endif
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker #if AT_MKL_ENABLED()
203*da0073e9SAndroid Build Coastguard Worker mkl_set_num_threads(1);
204*da0073e9SAndroid Build Coastguard Worker #endif
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker #ifdef C10_MOBILE
207*da0073e9SAndroid Build Coastguard Worker caffe2::pthreadpool();
208*da0073e9SAndroid Build Coastguard Worker #endif
209*da0073e9SAndroid Build Coastguard Worker }
210*da0073e9SAndroid Build Coastguard Worker
set_num_threads(int nthreads)211*da0073e9SAndroid Build Coastguard Worker void set_num_threads(int nthreads) {
212*da0073e9SAndroid Build Coastguard Worker #ifndef C10_MOBILE
213*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
214*da0073e9SAndroid Build Coastguard Worker int no_value = NOT_SET;
215*da0073e9SAndroid Build Coastguard Worker if (!num_intraop_threads.compare_exchange_strong(no_value, nthreads)) {
216*da0073e9SAndroid Build Coastguard Worker // num_intraop_threads either stores a positive integer or CONSUMED,
217*da0073e9SAndroid Build Coastguard Worker // check that requested size is the same as the current one
218*da0073e9SAndroid Build Coastguard Worker int stored_nthreads = num_intraop_threads.load();
219*da0073e9SAndroid Build Coastguard Worker if (stored_nthreads <= 0) {
220*da0073e9SAndroid Build Coastguard Worker // plus one because of master thread
221*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
222*da0073e9SAndroid Build Coastguard Worker stored_nthreads = _get_intraop_pool().size() + 1;
223*da0073e9SAndroid Build Coastguard Worker }
224*da0073e9SAndroid Build Coastguard Worker if (stored_nthreads != nthreads) {
225*da0073e9SAndroid Build Coastguard Worker TORCH_WARN(
226*da0073e9SAndroid Build Coastguard Worker "Cannot set number of intraop threads "
227*da0073e9SAndroid Build Coastguard Worker "after parallel work has started or after set_num_threads call "
228*da0073e9SAndroid Build Coastguard Worker "when using native parallel backend");
229*da0073e9SAndroid Build Coastguard Worker }
230*da0073e9SAndroid Build Coastguard Worker }
231*da0073e9SAndroid Build Coastguard Worker #else
232*da0073e9SAndroid Build Coastguard Worker caffe2::PThreadPool* const pool = caffe2::pthreadpool();
233*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
234*da0073e9SAndroid Build Coastguard Worker pool->set_thread_count(nthreads);
235*da0073e9SAndroid Build Coastguard Worker #endif // C10_MOBILE
236*da0073e9SAndroid Build Coastguard Worker }
237*da0073e9SAndroid Build Coastguard Worker
get_num_threads()238*da0073e9SAndroid Build Coastguard Worker int get_num_threads() {
239*da0073e9SAndroid Build Coastguard Worker at::internal::lazy_init_num_threads();
240*da0073e9SAndroid Build Coastguard Worker #ifndef C10_MOBILE
241*da0073e9SAndroid Build Coastguard Worker // not initializing pool unnecessarily,
242*da0073e9SAndroid Build Coastguard Worker // because pool cannot be resized after initialization
243*da0073e9SAndroid Build Coastguard Worker int nthreads = num_intraop_threads.load();
244*da0073e9SAndroid Build Coastguard Worker if (nthreads > 0) {
245*da0073e9SAndroid Build Coastguard Worker return nthreads;
246*da0073e9SAndroid Build Coastguard Worker } else if (nthreads == NOT_SET) {
247*da0073e9SAndroid Build Coastguard Worker return intraop_default_num_threads();
248*da0073e9SAndroid Build Coastguard Worker } else {
249*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(nthreads == CONSUMED);
250*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
251*da0073e9SAndroid Build Coastguard Worker return _get_intraop_pool().size() + 1;
252*da0073e9SAndroid Build Coastguard Worker }
253*da0073e9SAndroid Build Coastguard Worker #else
254*da0073e9SAndroid Build Coastguard Worker caffe2::PThreadPool* const pool = caffe2::pthreadpool();
255*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!")
256*da0073e9SAndroid Build Coastguard Worker return in_parallel_region() ? 1 /* current thread */ : pool->get_thread_count();
257*da0073e9SAndroid Build Coastguard Worker #endif // C10_MOBILE
258*da0073e9SAndroid Build Coastguard Worker }
259*da0073e9SAndroid Build Coastguard Worker
get_thread_num()260*da0073e9SAndroid Build Coastguard Worker int get_thread_num() {
261*da0073e9SAndroid Build Coastguard Worker return thread_num_;
262*da0073e9SAndroid Build Coastguard Worker }
263*da0073e9SAndroid Build Coastguard Worker
in_parallel_region()264*da0073e9SAndroid Build Coastguard Worker bool in_parallel_region() {
265*da0073e9SAndroid Build Coastguard Worker #ifndef C10_MOBILE
266*da0073e9SAndroid Build Coastguard Worker return in_parallel_region_ || (
267*da0073e9SAndroid Build Coastguard Worker num_intraop_threads.load() == CONSUMED &&
268*da0073e9SAndroid Build Coastguard Worker // Needed as intraop_launch() doesn't set in_parallel_region().
269*da0073e9SAndroid Build Coastguard Worker _get_intraop_pool().inThreadPool()
270*da0073e9SAndroid Build Coastguard Worker );
271*da0073e9SAndroid Build Coastguard Worker #else
272*da0073e9SAndroid Build Coastguard Worker return in_parallel_region_;
273*da0073e9SAndroid Build Coastguard Worker #endif // C10_MOBILE
274*da0073e9SAndroid Build Coastguard Worker }
275*da0073e9SAndroid Build Coastguard Worker
intraop_launch(std::function<void ()> func)276*da0073e9SAndroid Build Coastguard Worker void intraop_launch(std::function<void()> func) {
277*da0073e9SAndroid Build Coastguard Worker #ifndef C10_MOBILE
278*da0073e9SAndroid Build Coastguard Worker if (!in_parallel_region() && get_num_threads() > 1) {
279*da0073e9SAndroid Build Coastguard Worker _get_intraop_pool().run(std::move(func));
280*da0073e9SAndroid Build Coastguard Worker } else {
281*da0073e9SAndroid Build Coastguard Worker // execute inline if we're in parallel region
282*da0073e9SAndroid Build Coastguard Worker func();
283*da0073e9SAndroid Build Coastguard Worker }
284*da0073e9SAndroid Build Coastguard Worker #else
285*da0073e9SAndroid Build Coastguard Worker // TODO: caffe2::PThreadPool only provides a data-parallel API.
286*da0073e9SAndroid Build Coastguard Worker // Task parallelism is not currently supported.
287*da0073e9SAndroid Build Coastguard Worker func();
288*da0073e9SAndroid Build Coastguard Worker #endif // C10_MOBILE
289*da0073e9SAndroid Build Coastguard Worker }
290*da0073e9SAndroid Build Coastguard Worker
intraop_launch_future(std::function<void ()> func)291*da0073e9SAndroid Build Coastguard Worker c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
292*da0073e9SAndroid Build Coastguard Worker std::function<void()> func) {
293*da0073e9SAndroid Build Coastguard Worker #ifndef C10_MOBILE
294*da0073e9SAndroid Build Coastguard Worker auto future = c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
295*da0073e9SAndroid Build Coastguard Worker if (!in_parallel_region() && get_num_threads() > 1) {
296*da0073e9SAndroid Build Coastguard Worker _get_intraop_pool().run(
297*da0073e9SAndroid Build Coastguard Worker [func, future]() {
298*da0073e9SAndroid Build Coastguard Worker func();
299*da0073e9SAndroid Build Coastguard Worker future->markCompleted();
300*da0073e9SAndroid Build Coastguard Worker }
301*da0073e9SAndroid Build Coastguard Worker );
302*da0073e9SAndroid Build Coastguard Worker } else {
303*da0073e9SAndroid Build Coastguard Worker func();
304*da0073e9SAndroid Build Coastguard Worker future->markCompleted();
305*da0073e9SAndroid Build Coastguard Worker }
306*da0073e9SAndroid Build Coastguard Worker return future;
307*da0073e9SAndroid Build Coastguard Worker #else
308*da0073e9SAndroid Build Coastguard Worker // TODO: caffe2::PThreadPool only provides a data-parallel API.
309*da0073e9SAndroid Build Coastguard Worker // Task parallelism is not currently supported.
310*da0073e9SAndroid Build Coastguard Worker auto future = c10::make_intrusive<c10::ivalue::Future>(c10::dynT<NoneType>());
311*da0073e9SAndroid Build Coastguard Worker func();
312*da0073e9SAndroid Build Coastguard Worker future->markCompleted();
313*da0073e9SAndroid Build Coastguard Worker return future;
314*da0073e9SAndroid Build Coastguard Worker #endif // C10_MOBILE
315*da0073e9SAndroid Build Coastguard Worker }
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker } // namespace at
318*da0073e9SAndroid Build Coastguard Worker #endif
319