1 #include <ATen/Config.h> 2 #include <ATen/core/jit_type.h> 3 #if AT_PARALLEL_OPENMP 4 #include <ATen/Parallel.h> 5 #include <ATen/ParallelFuture.h> 6 7 #include <atomic> 8 9 #if AT_MKL_ENABLED() 10 #include <mkl.h> 11 #endif 12 13 #include <caffe2/utils/threadpool/pthreadpool-cpp.h> 14 15 namespace at { 16 #if AT_MKLDNN_ENABLED() 17 namespace native { namespace mkldnn { 18 void clear_computation_cache(); 19 }} // namespace native::mkldnn 20 #endif 21 22 namespace { 23 // Number of threads set by the user 24 std::atomic<int> num_threads{-1}; 25 thread_local int this_thread_id{0}; 26 27 } // namespace 28 init_num_threads()29void init_num_threads() { 30 auto nthreads = num_threads.load(); 31 if (nthreads > 0) { 32 set_num_threads(nthreads); 33 } else { 34 #if defined(_OPENMP) && AT_MKL_ENABLED() && !AT_MKL_SEQUENTIAL() 35 // If we are using MKL an OpenMP make sure the number of threads match. 36 // Otherwise, MKL and our OpenMP-enabled functions will keep changing the 37 // size of the OpenMP thread pool, resulting in worse performance (and memory 38 // leaks in GCC 5.4) 39 omp_set_num_threads(mkl_get_max_threads()); 40 #elif defined(_OPENMP) 41 omp_set_num_threads(intraop_default_num_threads()); 42 #endif 43 } 44 } 45 set_num_threads(int nthreads)46void set_num_threads(int nthreads) { 47 TORCH_CHECK(nthreads > 0, "Expected positive number of threads"); 48 num_threads.store(nthreads); 49 #ifdef _OPENMP 50 omp_set_num_threads(nthreads); 51 #endif 52 #if AT_MKL_ENABLED() 53 mkl_set_num_threads_local(nthreads); 54 55 // because PyTorch uses OpenMP outside of MKL invocations 56 // as well, we want this flag to be false, so that 57 // threads aren't destroyed and recreated across every 58 // MKL / non-MKL boundary of OpenMP usage 59 // See https://github.com/pytorch/pytorch/issues/13757 60 mkl_set_dynamic(false); 61 #endif 62 #ifdef USE_PTHREADPOOL 63 // because PyTorch uses caffe2::pthreadpool() in QNNPACK 64 caffe2::PThreadPool* const pool = caffe2::pthreadpool(); 65 TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!"); 66 pool->set_thread_count(nthreads); 67 #endif 68 #if AT_MKLDNN_ENABLED() 69 at::native::mkldnn::clear_computation_cache(); 70 #endif 71 } 72 73 // Explicitly calling omp_get_max_threads() as the size of the parallel 74 // region might be different in the new thread; 75 // Use init_num_threads() during thread initialization to ensure 76 // consistent size of parallel region in different threads get_num_threads()77int get_num_threads() { 78 #ifdef _OPENMP 79 at::internal::lazy_init_num_threads(); 80 return omp_get_max_threads(); 81 #else 82 return 1; 83 #endif 84 } 85 get_thread_num()86int get_thread_num() { 87 return this_thread_id; 88 } 89 90 namespace internal { set_thread_num(int id)91void set_thread_num(int id) { 92 this_thread_id = id; 93 } 94 } 95 in_parallel_region()96bool in_parallel_region() { 97 #ifdef _OPENMP 98 return omp_in_parallel(); 99 #else 100 return false; 101 #endif 102 } 103 intraop_launch(std::function<void ()> func)104void intraop_launch(std::function<void()> func) { 105 // execute inline in openmp case 106 func(); 107 } 108 intraop_launch_future(std::function<void ()> func)109c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future( 110 std::function<void()> func) { 111 func(); 112 auto future = c10::make_intrusive<c10::ivalue::Future>(NoneType::get()); 113 future->markCompleted(); 114 return future; 115 } 116 117 } // namespace at 118 #endif 119