xref: /aosp_15_r20/external/tensorflow/tensorflow/core/platform/threadpool.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/platform/threadpool.h"
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "absl/types/optional.h"
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/platform/blocking_counter.h"
23 #include "tensorflow/core/platform/context.h"
24 #include "tensorflow/core/platform/denormal.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/core/platform/numa.h"
28 #include "tensorflow/core/platform/setround.h"
29 #include "tensorflow/core/platform/tracing.h"
30 
31 namespace tensorflow {
32 namespace thread {
33 
34 struct EigenEnvironment {
35   typedef Thread EnvThread;
36   struct TaskImpl {
37     std::function<void()> f;
38     Context context;
39     uint64 trace_id;
40   };
41   struct Task {
42     std::unique_ptr<TaskImpl> f;
43   };
44 
45   Env* const env_;
46   const ThreadOptions thread_options_;
47   const string name_;
48 
EigenEnvironmenttensorflow::thread::EigenEnvironment49   EigenEnvironment(Env* env, const ThreadOptions& thread_options,
50                    const string& name)
51       : env_(env), thread_options_(thread_options), name_(name) {}
52 
CreateThreadtensorflow::thread::EigenEnvironment53   EnvThread* CreateThread(std::function<void()> f) {
54     return env_->StartThread(thread_options_, name_, [=]() {
55       // Set the processor flag to flush denormals to zero.
56       port::ScopedFlushDenormal flush;
57       // Set the processor rounding mode to ROUND TO NEAREST.
58       port::ScopedSetRound round(FE_TONEAREST);
59       if (thread_options_.numa_node != port::kNUMANoAffinity) {
60         port::NUMASetThreadNodeAffinity(thread_options_.numa_node);
61       }
62       f();
63     });
64   }
65 
CreateTasktensorflow::thread::EigenEnvironment66   Task CreateTask(std::function<void()> f) {
67     uint64 id = 0;
68     if (tracing::EventCollector::IsEnabled()) {
69       id = tracing::GetUniqueArg();
70       tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id);
71     }
72     return Task{
73         std::unique_ptr<TaskImpl>(new TaskImpl{
74             std::move(f),
75             Context(ContextKind::kThread),
76             id,
77         }),
78     };
79   }
80 
ExecuteTasktensorflow::thread::EigenEnvironment81   void ExecuteTask(const Task& t) {
82     WithContext wc(t.f->context);
83     tracing::ScopedRegion region(tracing::EventCategory::kRunClosure,
84                                  t.f->trace_id);
85     t.f->f();
86   }
87 };
88 
ThreadPool(Env * env,const string & name,int num_threads)89 ThreadPool::ThreadPool(Env* env, const string& name, int num_threads)
90     : ThreadPool(env, ThreadOptions(), name, num_threads, true, nullptr) {}
91 
ThreadPool(Env * env,const ThreadOptions & thread_options,const string & name,int num_threads)92 ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options,
93                        const string& name, int num_threads)
94     : ThreadPool(env, thread_options, name, num_threads, true, nullptr) {}
95 
ThreadPool(Env * env,const ThreadOptions & thread_options,const string & name,int num_threads,bool low_latency_hint,Eigen::Allocator * allocator)96 ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options,
97                        const string& name, int num_threads,
98                        bool low_latency_hint, Eigen::Allocator* allocator) {
99   CHECK_GE(num_threads, 1);
100   eigen_threadpool_.reset(new Eigen::ThreadPoolTempl<EigenEnvironment>(
101       num_threads, low_latency_hint,
102       EigenEnvironment(env, thread_options, "tf_" + name)));
103   underlying_threadpool_ = eigen_threadpool_.get();
104   threadpool_device_.reset(new Eigen::ThreadPoolDevice(underlying_threadpool_,
105                                                        num_threads, allocator));
106 }
107 
ThreadPool(thread::ThreadPoolInterface * user_threadpool)108 ThreadPool::ThreadPool(thread::ThreadPoolInterface* user_threadpool) {
109   underlying_threadpool_ = user_threadpool;
110   threadpool_device_.reset(new Eigen::ThreadPoolDevice(
111       underlying_threadpool_, underlying_threadpool_->NumThreads(), nullptr));
112 }
113 
~ThreadPool()114 ThreadPool::~ThreadPool() {}
115 
Schedule(std::function<void ()> fn)116 void ThreadPool::Schedule(std::function<void()> fn) {
117   CHECK(fn != nullptr);
118   underlying_threadpool_->Schedule(std::move(fn));
119 }
120 
NumShardsUsedByFixedBlockSizeScheduling(const int64_t total,const int64_t block_size)121 int ThreadPool::NumShardsUsedByFixedBlockSizeScheduling(
122     const int64_t total, const int64_t block_size) {
123   if (block_size <= 0 || total <= 1 || total <= block_size ||
124       NumThreads() == 1) {
125     return 1;
126   }
127   return (total + block_size - 1) / block_size;
128 }
129 
NumShardsUsedByTransformRangeConcurrently(const int64_t block_size,const int64_t total)130 int ThreadPool::NumShardsUsedByTransformRangeConcurrently(
131     const int64_t block_size, const int64_t total) {
132   return NumShardsUsedByFixedBlockSizeScheduling(total, block_size);
133 }
134 
ParallelFor(int64_t total,const SchedulingParams & scheduling_params,const std::function<void (int64_t,int64_t)> & fn)135 void ThreadPool::ParallelFor(int64_t total,
136                              const SchedulingParams& scheduling_params,
137                              const std::function<void(int64_t, int64_t)>& fn) {
138   switch (scheduling_params.strategy()) {
139     case SchedulingStrategy::kAdaptive: {
140       if (scheduling_params.cost_per_unit().has_value()) {
141         ParallelFor(total, *scheduling_params.cost_per_unit(), fn);
142       }
143       break;
144     }
145     case SchedulingStrategy::kFixedBlockSize: {
146       if (scheduling_params.block_size().has_value()) {
147         ParallelForFixedBlockSizeScheduling(
148             total, *scheduling_params.block_size(), fn);
149       }
150       break;
151     }
152   }
153 }
154 
TransformRangeConcurrently(const int64_t block_size,const int64_t total,const std::function<void (int64_t,int64_t)> & fn)155 void ThreadPool::TransformRangeConcurrently(
156     const int64_t block_size, const int64_t total,
157     const std::function<void(int64_t, int64_t)>& fn) {
158   ParallelFor(total,
159               SchedulingParams(SchedulingStrategy::kFixedBlockSize,
160                                absl::nullopt /* cost_per_unit */, block_size),
161               fn);
162 }
163 
164 // This functionality is similar to parallelFor, except that reasoning about
165 // the number of shards used is significantly easier.
ParallelForFixedBlockSizeScheduling(const int64_t total,const int64_t block_size,const std::function<void (int64_t,int64_t)> & fn)166 void ThreadPool::ParallelForFixedBlockSizeScheduling(
167     const int64_t total, const int64_t block_size,
168     const std::function<void(int64_t, int64_t)>& fn) {
169   const int num_shards_used =
170       NumShardsUsedByFixedBlockSizeScheduling(total, block_size);
171   if (num_shards_used == 1) {
172     fn(0, total);
173     return;
174   }
175 
176   // Adapted from Eigen's parallelFor implementation.
177   BlockingCounter counter(num_shards_used);
178   std::function<void(int64_t, int64_t)> handle_range =
179       [=, &handle_range, &counter, &fn](int64_t first, int64_t last) {
180         while (last - first > block_size) {
181           // Find something near the midpoint which is a multiple of block size.
182           const int64_t mid = first + ((last - first) / 2 + block_size - 1) /
183                                           block_size * block_size;
184           Schedule([=, &handle_range]() { handle_range(mid, last); });
185           last = mid;
186         }
187         // Single block or less, execute directly.
188         fn(first, last);
189         counter.DecrementCount();  // The shard is done.
190       };
191   if (num_shards_used <= NumThreads()) {
192     // Avoid a thread hop by running the root of the tree and one block on the
193     // main thread.
194     handle_range(0, total);
195   } else {
196     // Execute the root in the thread pool to avoid running work on more than
197     // numThreads() threads.
198     Schedule([=, &handle_range]() { handle_range(0, total); });
199   }
200   counter.Wait();
201 }
202 
ParallelFor(int64_t total,int64_t cost_per_unit,const std::function<void (int64_t,int64_t)> & fn)203 void ThreadPool::ParallelFor(int64_t total, int64_t cost_per_unit,
204                              const std::function<void(int64_t, int64_t)>& fn) {
205   CHECK_GE(total, 0);
206   CHECK_EQ(total, (int64_t)(Eigen::Index)total);
207   threadpool_device_->parallelFor(
208       total, Eigen::TensorOpCost(0, 0, cost_per_unit),
209       [&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); });
210 }
211 
ParallelForWithWorkerId(int64_t total,int64_t cost_per_unit,const std::function<void (int64_t,int64_t,int)> & fn)212 void ThreadPool::ParallelForWithWorkerId(
213     int64_t total, int64_t cost_per_unit,
214     const std::function<void(int64_t, int64_t, int)>& fn) {
215   CHECK_GE(total, 0);
216   CHECK_EQ(total, (int64_t)(Eigen::Index)total);
217 
218   threadpool_device_->parallelFor(total,
219                                   Eigen::TensorOpCost(0, 0, cost_per_unit),
220                                   [this, &fn](int64_t start, int64_t limit) {
221                                     // ParallelFor may use the current thread to
222                                     // do some work synchronously. When calling
223                                     // CurrentThreadId() from outside of the
224                                     // thread pool, we get -1, so we can shift
225                                     // every id up by 1.
226                                     int id = CurrentThreadId() + 1;
227                                     fn(start, limit, id);
228                                   });
229 }
230 
ParallelForWithWorkerId(int64_t total,const SchedulingParams & scheduling_params,const std::function<void (int64_t,int64_t,int)> & fn)231 void ThreadPool::ParallelForWithWorkerId(
232     int64_t total, const SchedulingParams& scheduling_params,
233     const std::function<void(int64_t, int64_t, int)>& fn) {
234   ParallelFor(total, scheduling_params,
235               [this, &fn](int64_t start, int64_t limit) {
236                 // We may use the current thread to do some work synchronously.
237                 // When calling CurrentThreadId() from outside of the thread
238                 // pool, we get -1, so we can shift every id up by 1.
239                 int id = CurrentThreadId() + 1;
240                 fn(start, limit, id);
241               });
242 }
243 
NumThreads() const244 int ThreadPool::NumThreads() const {
245   return underlying_threadpool_->NumThreads();
246 }
247 
CurrentThreadId() const248 int ThreadPool::CurrentThreadId() const {
249   return underlying_threadpool_->CurrentThreadId();
250 }
251 
ScheduleWithHint(std::function<void ()> fn,int start,int limit)252 void ThreadPool::ScheduleWithHint(std::function<void()> fn, int start,
253                                   int limit) {
254   underlying_threadpool_->ScheduleWithHint(std::move(fn), start, limit);
255 }
256 
SetStealPartitions(const std::vector<std::pair<unsigned,unsigned>> & partitions)257 void ThreadPool::SetStealPartitions(
258     const std::vector<std::pair<unsigned, unsigned>>& partitions) {
259   // ThreadPool::SetStealPartitions is only called in the constructor of
260   // RunHandlerPool::Impl, which currently instantiates ThreadPool using a
261   // constructor that does not take user_threadpool. Thus we assume
262   // eigen_threadpool_ is not null here.
263   DCHECK(eigen_threadpool_ != nullptr);
264   eigen_threadpool_->SetStealPartitions(partitions);
265 }
266 
AsEigenThreadPool() const267 Eigen::ThreadPoolInterface* ThreadPool::AsEigenThreadPool() const {
268   DCHECK(underlying_threadpool_ != nullptr);
269   return underlying_threadpool_;
270 }
271 }  // namespace thread
272 }  // namespace tensorflow
273