xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ParallelOpenMP.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Config.h>
2 #include <ATen/core/jit_type.h>
3 #if AT_PARALLEL_OPENMP
4 #include <ATen/Parallel.h>
5 #include <ATen/ParallelFuture.h>
6 
7 #include <atomic>
8 
9 #if AT_MKL_ENABLED()
10 #include <mkl.h>
11 #endif
12 
13 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
14 
15 namespace at {
16 #if AT_MKLDNN_ENABLED()
17 namespace native { namespace mkldnn {
18 void clear_computation_cache();
19 }} // namespace native::mkldnn
20 #endif
21 
22 namespace {
23 // Number of threads set by the user
24 std::atomic<int> num_threads{-1};
25 thread_local int this_thread_id{0};
26 
27 } // namespace
28 
init_num_threads()29 void init_num_threads() {
30   auto nthreads = num_threads.load();
31   if (nthreads > 0) {
32     set_num_threads(nthreads);
33   } else {
34 #if defined(_OPENMP) && AT_MKL_ENABLED() && !AT_MKL_SEQUENTIAL()
35     // If we are using MKL an OpenMP make sure the number of threads match.
36     // Otherwise, MKL and our OpenMP-enabled functions will keep changing the
37     // size of the OpenMP thread pool, resulting in worse performance (and memory
38     // leaks in GCC 5.4)
39     omp_set_num_threads(mkl_get_max_threads());
40 #elif defined(_OPENMP)
41     omp_set_num_threads(intraop_default_num_threads());
42 #endif
43   }
44 }
45 
set_num_threads(int nthreads)46 void set_num_threads(int nthreads) {
47   TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
48   num_threads.store(nthreads);
49 #ifdef _OPENMP
50   omp_set_num_threads(nthreads);
51 #endif
52 #if AT_MKL_ENABLED()
53   mkl_set_num_threads_local(nthreads);
54 
55   // because PyTorch uses OpenMP outside of MKL invocations
56   // as well, we want this flag to be false, so that
57   // threads aren't destroyed and recreated across every
58   // MKL / non-MKL boundary of OpenMP usage
59   // See https://github.com/pytorch/pytorch/issues/13757
60   mkl_set_dynamic(false);
61 #endif
62 #ifdef USE_PTHREADPOOL
63   // because PyTorch uses caffe2::pthreadpool() in QNNPACK
64   caffe2::PThreadPool* const pool = caffe2::pthreadpool();
65   TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
66   pool->set_thread_count(nthreads);
67 #endif
68 #if AT_MKLDNN_ENABLED()
69   at::native::mkldnn::clear_computation_cache();
70 #endif
71 }
72 
73 // Explicitly calling omp_get_max_threads() as the size of the parallel
74 // region might be different in the new thread;
75 // Use init_num_threads() during thread initialization to ensure
76 // consistent size of parallel region in different threads
get_num_threads()77 int get_num_threads() {
78 #ifdef _OPENMP
79   at::internal::lazy_init_num_threads();
80   return omp_get_max_threads();
81 #else
82   return 1;
83 #endif
84 }
85 
get_thread_num()86 int get_thread_num() {
87   return this_thread_id;
88 }
89 
90 namespace internal {
set_thread_num(int id)91 void set_thread_num(int id) {
92   this_thread_id = id;
93 }
94 }
95 
in_parallel_region()96 bool in_parallel_region() {
97 #ifdef _OPENMP
98   return omp_in_parallel();
99 #else
100   return false;
101 #endif
102 }
103 
intraop_launch(std::function<void ()> func)104 void intraop_launch(std::function<void()> func) {
105   // execute inline in openmp case
106   func();
107 }
108 
intraop_launch_future(std::function<void ()> func)109 c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
110     std::function<void()> func) {
111   func();
112   auto future = c10::make_intrusive<c10::ivalue::Future>(NoneType::get());
113   future->markCompleted();
114   return future;
115 }
116 
117 } // namespace at
118 #endif
119