1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/threadpool/threadpool.h>
10
11 #include <algorithm>
12 #include <atomic>
13 #include <memory>
14
15 #include <executorch/extension/threadpool/threadpool_guard.h>
16 #include <executorch/runtime/platform/assert.h>
17
18 #include <cpuinfo.h>
19
20 namespace executorch::extension::threadpool {
21
22 #if !(defined(WIN32))
23 namespace {
24 // After fork, the child process inherits the data-structures of the parent
25 // process' thread-pool, but since those threads don't exist, the thread-pool
26 // is corrupt. It's leaked in order to prevent segfaults.
27 // Ref: https://github.com/pytorch/pytorch/issues/54752#issuecomment-810315302
28 bool leak_corrupted_threadpool = false;
29
child_atfork()30 void child_atfork() {
31 leak_corrupted_threadpool = true;
32 }
33
34 } // namespace
35 #endif
36
ThreadPool(size_t thread_count)37 ThreadPool::ThreadPool(size_t thread_count)
38 : threadpool_(pthreadpool_create(thread_count), pthreadpool_destroy) {}
39
get_thread_count() const40 size_t ThreadPool::get_thread_count() const {
41 std::lock_guard<std::mutex> lock{mutex_};
42
43 ET_CHECK_MSG(threadpool_.get(), "Invalid threadpool!");
44 return pthreadpool_get_threads_count(threadpool_.get());
45 }
46
_unsafe_reset_threadpool(uint32_t new_thread_count)47 bool ThreadPool::_unsafe_reset_threadpool(uint32_t new_thread_count) {
48 // No need to do anything if the count is same or 0
49 if (new_thread_count == get_thread_count() || new_thread_count == 0) {
50 return true;
51 }
52
53 std::lock_guard<std::mutex> lock{mutex_};
54
55 threadpool_.reset(pthreadpool_create(new_thread_count));
56 return true;
57 }
58
run(const std::function<void (size_t)> & fn,const size_t range)59 void ThreadPool::run(
60 const std::function<void(size_t)>& fn,
61 const size_t range) {
62 // Run on same thread if NoThreadPoolGuard guard is enabled
63 if (NoThreadPoolGuard::is_enabled()) {
64 for (size_t i = 0; i < range; ++i) {
65 fn(i);
66 }
67 return;
68 }
69
70 std::lock_guard<std::mutex> lock{mutex_};
71
72 ET_CHECK_MSG(!NoThreadPoolGuard::is_enabled(), "Inside a threadpool guard!");
73 ET_CHECK_MSG(threadpool_.get(), "Invalid threadpool!");
74
75 struct Context final {
76 const std::function<void(size_t)>& fn;
77 } context{
78 fn,
79 };
80
81 pthreadpool_parallelize_1d(
82 threadpool_.get(),
83 // Note: pthreadpool_parallelize_1d() is a blocking function. The
84 // function pointer to this lambda passed on to
85 // pthreadpool_parallelize_1d() cannot go out of scope until
86 // pthreadpool_parallelize_1d() returns.
87 [](void* const context, const size_t item) {
88 NoThreadPoolGuard guard;
89 reinterpret_cast<Context*>(context)->fn(item);
90 },
91 &context,
92 range,
93 0u);
94 }
95
96 // get_threadpool is not thread safe due to leak_corrupted_threadpool
97 // Make this part threadsafe: TODO(kimishpatel)
get_threadpool()98 ThreadPool* get_threadpool() {
99 ET_CHECK_MSG(cpuinfo_initialize(), "cpuinfo initialization failed");
100 int num_threads = cpuinfo_get_processors_count();
101 /*
102 * For llvm-tsan, holding limit for the number of locks for a single thread
103 * is 63 (because of comparison < 64 instead of <=). pthreadpool's worst
104 * case is the number of threads in a pool. So we want to limit the threadpool
105 * size to 64 when running with tsan. However, sometimes it is tricky to
106 * detect if we are running under tsan, for now capping the default
107 * threadcount to the tsan limit unconditionally.
108 */
109 constexpr int tsan_thread_limit = 63;
110 num_threads = std::min(num_threads, tsan_thread_limit);
111 static auto threadpool = std::make_unique<ThreadPool>(num_threads);
112
113 // Inheriting from old threadpool to get around segfault issue
114 // commented above at child_atfork
115 #if !(defined(WIN32))
116 // @lint-ignore CLANGTIDY facebook-hte-std::once_flag
117 static std::once_flag flag;
118 // @lint-ignore CLANGTIDY facebook-hte-std::call_once
119 std::call_once(
120 flag, []() { pthread_atfork(nullptr, nullptr, child_atfork); });
121 if ET_UNLIKELY (leak_corrupted_threadpool) {
122 leak_corrupted_threadpool = false;
123 if (auto leaked = threadpool.release()) {
124 auto t = leaked->get_thread_count();
125 threadpool = std::make_unique<ThreadPool>(t);
126 }
127 }
128 #endif
129 return threadpool.get();
130 }
131
get_pthreadpool()132 pthreadpool_t get_pthreadpool() {
133 if (NoThreadPoolGuard::is_enabled()) {
134 return nullptr;
135 }
136 ThreadPool* const threadpool = get_threadpool();
137 ET_CHECK_MSG(threadpool, "Failed to acquire an instance of ThreadPool!");
138 return threadpool->threadpool_.get();
139 }
140
141 } // namespace executorch::extension::threadpool
142