xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ParallelThreadPoolNative.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <ATen/Config.h>
2*da0073e9SAndroid Build Coastguard Worker #if AT_PARALLEL_OPENMP || AT_PARALLEL_NATIVE
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/Parallel.h>
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/PTThreadPool.h>
5*da0073e9SAndroid Build Coastguard Worker #include <ATen/ThreadLocalState.h>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker #include <atomic>
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker namespace at {
10*da0073e9SAndroid Build Coastguard Worker 
11*da0073e9SAndroid Build Coastguard Worker namespace {
12*da0073e9SAndroid Build Coastguard Worker const int NOT_SET = -1;
13*da0073e9SAndroid Build Coastguard Worker const int CONSUMED = -2;
14*da0073e9SAndroid Build Coastguard Worker 
15*da0073e9SAndroid Build Coastguard Worker // Number of inter-op threads set by the user;
16*da0073e9SAndroid Build Coastguard Worker // NOT_SET -> positive value -> CONSUMED
17*da0073e9SAndroid Build Coastguard Worker // (CONSUMED - thread pool is initialized)
18*da0073e9SAndroid Build Coastguard Worker // or
19*da0073e9SAndroid Build Coastguard Worker // NOT_SET -> CONSUMED
20*da0073e9SAndroid Build Coastguard Worker std::atomic<int> num_interop_threads{NOT_SET};
21*da0073e9SAndroid Build Coastguard Worker 
22*da0073e9SAndroid Build Coastguard Worker // thread pool global instance is hidden,
23*da0073e9SAndroid Build Coastguard Worker // users should use at::launch and get/set_num_interop_threads interface
get_pool()24*da0073e9SAndroid Build Coastguard Worker TaskThreadPoolBase& get_pool() {
25*da0073e9SAndroid Build Coastguard Worker   static std::shared_ptr<TaskThreadPoolBase> pool =
26*da0073e9SAndroid Build Coastguard Worker       ThreadPoolRegistry()->Create(
27*da0073e9SAndroid Build Coastguard Worker           "C10",
28*da0073e9SAndroid Build Coastguard Worker           /* device_id */ 0,
29*da0073e9SAndroid Build Coastguard Worker           /* pool_size */ num_interop_threads.exchange(CONSUMED),
30*da0073e9SAndroid Build Coastguard Worker           /* create_new */ true);
31*da0073e9SAndroid Build Coastguard Worker   return *pool;
32*da0073e9SAndroid Build Coastguard Worker }
33*da0073e9SAndroid Build Coastguard Worker 
34*da0073e9SAndroid Build Coastguard Worker // Factory function for ThreadPoolRegistry
create_c10_threadpool(int device_id,int pool_size,bool create_new)35*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<TaskThreadPoolBase> create_c10_threadpool(
36*da0073e9SAndroid Build Coastguard Worker     int device_id,
37*da0073e9SAndroid Build Coastguard Worker     int pool_size,
38*da0073e9SAndroid Build Coastguard Worker     bool create_new) {
39*da0073e9SAndroid Build Coastguard Worker   // For now, the only accepted device id is 0
40*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(device_id == 0);
41*da0073e9SAndroid Build Coastguard Worker   // Create new thread pool
42*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(create_new);
43*da0073e9SAndroid Build Coastguard Worker   return std::make_shared<PTThreadPool>(pool_size);
44*da0073e9SAndroid Build Coastguard Worker }
45*da0073e9SAndroid Build Coastguard Worker 
46*da0073e9SAndroid Build Coastguard Worker } // namespace
47*da0073e9SAndroid Build Coastguard Worker 
48*da0073e9SAndroid Build Coastguard Worker C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, create_c10_threadpool);
49*da0073e9SAndroid Build Coastguard Worker 
set_num_interop_threads(int nthreads)50*da0073e9SAndroid Build Coastguard Worker void set_num_interop_threads(int nthreads) {
51*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
52*da0073e9SAndroid Build Coastguard Worker 
53*da0073e9SAndroid Build Coastguard Worker   int no_value = NOT_SET;
54*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(num_interop_threads.compare_exchange_strong(no_value, nthreads),
55*da0073e9SAndroid Build Coastguard Worker       "Error: cannot set number of interop threads after parallel work "
56*da0073e9SAndroid Build Coastguard Worker       "has started or set_num_interop_threads called");
57*da0073e9SAndroid Build Coastguard Worker }
58*da0073e9SAndroid Build Coastguard Worker 
get_num_interop_threads()59*da0073e9SAndroid Build Coastguard Worker int get_num_interop_threads() {
60*da0073e9SAndroid Build Coastguard Worker   at::internal::lazy_init_num_threads();
61*da0073e9SAndroid Build Coastguard Worker   int nthreads = num_interop_threads.load();
62*da0073e9SAndroid Build Coastguard Worker   if (nthreads > 0) {
63*da0073e9SAndroid Build Coastguard Worker     return nthreads;
64*da0073e9SAndroid Build Coastguard Worker   } else if (nthreads == NOT_SET) {
65*da0073e9SAndroid Build Coastguard Worker     // return default value
66*da0073e9SAndroid Build Coastguard Worker     return TaskThreadPoolBase::defaultNumThreads();
67*da0073e9SAndroid Build Coastguard Worker   } else {
68*da0073e9SAndroid Build Coastguard Worker     return get_pool().size();
69*da0073e9SAndroid Build Coastguard Worker   }
70*da0073e9SAndroid Build Coastguard Worker }
71*da0073e9SAndroid Build Coastguard Worker 
72*da0073e9SAndroid Build Coastguard Worker namespace internal {
launch_no_thread_state(std::function<void ()> fn)73*da0073e9SAndroid Build Coastguard Worker void launch_no_thread_state(std::function<void()> fn) {
74*da0073e9SAndroid Build Coastguard Worker #if AT_EXPERIMENTAL_SINGLE_THREAD_POOL
75*da0073e9SAndroid Build Coastguard Worker   intraop_launch(std::move(fn));
76*da0073e9SAndroid Build Coastguard Worker #else
77*da0073e9SAndroid Build Coastguard Worker   get_pool().run(std::move(fn));
78*da0073e9SAndroid Build Coastguard Worker #endif
79*da0073e9SAndroid Build Coastguard Worker }
80*da0073e9SAndroid Build Coastguard Worker } // namespace internal
81*da0073e9SAndroid Build Coastguard Worker 
launch(std::function<void ()> func)82*da0073e9SAndroid Build Coastguard Worker void launch(std::function<void()> func) {
83*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(modernize-avoid-bind)
84*da0073e9SAndroid Build Coastguard Worker   internal::launch_no_thread_state(std::bind([](
85*da0073e9SAndroid Build Coastguard Worker     std::function<void()> f, ThreadLocalState thread_locals) {
86*da0073e9SAndroid Build Coastguard Worker       ThreadLocalStateGuard guard(thread_locals);
87*da0073e9SAndroid Build Coastguard Worker       f();
88*da0073e9SAndroid Build Coastguard Worker     },
89*da0073e9SAndroid Build Coastguard Worker     std::move(func),
90*da0073e9SAndroid Build Coastguard Worker     ThreadLocalState()
91*da0073e9SAndroid Build Coastguard Worker   ));
92*da0073e9SAndroid Build Coastguard Worker }
93*da0073e9SAndroid Build Coastguard Worker 
94*da0073e9SAndroid Build Coastguard Worker } // namespace at
95*da0073e9SAndroid Build Coastguard Worker #endif
96