xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ParallelCommon.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Parallel.h>
2 
3 #include <ATen/Config.h>
4 #include <ATen/PTThreadPool.h>
5 #include <ATen/Version.h>
6 
7 #include <sstream>
8 #include <thread>
9 
10 #if AT_MKL_ENABLED()
11 #include <mkl.h>
12 #endif
13 
14 #ifdef _OPENMP
15 #include <omp.h>
16 #endif
17 
18 #if defined(__APPLE__) && defined(__aarch64__) && !defined(C10_MOBILE)
19 #include <sys/sysctl.h>
20 #endif
21 
22 namespace at {
23 
24 namespace {
25 
get_env_var(const char * var_name,const char * def_value=nullptr)26 const char* get_env_var(
27     const char* var_name, const char* def_value = nullptr) {
28   const char* value = std::getenv(var_name);
29   return value ? value : def_value;
30 }
31 
32 #ifndef C10_MOBILE
get_env_num_threads(const char * var_name,size_t def_value=0)33 size_t get_env_num_threads(const char* var_name, size_t def_value = 0) {
34   try {
35     if (auto* value = std::getenv(var_name)) {
36       int nthreads = std::stoi(value);
37       TORCH_CHECK(nthreads > 0);
38       return nthreads;
39     }
40   } catch (const std::exception& e) {
41     std::ostringstream oss;
42     oss << "Invalid " << var_name << " variable value, " << e.what();
43     TORCH_WARN(oss.str());
44   }
45   return def_value;
46 }
47 #endif
48 
49 } // namespace
50 
get_parallel_info()51 std::string get_parallel_info() {
52   std::ostringstream ss;
53 
54   ss << "ATen/Parallel:\n\tat::get_num_threads() : "
55      << at::get_num_threads() << '\n';
56   ss << "\tat::get_num_interop_threads() : "
57      << at::get_num_interop_threads() << '\n';
58 
59   ss << at::get_openmp_version() << '\n';
60 #ifdef _OPENMP
61   ss << "\tomp_get_max_threads() : " << omp_get_max_threads() << '\n';
62 #endif
63 
64   ss << at::get_mkl_version() << '\n';
65 #if AT_MKL_ENABLED()
66   ss << "\tmkl_get_max_threads() : " << mkl_get_max_threads() << '\n';
67 #endif
68 
69   ss << at::get_mkldnn_version() << '\n';
70 
71   ss << "std::thread::hardware_concurrency() : "
72      << std::thread::hardware_concurrency() << '\n';
73 
74   ss << "Environment variables:" << '\n';
75   ss << "\tOMP_NUM_THREADS : "
76      << get_env_var("OMP_NUM_THREADS", "[not set]") << '\n';
77   ss << "\tMKL_NUM_THREADS : "
78      << get_env_var("MKL_NUM_THREADS", "[not set]") << '\n';
79 
80   ss << "ATen parallel backend: ";
81   #if AT_PARALLEL_OPENMP
82   ss << "OpenMP";
83   #elif AT_PARALLEL_NATIVE
84   ss << "native thread pool";
85   #endif
86   #ifdef C10_MOBILE
87   ss << " [mobile]";
88   #endif
89   ss << '\n';
90 
91   #if AT_EXPERIMENTAL_SINGLE_THREAD_POOL
92   ss << "Experimental: single thread pool" << std::endl;
93   #endif
94 
95   return ss.str();
96 }
97 
intraop_default_num_threads()98 int intraop_default_num_threads() {
99 #ifdef C10_MOBILE
100   // Intraop thread pool size should be determined by mobile cpuinfo.
101   // We should hook up with the logic in caffe2/utils/threadpool if we ever need
102   // call this API for mobile.
103   TORCH_CHECK(false, "Undefined intraop_default_num_threads on mobile.");
104 #else
105   size_t nthreads = get_env_num_threads("OMP_NUM_THREADS", 0);
106   nthreads = get_env_num_threads("MKL_NUM_THREADS", nthreads);
107   if (nthreads == 0) {
108 #if defined(FBCODE_CAFFE2) && defined(__aarch64__)
109     nthreads = 1;
110 #else
111 #if defined(__aarch64__) && defined(__APPLE__)
112     // On Apple Silicon there are efficient and performance core
113     // Restrict parallel algorithms to performance cores by default
114     int32_t num_cores = -1;
115     size_t num_cores_len = sizeof(num_cores);
116     if (sysctlbyname("hw.perflevel0.physicalcpu", &num_cores, &num_cores_len, nullptr, 0) == 0) {
117       if (num_cores > 1) {
118         nthreads = num_cores;
119         return num_cores;
120       }
121     }
122 #endif
123     nthreads = TaskThreadPoolBase::defaultNumThreads();
124 #endif
125   }
126   return static_cast<int>(nthreads);
127 #endif /* !defined(C10_MOBILE) */
128 }
129 
130 } // namespace at
131