xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ParallelOpenMP.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <ATen/Config.h>
2*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/jit_type.h>
3*da0073e9SAndroid Build Coastguard Worker #if AT_PARALLEL_OPENMP
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/Parallel.h>
5*da0073e9SAndroid Build Coastguard Worker #include <ATen/ParallelFuture.h>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker #include <atomic>
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker #if AT_MKL_ENABLED()
10*da0073e9SAndroid Build Coastguard Worker #include <mkl.h>
11*da0073e9SAndroid Build Coastguard Worker #endif
12*da0073e9SAndroid Build Coastguard Worker 
13*da0073e9SAndroid Build Coastguard Worker #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
14*da0073e9SAndroid Build Coastguard Worker 
15*da0073e9SAndroid Build Coastguard Worker namespace at {
16*da0073e9SAndroid Build Coastguard Worker #if AT_MKLDNN_ENABLED()
17*da0073e9SAndroid Build Coastguard Worker namespace native { namespace mkldnn {
18*da0073e9SAndroid Build Coastguard Worker void clear_computation_cache();
19*da0073e9SAndroid Build Coastguard Worker }} // namespace native::mkldnn
20*da0073e9SAndroid Build Coastguard Worker #endif
21*da0073e9SAndroid Build Coastguard Worker 
22*da0073e9SAndroid Build Coastguard Worker namespace {
23*da0073e9SAndroid Build Coastguard Worker // Number of threads set by the user
24*da0073e9SAndroid Build Coastguard Worker std::atomic<int> num_threads{-1};
25*da0073e9SAndroid Build Coastguard Worker thread_local int this_thread_id{0};
26*da0073e9SAndroid Build Coastguard Worker 
27*da0073e9SAndroid Build Coastguard Worker } // namespace
28*da0073e9SAndroid Build Coastguard Worker 
init_num_threads()29*da0073e9SAndroid Build Coastguard Worker void init_num_threads() {
30*da0073e9SAndroid Build Coastguard Worker   auto nthreads = num_threads.load();
31*da0073e9SAndroid Build Coastguard Worker   if (nthreads > 0) {
32*da0073e9SAndroid Build Coastguard Worker     set_num_threads(nthreads);
33*da0073e9SAndroid Build Coastguard Worker   } else {
34*da0073e9SAndroid Build Coastguard Worker #if defined(_OPENMP) && AT_MKL_ENABLED() && !AT_MKL_SEQUENTIAL()
35*da0073e9SAndroid Build Coastguard Worker     // If we are using MKL an OpenMP make sure the number of threads match.
36*da0073e9SAndroid Build Coastguard Worker     // Otherwise, MKL and our OpenMP-enabled functions will keep changing the
37*da0073e9SAndroid Build Coastguard Worker     // size of the OpenMP thread pool, resulting in worse performance (and memory
38*da0073e9SAndroid Build Coastguard Worker     // leaks in GCC 5.4)
39*da0073e9SAndroid Build Coastguard Worker     omp_set_num_threads(mkl_get_max_threads());
40*da0073e9SAndroid Build Coastguard Worker #elif defined(_OPENMP)
41*da0073e9SAndroid Build Coastguard Worker     omp_set_num_threads(intraop_default_num_threads());
42*da0073e9SAndroid Build Coastguard Worker #endif
43*da0073e9SAndroid Build Coastguard Worker   }
44*da0073e9SAndroid Build Coastguard Worker }
45*da0073e9SAndroid Build Coastguard Worker 
set_num_threads(int nthreads)46*da0073e9SAndroid Build Coastguard Worker void set_num_threads(int nthreads) {
47*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
48*da0073e9SAndroid Build Coastguard Worker   num_threads.store(nthreads);
49*da0073e9SAndroid Build Coastguard Worker #ifdef _OPENMP
50*da0073e9SAndroid Build Coastguard Worker   omp_set_num_threads(nthreads);
51*da0073e9SAndroid Build Coastguard Worker #endif
52*da0073e9SAndroid Build Coastguard Worker #if AT_MKL_ENABLED()
53*da0073e9SAndroid Build Coastguard Worker   mkl_set_num_threads_local(nthreads);
54*da0073e9SAndroid Build Coastguard Worker 
55*da0073e9SAndroid Build Coastguard Worker   // because PyTorch uses OpenMP outside of MKL invocations
56*da0073e9SAndroid Build Coastguard Worker   // as well, we want this flag to be false, so that
57*da0073e9SAndroid Build Coastguard Worker   // threads aren't destroyed and recreated across every
58*da0073e9SAndroid Build Coastguard Worker   // MKL / non-MKL boundary of OpenMP usage
59*da0073e9SAndroid Build Coastguard Worker   // See https://github.com/pytorch/pytorch/issues/13757
60*da0073e9SAndroid Build Coastguard Worker   mkl_set_dynamic(false);
61*da0073e9SAndroid Build Coastguard Worker #endif
62*da0073e9SAndroid Build Coastguard Worker #ifdef USE_PTHREADPOOL
63*da0073e9SAndroid Build Coastguard Worker   // because PyTorch uses caffe2::pthreadpool() in QNNPACK
64*da0073e9SAndroid Build Coastguard Worker   caffe2::PThreadPool* const pool = caffe2::pthreadpool();
65*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
66*da0073e9SAndroid Build Coastguard Worker   pool->set_thread_count(nthreads);
67*da0073e9SAndroid Build Coastguard Worker #endif
68*da0073e9SAndroid Build Coastguard Worker #if AT_MKLDNN_ENABLED()
69*da0073e9SAndroid Build Coastguard Worker   at::native::mkldnn::clear_computation_cache();
70*da0073e9SAndroid Build Coastguard Worker #endif
71*da0073e9SAndroid Build Coastguard Worker }
72*da0073e9SAndroid Build Coastguard Worker 
73*da0073e9SAndroid Build Coastguard Worker // Explicitly calling omp_get_max_threads() as the size of the parallel
74*da0073e9SAndroid Build Coastguard Worker // region might be different in the new thread;
75*da0073e9SAndroid Build Coastguard Worker // Use init_num_threads() during thread initialization to ensure
76*da0073e9SAndroid Build Coastguard Worker // consistent size of parallel region in different threads
get_num_threads()77*da0073e9SAndroid Build Coastguard Worker int get_num_threads() {
78*da0073e9SAndroid Build Coastguard Worker #ifdef _OPENMP
79*da0073e9SAndroid Build Coastguard Worker   at::internal::lazy_init_num_threads();
80*da0073e9SAndroid Build Coastguard Worker   return omp_get_max_threads();
81*da0073e9SAndroid Build Coastguard Worker #else
82*da0073e9SAndroid Build Coastguard Worker   return 1;
83*da0073e9SAndroid Build Coastguard Worker #endif
84*da0073e9SAndroid Build Coastguard Worker }
85*da0073e9SAndroid Build Coastguard Worker 
get_thread_num()86*da0073e9SAndroid Build Coastguard Worker int get_thread_num() {
87*da0073e9SAndroid Build Coastguard Worker   return this_thread_id;
88*da0073e9SAndroid Build Coastguard Worker }
89*da0073e9SAndroid Build Coastguard Worker 
90*da0073e9SAndroid Build Coastguard Worker namespace internal {
set_thread_num(int id)91*da0073e9SAndroid Build Coastguard Worker void set_thread_num(int id) {
92*da0073e9SAndroid Build Coastguard Worker   this_thread_id = id;
93*da0073e9SAndroid Build Coastguard Worker }
94*da0073e9SAndroid Build Coastguard Worker }
95*da0073e9SAndroid Build Coastguard Worker 
in_parallel_region()96*da0073e9SAndroid Build Coastguard Worker bool in_parallel_region() {
97*da0073e9SAndroid Build Coastguard Worker #ifdef _OPENMP
98*da0073e9SAndroid Build Coastguard Worker   return omp_in_parallel();
99*da0073e9SAndroid Build Coastguard Worker #else
100*da0073e9SAndroid Build Coastguard Worker   return false;
101*da0073e9SAndroid Build Coastguard Worker #endif
102*da0073e9SAndroid Build Coastguard Worker }
103*da0073e9SAndroid Build Coastguard Worker 
intraop_launch(std::function<void ()> func)104*da0073e9SAndroid Build Coastguard Worker void intraop_launch(std::function<void()> func) {
105*da0073e9SAndroid Build Coastguard Worker   // execute inline in openmp case
106*da0073e9SAndroid Build Coastguard Worker   func();
107*da0073e9SAndroid Build Coastguard Worker }
108*da0073e9SAndroid Build Coastguard Worker 
intraop_launch_future(std::function<void ()> func)109*da0073e9SAndroid Build Coastguard Worker c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
110*da0073e9SAndroid Build Coastguard Worker     std::function<void()> func) {
111*da0073e9SAndroid Build Coastguard Worker   func();
112*da0073e9SAndroid Build Coastguard Worker   auto future = c10::make_intrusive<c10::ivalue::Future>(NoneType::get());
113*da0073e9SAndroid Build Coastguard Worker   future->markCompleted();
114*da0073e9SAndroid Build Coastguard Worker   return future;
115*da0073e9SAndroid Build Coastguard Worker }
116*da0073e9SAndroid Build Coastguard Worker 
117*da0073e9SAndroid Build Coastguard Worker } // namespace at
118*da0073e9SAndroid Build Coastguard Worker #endif
119