xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ParallelThreadPoolNative.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Config.h>
2 #if AT_PARALLEL_OPENMP || AT_PARALLEL_NATIVE
3 #include <ATen/Parallel.h>
4 #include <ATen/PTThreadPool.h>
5 #include <ATen/ThreadLocalState.h>
6 
7 #include <atomic>
8 
9 namespace at {
10 
11 namespace {
12 const int NOT_SET = -1;
13 const int CONSUMED = -2;
14 
15 // Number of inter-op threads set by the user;
16 // NOT_SET -> positive value -> CONSUMED
17 // (CONSUMED - thread pool is initialized)
18 // or
19 // NOT_SET -> CONSUMED
20 std::atomic<int> num_interop_threads{NOT_SET};
21 
22 // thread pool global instance is hidden,
23 // users should use at::launch and get/set_num_interop_threads interface
get_pool()24 TaskThreadPoolBase& get_pool() {
25   static std::shared_ptr<TaskThreadPoolBase> pool =
26       ThreadPoolRegistry()->Create(
27           "C10",
28           /* device_id */ 0,
29           /* pool_size */ num_interop_threads.exchange(CONSUMED),
30           /* create_new */ true);
31   return *pool;
32 }
33 
34 // Factory function for ThreadPoolRegistry
create_c10_threadpool(int device_id,int pool_size,bool create_new)35 std::shared_ptr<TaskThreadPoolBase> create_c10_threadpool(
36     int device_id,
37     int pool_size,
38     bool create_new) {
39   // For now, the only accepted device id is 0
40   TORCH_CHECK(device_id == 0);
41   // Create new thread pool
42   TORCH_CHECK(create_new);
43   return std::make_shared<PTThreadPool>(pool_size);
44 }
45 
46 } // namespace
47 
48 C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, create_c10_threadpool);
49 
set_num_interop_threads(int nthreads)50 void set_num_interop_threads(int nthreads) {
51   TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
52 
53   int no_value = NOT_SET;
54   TORCH_CHECK(num_interop_threads.compare_exchange_strong(no_value, nthreads),
55       "Error: cannot set number of interop threads after parallel work "
56       "has started or set_num_interop_threads called");
57 }
58 
get_num_interop_threads()59 int get_num_interop_threads() {
60   at::internal::lazy_init_num_threads();
61   int nthreads = num_interop_threads.load();
62   if (nthreads > 0) {
63     return nthreads;
64   } else if (nthreads == NOT_SET) {
65     // return default value
66     return TaskThreadPoolBase::defaultNumThreads();
67   } else {
68     return get_pool().size();
69   }
70 }
71 
72 namespace internal {
launch_no_thread_state(std::function<void ()> fn)73 void launch_no_thread_state(std::function<void()> fn) {
74 #if AT_EXPERIMENTAL_SINGLE_THREAD_POOL
75   intraop_launch(std::move(fn));
76 #else
77   get_pool().run(std::move(fn));
78 #endif
79 }
80 } // namespace internal
81 
launch(std::function<void ()> func)82 void launch(std::function<void()> func) {
83   // NOLINTNEXTLINE(modernize-avoid-bind)
84   internal::launch_no_thread_state(std::bind([](
85     std::function<void()> f, ThreadLocalState thread_locals) {
86       ThreadLocalStateGuard guard(thread_locals);
87       f();
88     },
89     std::move(func),
90     ThreadLocalState()
91   ));
92 }
93 
94 } // namespace at
95 #endif
96