xref: /aosp_15_r20/external/pytorch/aten/src/ATen/Parallel.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker #include <ATen/Config.h>
3*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
4*da0073e9SAndroid Build Coastguard Worker #include <functional>
5*da0073e9SAndroid Build Coastguard Worker #include <string>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker namespace at {
8*da0073e9SAndroid Build Coastguard Worker 
divup(int64_t x,int64_t y)9*da0073e9SAndroid Build Coastguard Worker inline int64_t divup(int64_t x, int64_t y) {
10*da0073e9SAndroid Build Coastguard Worker   return (x + y - 1) / y;
11*da0073e9SAndroid Build Coastguard Worker }
12*da0073e9SAndroid Build Coastguard Worker 
13*da0073e9SAndroid Build Coastguard Worker // Called during new thread initialization
14*da0073e9SAndroid Build Coastguard Worker TORCH_API void init_num_threads();
15*da0073e9SAndroid Build Coastguard Worker 
16*da0073e9SAndroid Build Coastguard Worker // Sets the number of threads to be used in parallel region
17*da0073e9SAndroid Build Coastguard Worker TORCH_API void set_num_threads(int);
18*da0073e9SAndroid Build Coastguard Worker 
19*da0073e9SAndroid Build Coastguard Worker // Returns the maximum number of threads that may be used in a parallel region
20*da0073e9SAndroid Build Coastguard Worker TORCH_API int get_num_threads();
21*da0073e9SAndroid Build Coastguard Worker 
22*da0073e9SAndroid Build Coastguard Worker // Returns the current thread number (starting from 0)
23*da0073e9SAndroid Build Coastguard Worker // in the current parallel region, or 0 in the sequential region
24*da0073e9SAndroid Build Coastguard Worker TORCH_API int get_thread_num();
25*da0073e9SAndroid Build Coastguard Worker 
26*da0073e9SAndroid Build Coastguard Worker // Checks whether the code runs in parallel region
27*da0073e9SAndroid Build Coastguard Worker TORCH_API bool in_parallel_region();
28*da0073e9SAndroid Build Coastguard Worker 
29*da0073e9SAndroid Build Coastguard Worker namespace internal {
30*da0073e9SAndroid Build Coastguard Worker 
31*da0073e9SAndroid Build Coastguard Worker // Initialise num_threads lazily at first parallel call
lazy_init_num_threads()32*da0073e9SAndroid Build Coastguard Worker inline void lazy_init_num_threads() {
33*da0073e9SAndroid Build Coastguard Worker   thread_local bool init = false;
34*da0073e9SAndroid Build Coastguard Worker   if (C10_UNLIKELY(!init)) {
35*da0073e9SAndroid Build Coastguard Worker     at::init_num_threads();
36*da0073e9SAndroid Build Coastguard Worker     init = true;
37*da0073e9SAndroid Build Coastguard Worker   }
38*da0073e9SAndroid Build Coastguard Worker }
39*da0073e9SAndroid Build Coastguard Worker 
40*da0073e9SAndroid Build Coastguard Worker TORCH_API void set_thread_num(int);
41*da0073e9SAndroid Build Coastguard Worker 
42*da0073e9SAndroid Build Coastguard Worker class TORCH_API ThreadIdGuard {
43*da0073e9SAndroid Build Coastguard Worker  public:
ThreadIdGuard(int new_id)44*da0073e9SAndroid Build Coastguard Worker   ThreadIdGuard(int new_id) : old_id_(at::get_thread_num()) {
45*da0073e9SAndroid Build Coastguard Worker     set_thread_num(new_id);
46*da0073e9SAndroid Build Coastguard Worker   }
47*da0073e9SAndroid Build Coastguard Worker 
~ThreadIdGuard()48*da0073e9SAndroid Build Coastguard Worker   ~ThreadIdGuard() {
49*da0073e9SAndroid Build Coastguard Worker     set_thread_num(old_id_);
50*da0073e9SAndroid Build Coastguard Worker   }
51*da0073e9SAndroid Build Coastguard Worker 
52*da0073e9SAndroid Build Coastguard Worker  private:
53*da0073e9SAndroid Build Coastguard Worker   int old_id_;
54*da0073e9SAndroid Build Coastguard Worker };
55*da0073e9SAndroid Build Coastguard Worker 
56*da0073e9SAndroid Build Coastguard Worker } // namespace internal
57*da0073e9SAndroid Build Coastguard Worker 
58*da0073e9SAndroid Build Coastguard Worker /*
59*da0073e9SAndroid Build Coastguard Worker parallel_for
60*da0073e9SAndroid Build Coastguard Worker 
61*da0073e9SAndroid Build Coastguard Worker begin: index at which to start applying user function
62*da0073e9SAndroid Build Coastguard Worker 
63*da0073e9SAndroid Build Coastguard Worker end: index at which to stop applying user function
64*da0073e9SAndroid Build Coastguard Worker 
65*da0073e9SAndroid Build Coastguard Worker grain_size: number of elements per chunk. impacts the degree of parallelization
66*da0073e9SAndroid Build Coastguard Worker 
67*da0073e9SAndroid Build Coastguard Worker f: user function applied in parallel to the chunks, signature:
68*da0073e9SAndroid Build Coastguard Worker   void f(int64_t begin, int64_t end)
69*da0073e9SAndroid Build Coastguard Worker 
70*da0073e9SAndroid Build Coastguard Worker Warning: parallel_for does NOT copy thread local
71*da0073e9SAndroid Build Coastguard Worker states from the current thread to the worker threads.
72*da0073e9SAndroid Build Coastguard Worker This means for example that Tensor operations CANNOT be used in the
73*da0073e9SAndroid Build Coastguard Worker body of your function, only data pointers.
74*da0073e9SAndroid Build Coastguard Worker */
75*da0073e9SAndroid Build Coastguard Worker template <class F>
76*da0073e9SAndroid Build Coastguard Worker inline void parallel_for(
77*da0073e9SAndroid Build Coastguard Worker     const int64_t begin,
78*da0073e9SAndroid Build Coastguard Worker     const int64_t end,
79*da0073e9SAndroid Build Coastguard Worker     const int64_t grain_size,
80*da0073e9SAndroid Build Coastguard Worker     const F& f);
81*da0073e9SAndroid Build Coastguard Worker 
82*da0073e9SAndroid Build Coastguard Worker /*
83*da0073e9SAndroid Build Coastguard Worker parallel_reduce
84*da0073e9SAndroid Build Coastguard Worker 
85*da0073e9SAndroid Build Coastguard Worker begin: index at which to start applying reduction
86*da0073e9SAndroid Build Coastguard Worker 
87*da0073e9SAndroid Build Coastguard Worker end: index at which to stop applying reduction
88*da0073e9SAndroid Build Coastguard Worker 
89*da0073e9SAndroid Build Coastguard Worker grain_size: number of elements per chunk. impacts number of elements in
90*da0073e9SAndroid Build Coastguard Worker intermediate results tensor and degree of parallelization.
91*da0073e9SAndroid Build Coastguard Worker 
92*da0073e9SAndroid Build Coastguard Worker ident: identity for binary combination function sf. sf(ident, x) needs to return
93*da0073e9SAndroid Build Coastguard Worker x.
94*da0073e9SAndroid Build Coastguard Worker 
95*da0073e9SAndroid Build Coastguard Worker f: function for reduction over a chunk. f needs to be of signature scalar_t
96*da0073e9SAndroid Build Coastguard Worker f(int64_t partial_begin, int64_t partial_end, scalar_t identifiy)
97*da0073e9SAndroid Build Coastguard Worker 
98*da0073e9SAndroid Build Coastguard Worker sf: function to combine two partial results. sf needs to be of signature
99*da0073e9SAndroid Build Coastguard Worker scalar_t sf(scalar_t x, scalar_t y)
100*da0073e9SAndroid Build Coastguard Worker 
101*da0073e9SAndroid Build Coastguard Worker For example, you might have a tensor of 10000 entires and want to sum together
102*da0073e9SAndroid Build Coastguard Worker all the elements. Parallel_reduce with a grain_size of 2500 will then allocate
103*da0073e9SAndroid Build Coastguard Worker an intermediate result tensor with 4 elements. Then it will execute the function
104*da0073e9SAndroid Build Coastguard Worker "f" you provide and pass the beginning and end index of these chunks, so
105*da0073e9SAndroid Build Coastguard Worker 0-2499, 2500-4999, etc. and the combination identity. It will then write out
106*da0073e9SAndroid Build Coastguard Worker the result from each of these chunks into the intermediate result tensor. After
107*da0073e9SAndroid Build Coastguard Worker that it'll reduce the partial results from each chunk into a single number using
108*da0073e9SAndroid Build Coastguard Worker the combination function sf and the identity ident. For a total summation this
109*da0073e9SAndroid Build Coastguard Worker would be "+" and 0 respectively. This is similar to tbb's approach [1], where
110*da0073e9SAndroid Build Coastguard Worker you need to provide a function to accumulate a subrange, a function to combine
111*da0073e9SAndroid Build Coastguard Worker two partial results and an identity.
112*da0073e9SAndroid Build Coastguard Worker 
113*da0073e9SAndroid Build Coastguard Worker Warning: parallel_reduce does NOT copy thread local
114*da0073e9SAndroid Build Coastguard Worker states from the current thread to the worker threads.
115*da0073e9SAndroid Build Coastguard Worker This means for example that Tensor operations CANNOT be used in the
116*da0073e9SAndroid Build Coastguard Worker body of your function, only data pointers.
117*da0073e9SAndroid Build Coastguard Worker 
118*da0073e9SAndroid Build Coastguard Worker [1] https://software.intel.com/en-us/node/506154
119*da0073e9SAndroid Build Coastguard Worker */
120*da0073e9SAndroid Build Coastguard Worker template <class scalar_t, class F, class SF>
121*da0073e9SAndroid Build Coastguard Worker inline scalar_t parallel_reduce(
122*da0073e9SAndroid Build Coastguard Worker     const int64_t begin,
123*da0073e9SAndroid Build Coastguard Worker     const int64_t end,
124*da0073e9SAndroid Build Coastguard Worker     const int64_t grain_size,
125*da0073e9SAndroid Build Coastguard Worker     const scalar_t ident,
126*da0073e9SAndroid Build Coastguard Worker     const F& f,
127*da0073e9SAndroid Build Coastguard Worker     const SF& sf);
128*da0073e9SAndroid Build Coastguard Worker 
129*da0073e9SAndroid Build Coastguard Worker // Returns a detailed string describing parallelization settings
130*da0073e9SAndroid Build Coastguard Worker TORCH_API std::string get_parallel_info();
131*da0073e9SAndroid Build Coastguard Worker 
132*da0073e9SAndroid Build Coastguard Worker // Sets number of threads used for inter-op parallelism
133*da0073e9SAndroid Build Coastguard Worker TORCH_API void set_num_interop_threads(int);
134*da0073e9SAndroid Build Coastguard Worker 
135*da0073e9SAndroid Build Coastguard Worker // Returns the number of threads used for inter-op parallelism
136*da0073e9SAndroid Build Coastguard Worker TORCH_API int get_num_interop_threads();
137*da0073e9SAndroid Build Coastguard Worker 
138*da0073e9SAndroid Build Coastguard Worker // Launches inter-op parallel task
139*da0073e9SAndroid Build Coastguard Worker TORCH_API void launch(std::function<void()> func);
140*da0073e9SAndroid Build Coastguard Worker namespace internal {
141*da0073e9SAndroid Build Coastguard Worker void launch_no_thread_state(std::function<void()> fn);
142*da0073e9SAndroid Build Coastguard Worker } // namespace internal
143*da0073e9SAndroid Build Coastguard Worker 
144*da0073e9SAndroid Build Coastguard Worker // Launches intra-op parallel task
145*da0073e9SAndroid Build Coastguard Worker TORCH_API void intraop_launch(std::function<void()> func);
146*da0073e9SAndroid Build Coastguard Worker 
147*da0073e9SAndroid Build Coastguard Worker // Returns number of intra-op threads used by default
148*da0073e9SAndroid Build Coastguard Worker TORCH_API int intraop_default_num_threads();
149*da0073e9SAndroid Build Coastguard Worker 
150*da0073e9SAndroid Build Coastguard Worker } // namespace at
151*da0073e9SAndroid Build Coastguard Worker 
152*da0073e9SAndroid Build Coastguard Worker #if AT_PARALLEL_OPENMP
153*da0073e9SAndroid Build Coastguard Worker #include <ATen/ParallelOpenMP.h> // IWYU pragma: keep
154*da0073e9SAndroid Build Coastguard Worker #elif AT_PARALLEL_NATIVE
155*da0073e9SAndroid Build Coastguard Worker #include <ATen/ParallelNative.h> // IWYU pragma: keep
156*da0073e9SAndroid Build Coastguard Worker #endif
157*da0073e9SAndroid Build Coastguard Worker 
158*da0073e9SAndroid Build Coastguard Worker #include <ATen/Parallel-inl.h> // IWYU pragma: keep
159