1 #include <ATen/Config.h>
2 #if AT_PARALLEL_OPENMP || AT_PARALLEL_NATIVE
3 #include <ATen/Parallel.h>
4 #include <ATen/PTThreadPool.h>
5 #include <ATen/ThreadLocalState.h>
6
7 #include <atomic>
8
9 namespace at {
10
11 namespace {
12 const int NOT_SET = -1;
13 const int CONSUMED = -2;
14
15 // Number of inter-op threads set by the user;
16 // NOT_SET -> positive value -> CONSUMED
17 // (CONSUMED - thread pool is initialized)
18 // or
19 // NOT_SET -> CONSUMED
20 std::atomic<int> num_interop_threads{NOT_SET};
21
22 // thread pool global instance is hidden,
23 // users should use at::launch and get/set_num_interop_threads interface
get_pool()24 TaskThreadPoolBase& get_pool() {
25 static std::shared_ptr<TaskThreadPoolBase> pool =
26 ThreadPoolRegistry()->Create(
27 "C10",
28 /* device_id */ 0,
29 /* pool_size */ num_interop_threads.exchange(CONSUMED),
30 /* create_new */ true);
31 return *pool;
32 }
33
34 // Factory function for ThreadPoolRegistry
create_c10_threadpool(int device_id,int pool_size,bool create_new)35 std::shared_ptr<TaskThreadPoolBase> create_c10_threadpool(
36 int device_id,
37 int pool_size,
38 bool create_new) {
39 // For now, the only accepted device id is 0
40 TORCH_CHECK(device_id == 0);
41 // Create new thread pool
42 TORCH_CHECK(create_new);
43 return std::make_shared<PTThreadPool>(pool_size);
44 }
45
46 } // namespace
47
48 C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, create_c10_threadpool);
49
set_num_interop_threads(int nthreads)50 void set_num_interop_threads(int nthreads) {
51 TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
52
53 int no_value = NOT_SET;
54 TORCH_CHECK(num_interop_threads.compare_exchange_strong(no_value, nthreads),
55 "Error: cannot set number of interop threads after parallel work "
56 "has started or set_num_interop_threads called");
57 }
58
get_num_interop_threads()59 int get_num_interop_threads() {
60 at::internal::lazy_init_num_threads();
61 int nthreads = num_interop_threads.load();
62 if (nthreads > 0) {
63 return nthreads;
64 } else if (nthreads == NOT_SET) {
65 // return default value
66 return TaskThreadPoolBase::defaultNumThreads();
67 } else {
68 return get_pool().size();
69 }
70 }
71
72 namespace internal {
launch_no_thread_state(std::function<void ()> fn)73 void launch_no_thread_state(std::function<void()> fn) {
74 #if AT_EXPERIMENTAL_SINGLE_THREAD_POOL
75 intraop_launch(std::move(fn));
76 #else
77 get_pool().run(std::move(fn));
78 #endif
79 }
80 } // namespace internal
81
launch(std::function<void ()> func)82 void launch(std::function<void()> func) {
83 // NOLINTNEXTLINE(modernize-avoid-bind)
84 internal::launch_no_thread_state(std::bind([](
85 std::function<void()> f, ThreadLocalState thread_locals) {
86 ThreadLocalStateGuard guard(thread_locals);
87 f();
88 },
89 std::move(func),
90 ThreadLocalState()
91 ));
92 }
93
94 } // namespace at
95 #endif
96