xref: /aosp_15_r20/external/executorch/extension/threadpool/threadpool.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/threadpool/threadpool.h>
10 
11 #include <algorithm>
12 #include <atomic>
13 #include <memory>
14 
15 #include <executorch/extension/threadpool/threadpool_guard.h>
16 #include <executorch/runtime/platform/assert.h>
17 
18 #include <cpuinfo.h>
19 
20 namespace executorch::extension::threadpool {
21 
22 #if !(defined(WIN32))
23 namespace {
24 // After fork, the child process inherits the data-structures of the parent
25 // process' thread-pool, but since those threads don't exist, the thread-pool
26 // is corrupt. It's leaked in order to prevent segfaults.
27 // Ref: https://github.com/pytorch/pytorch/issues/54752#issuecomment-810315302
28 bool leak_corrupted_threadpool = false;
29 
child_atfork()30 void child_atfork() {
31   leak_corrupted_threadpool = true;
32 }
33 
34 } // namespace
35 #endif
36 
ThreadPool(size_t thread_count)37 ThreadPool::ThreadPool(size_t thread_count)
38     : threadpool_(pthreadpool_create(thread_count), pthreadpool_destroy) {}
39 
get_thread_count() const40 size_t ThreadPool::get_thread_count() const {
41   std::lock_guard<std::mutex> lock{mutex_};
42 
43   ET_CHECK_MSG(threadpool_.get(), "Invalid threadpool!");
44   return pthreadpool_get_threads_count(threadpool_.get());
45 }
46 
_unsafe_reset_threadpool(uint32_t new_thread_count)47 bool ThreadPool::_unsafe_reset_threadpool(uint32_t new_thread_count) {
48   // No need to do anything if the count is same or 0
49   if (new_thread_count == get_thread_count() || new_thread_count == 0) {
50     return true;
51   }
52 
53   std::lock_guard<std::mutex> lock{mutex_};
54 
55   threadpool_.reset(pthreadpool_create(new_thread_count));
56   return true;
57 }
58 
run(const std::function<void (size_t)> & fn,const size_t range)59 void ThreadPool::run(
60     const std::function<void(size_t)>& fn,
61     const size_t range) {
62   // Run on same thread if NoThreadPoolGuard guard is enabled
63   if (NoThreadPoolGuard::is_enabled()) {
64     for (size_t i = 0; i < range; ++i) {
65       fn(i);
66     }
67     return;
68   }
69 
70   std::lock_guard<std::mutex> lock{mutex_};
71 
72   ET_CHECK_MSG(!NoThreadPoolGuard::is_enabled(), "Inside a threadpool guard!");
73   ET_CHECK_MSG(threadpool_.get(), "Invalid threadpool!");
74 
75   struct Context final {
76     const std::function<void(size_t)>& fn;
77   } context{
78       fn,
79   };
80 
81   pthreadpool_parallelize_1d(
82       threadpool_.get(),
83       // Note: pthreadpool_parallelize_1d() is a blocking function.  The
84       // function pointer to this lambda passed on to
85       // pthreadpool_parallelize_1d() cannot go out of scope until
86       // pthreadpool_parallelize_1d() returns.
87       [](void* const context, const size_t item) {
88         NoThreadPoolGuard guard;
89         reinterpret_cast<Context*>(context)->fn(item);
90       },
91       &context,
92       range,
93       0u);
94 }
95 
96 // get_threadpool is not thread safe due to leak_corrupted_threadpool
97 // Make this part threadsafe: TODO(kimishpatel)
get_threadpool()98 ThreadPool* get_threadpool() {
99   ET_CHECK_MSG(cpuinfo_initialize(), "cpuinfo initialization failed");
100   int num_threads = cpuinfo_get_processors_count();
101   /*
102    * For llvm-tsan, holding limit for the number of locks for a single thread
103    * is 63 (because of comparison < 64 instead of <=). pthreadpool's worst
104    * case is the number of threads in a pool. So we want to limit the threadpool
105    * size to 64 when running with tsan. However, sometimes it is tricky to
106    * detect if we are running under tsan, for now capping the default
107    * threadcount to the tsan limit unconditionally.
108    */
109   constexpr int tsan_thread_limit = 63;
110   num_threads = std::min(num_threads, tsan_thread_limit);
111   static auto threadpool = std::make_unique<ThreadPool>(num_threads);
112 
113 // Inheriting from old threadpool to get around segfault issue
114 // commented above at child_atfork
115 #if !(defined(WIN32))
116   // @lint-ignore CLANGTIDY facebook-hte-std::once_flag
117   static std::once_flag flag;
118   // @lint-ignore CLANGTIDY facebook-hte-std::call_once
119   std::call_once(
120       flag, []() { pthread_atfork(nullptr, nullptr, child_atfork); });
121   if ET_UNLIKELY (leak_corrupted_threadpool) {
122     leak_corrupted_threadpool = false;
123     if (auto leaked = threadpool.release()) {
124       auto t = leaked->get_thread_count();
125       threadpool = std::make_unique<ThreadPool>(t);
126     }
127   }
128 #endif
129   return threadpool.get();
130 }
131 
get_pthreadpool()132 pthreadpool_t get_pthreadpool() {
133   if (NoThreadPoolGuard::is_enabled()) {
134     return nullptr;
135   }
136   ThreadPool* const threadpool = get_threadpool();
137   ET_CHECK_MSG(threadpool, "Failed to acquire an instance of ThreadPool!");
138   return threadpool->threadpool_.get();
139 }
140 
141 } // namespace executorch::extension::threadpool
142