1*da0073e9SAndroid Build Coastguard Worker #include <ATen/Config.h> 2*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/jit_type.h> 3*da0073e9SAndroid Build Coastguard Worker #if AT_PARALLEL_OPENMP 4*da0073e9SAndroid Build Coastguard Worker #include <ATen/Parallel.h> 5*da0073e9SAndroid Build Coastguard Worker #include <ATen/ParallelFuture.h> 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker #include <atomic> 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker #if AT_MKL_ENABLED() 10*da0073e9SAndroid Build Coastguard Worker #include <mkl.h> 11*da0073e9SAndroid Build Coastguard Worker #endif 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker #include <caffe2/utils/threadpool/pthreadpool-cpp.h> 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker namespace at { 16*da0073e9SAndroid Build Coastguard Worker #if AT_MKLDNN_ENABLED() 17*da0073e9SAndroid Build Coastguard Worker namespace native { namespace mkldnn { 18*da0073e9SAndroid Build Coastguard Worker void clear_computation_cache(); 19*da0073e9SAndroid Build Coastguard Worker }} // namespace native::mkldnn 20*da0073e9SAndroid Build Coastguard Worker #endif 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker namespace { 23*da0073e9SAndroid Build Coastguard Worker // Number of threads set by the user 24*da0073e9SAndroid Build Coastguard Worker std::atomic<int> num_threads{-1}; 25*da0073e9SAndroid Build Coastguard Worker thread_local int this_thread_id{0}; 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker } // namespace 28*da0073e9SAndroid Build Coastguard Worker init_num_threads()29*da0073e9SAndroid Build Coastguard Workervoid init_num_threads() { 30*da0073e9SAndroid Build Coastguard Worker auto nthreads = num_threads.load(); 31*da0073e9SAndroid Build Coastguard Worker if (nthreads > 0) { 32*da0073e9SAndroid Build Coastguard Worker set_num_threads(nthreads); 33*da0073e9SAndroid Build Coastguard Worker } else { 34*da0073e9SAndroid Build Coastguard Worker #if defined(_OPENMP) && AT_MKL_ENABLED() && !AT_MKL_SEQUENTIAL() 35*da0073e9SAndroid Build Coastguard Worker // If we are using MKL an OpenMP make sure the number of threads match. 36*da0073e9SAndroid Build Coastguard Worker // Otherwise, MKL and our OpenMP-enabled functions will keep changing the 37*da0073e9SAndroid Build Coastguard Worker // size of the OpenMP thread pool, resulting in worse performance (and memory 38*da0073e9SAndroid Build Coastguard Worker // leaks in GCC 5.4) 39*da0073e9SAndroid Build Coastguard Worker omp_set_num_threads(mkl_get_max_threads()); 40*da0073e9SAndroid Build Coastguard Worker #elif defined(_OPENMP) 41*da0073e9SAndroid Build Coastguard Worker omp_set_num_threads(intraop_default_num_threads()); 42*da0073e9SAndroid Build Coastguard Worker #endif 43*da0073e9SAndroid Build Coastguard Worker } 44*da0073e9SAndroid Build Coastguard Worker } 45*da0073e9SAndroid Build Coastguard Worker set_num_threads(int nthreads)46*da0073e9SAndroid Build Coastguard Workervoid set_num_threads(int nthreads) { 47*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(nthreads > 0, "Expected positive number of threads"); 48*da0073e9SAndroid Build Coastguard Worker num_threads.store(nthreads); 49*da0073e9SAndroid Build Coastguard Worker #ifdef _OPENMP 50*da0073e9SAndroid Build Coastguard Worker omp_set_num_threads(nthreads); 51*da0073e9SAndroid Build Coastguard Worker #endif 52*da0073e9SAndroid Build Coastguard Worker #if AT_MKL_ENABLED() 53*da0073e9SAndroid Build Coastguard Worker mkl_set_num_threads_local(nthreads); 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker // because PyTorch uses OpenMP outside of MKL invocations 56*da0073e9SAndroid Build Coastguard Worker // as well, we want this flag to be false, so that 57*da0073e9SAndroid Build Coastguard Worker // threads aren't destroyed and recreated across every 58*da0073e9SAndroid Build Coastguard Worker // MKL / non-MKL boundary of OpenMP usage 59*da0073e9SAndroid Build Coastguard Worker // See https://github.com/pytorch/pytorch/issues/13757 60*da0073e9SAndroid Build Coastguard Worker mkl_set_dynamic(false); 61*da0073e9SAndroid Build Coastguard Worker #endif 62*da0073e9SAndroid Build Coastguard Worker #ifdef USE_PTHREADPOOL 63*da0073e9SAndroid Build Coastguard Worker // because PyTorch uses caffe2::pthreadpool() in QNNPACK 64*da0073e9SAndroid Build Coastguard Worker caffe2::PThreadPool* const pool = caffe2::pthreadpool(); 65*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!"); 66*da0073e9SAndroid Build Coastguard Worker pool->set_thread_count(nthreads); 67*da0073e9SAndroid Build Coastguard Worker #endif 68*da0073e9SAndroid Build Coastguard Worker #if AT_MKLDNN_ENABLED() 69*da0073e9SAndroid Build Coastguard Worker at::native::mkldnn::clear_computation_cache(); 70*da0073e9SAndroid Build Coastguard Worker #endif 71*da0073e9SAndroid Build Coastguard Worker } 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker // Explicitly calling omp_get_max_threads() as the size of the parallel 74*da0073e9SAndroid Build Coastguard Worker // region might be different in the new thread; 75*da0073e9SAndroid Build Coastguard Worker // Use init_num_threads() during thread initialization to ensure 76*da0073e9SAndroid Build Coastguard Worker // consistent size of parallel region in different threads get_num_threads()77*da0073e9SAndroid Build Coastguard Workerint get_num_threads() { 78*da0073e9SAndroid Build Coastguard Worker #ifdef _OPENMP 79*da0073e9SAndroid Build Coastguard Worker at::internal::lazy_init_num_threads(); 80*da0073e9SAndroid Build Coastguard Worker return omp_get_max_threads(); 81*da0073e9SAndroid Build Coastguard Worker #else 82*da0073e9SAndroid Build Coastguard Worker return 1; 83*da0073e9SAndroid Build Coastguard Worker #endif 84*da0073e9SAndroid Build Coastguard Worker } 85*da0073e9SAndroid Build Coastguard Worker get_thread_num()86*da0073e9SAndroid Build Coastguard Workerint get_thread_num() { 87*da0073e9SAndroid Build Coastguard Worker return this_thread_id; 88*da0073e9SAndroid Build Coastguard Worker } 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker namespace internal { set_thread_num(int id)91*da0073e9SAndroid Build Coastguard Workervoid set_thread_num(int id) { 92*da0073e9SAndroid Build Coastguard Worker this_thread_id = id; 93*da0073e9SAndroid Build Coastguard Worker } 94*da0073e9SAndroid Build Coastguard Worker } 95*da0073e9SAndroid Build Coastguard Worker in_parallel_region()96*da0073e9SAndroid Build Coastguard Workerbool in_parallel_region() { 97*da0073e9SAndroid Build Coastguard Worker #ifdef _OPENMP 98*da0073e9SAndroid Build Coastguard Worker return omp_in_parallel(); 99*da0073e9SAndroid Build Coastguard Worker #else 100*da0073e9SAndroid Build Coastguard Worker return false; 101*da0073e9SAndroid Build Coastguard Worker #endif 102*da0073e9SAndroid Build Coastguard Worker } 103*da0073e9SAndroid Build Coastguard Worker intraop_launch(std::function<void ()> func)104*da0073e9SAndroid Build Coastguard Workervoid intraop_launch(std::function<void()> func) { 105*da0073e9SAndroid Build Coastguard Worker // execute inline in openmp case 106*da0073e9SAndroid Build Coastguard Worker func(); 107*da0073e9SAndroid Build Coastguard Worker } 108*da0073e9SAndroid Build Coastguard Worker intraop_launch_future(std::function<void ()> func)109*da0073e9SAndroid Build Coastguard Workerc10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future( 110*da0073e9SAndroid Build Coastguard Worker std::function<void()> func) { 111*da0073e9SAndroid Build Coastguard Worker func(); 112*da0073e9SAndroid Build Coastguard Worker auto future = c10::make_intrusive<c10::ivalue::Future>(NoneType::get()); 113*da0073e9SAndroid Build Coastguard Worker future->markCompleted(); 114*da0073e9SAndroid Build Coastguard Worker return future; 115*da0073e9SAndroid Build Coastguard Worker } 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker } // namespace at 118*da0073e9SAndroid Build Coastguard Worker #endif 119