1*da0073e9SAndroid Build Coastguard Worker #include <ATen/Config.h>
2*da0073e9SAndroid Build Coastguard Worker #if AT_PARALLEL_OPENMP || AT_PARALLEL_NATIVE
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/Parallel.h>
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/PTThreadPool.h>
5*da0073e9SAndroid Build Coastguard Worker #include <ATen/ThreadLocalState.h>
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker #include <atomic>
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker namespace at {
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker namespace {
12*da0073e9SAndroid Build Coastguard Worker const int NOT_SET = -1;
13*da0073e9SAndroid Build Coastguard Worker const int CONSUMED = -2;
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker // Number of inter-op threads set by the user;
16*da0073e9SAndroid Build Coastguard Worker // NOT_SET -> positive value -> CONSUMED
17*da0073e9SAndroid Build Coastguard Worker // (CONSUMED - thread pool is initialized)
18*da0073e9SAndroid Build Coastguard Worker // or
19*da0073e9SAndroid Build Coastguard Worker // NOT_SET -> CONSUMED
20*da0073e9SAndroid Build Coastguard Worker std::atomic<int> num_interop_threads{NOT_SET};
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker // thread pool global instance is hidden,
23*da0073e9SAndroid Build Coastguard Worker // users should use at::launch and get/set_num_interop_threads interface
get_pool()24*da0073e9SAndroid Build Coastguard Worker TaskThreadPoolBase& get_pool() {
25*da0073e9SAndroid Build Coastguard Worker static std::shared_ptr<TaskThreadPoolBase> pool =
26*da0073e9SAndroid Build Coastguard Worker ThreadPoolRegistry()->Create(
27*da0073e9SAndroid Build Coastguard Worker "C10",
28*da0073e9SAndroid Build Coastguard Worker /* device_id */ 0,
29*da0073e9SAndroid Build Coastguard Worker /* pool_size */ num_interop_threads.exchange(CONSUMED),
30*da0073e9SAndroid Build Coastguard Worker /* create_new */ true);
31*da0073e9SAndroid Build Coastguard Worker return *pool;
32*da0073e9SAndroid Build Coastguard Worker }
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker // Factory function for ThreadPoolRegistry
create_c10_threadpool(int device_id,int pool_size,bool create_new)35*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<TaskThreadPoolBase> create_c10_threadpool(
36*da0073e9SAndroid Build Coastguard Worker int device_id,
37*da0073e9SAndroid Build Coastguard Worker int pool_size,
38*da0073e9SAndroid Build Coastguard Worker bool create_new) {
39*da0073e9SAndroid Build Coastguard Worker // For now, the only accepted device id is 0
40*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(device_id == 0);
41*da0073e9SAndroid Build Coastguard Worker // Create new thread pool
42*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(create_new);
43*da0073e9SAndroid Build Coastguard Worker return std::make_shared<PTThreadPool>(pool_size);
44*da0073e9SAndroid Build Coastguard Worker }
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker } // namespace
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, create_c10_threadpool);
49*da0073e9SAndroid Build Coastguard Worker
set_num_interop_threads(int nthreads)50*da0073e9SAndroid Build Coastguard Worker void set_num_interop_threads(int nthreads) {
51*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker int no_value = NOT_SET;
54*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(num_interop_threads.compare_exchange_strong(no_value, nthreads),
55*da0073e9SAndroid Build Coastguard Worker "Error: cannot set number of interop threads after parallel work "
56*da0073e9SAndroid Build Coastguard Worker "has started or set_num_interop_threads called");
57*da0073e9SAndroid Build Coastguard Worker }
58*da0073e9SAndroid Build Coastguard Worker
get_num_interop_threads()59*da0073e9SAndroid Build Coastguard Worker int get_num_interop_threads() {
60*da0073e9SAndroid Build Coastguard Worker at::internal::lazy_init_num_threads();
61*da0073e9SAndroid Build Coastguard Worker int nthreads = num_interop_threads.load();
62*da0073e9SAndroid Build Coastguard Worker if (nthreads > 0) {
63*da0073e9SAndroid Build Coastguard Worker return nthreads;
64*da0073e9SAndroid Build Coastguard Worker } else if (nthreads == NOT_SET) {
65*da0073e9SAndroid Build Coastguard Worker // return default value
66*da0073e9SAndroid Build Coastguard Worker return TaskThreadPoolBase::defaultNumThreads();
67*da0073e9SAndroid Build Coastguard Worker } else {
68*da0073e9SAndroid Build Coastguard Worker return get_pool().size();
69*da0073e9SAndroid Build Coastguard Worker }
70*da0073e9SAndroid Build Coastguard Worker }
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker namespace internal {
launch_no_thread_state(std::function<void ()> fn)73*da0073e9SAndroid Build Coastguard Worker void launch_no_thread_state(std::function<void()> fn) {
74*da0073e9SAndroid Build Coastguard Worker #if AT_EXPERIMENTAL_SINGLE_THREAD_POOL
75*da0073e9SAndroid Build Coastguard Worker intraop_launch(std::move(fn));
76*da0073e9SAndroid Build Coastguard Worker #else
77*da0073e9SAndroid Build Coastguard Worker get_pool().run(std::move(fn));
78*da0073e9SAndroid Build Coastguard Worker #endif
79*da0073e9SAndroid Build Coastguard Worker }
80*da0073e9SAndroid Build Coastguard Worker } // namespace internal
81*da0073e9SAndroid Build Coastguard Worker
launch(std::function<void ()> func)82*da0073e9SAndroid Build Coastguard Worker void launch(std::function<void()> func) {
83*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(modernize-avoid-bind)
84*da0073e9SAndroid Build Coastguard Worker internal::launch_no_thread_state(std::bind([](
85*da0073e9SAndroid Build Coastguard Worker std::function<void()> f, ThreadLocalState thread_locals) {
86*da0073e9SAndroid Build Coastguard Worker ThreadLocalStateGuard guard(thread_locals);
87*da0073e9SAndroid Build Coastguard Worker f();
88*da0073e9SAndroid Build Coastguard Worker },
89*da0073e9SAndroid Build Coastguard Worker std::move(func),
90*da0073e9SAndroid Build Coastguard Worker ThreadLocalState()
91*da0073e9SAndroid Build Coastguard Worker ));
92*da0073e9SAndroid Build Coastguard Worker }
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker } // namespace at
95*da0073e9SAndroid Build Coastguard Worker #endif
96