xref: /aosp_15_r20/external/eigen/unsupported/Eigen/CXX11/src/ThreadPool/ThreadLocal.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2016 Benoit Steiner <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_THREADPOOL_THREAD_LOCAL_H
11 #define EIGEN_CXX11_THREADPOOL_THREAD_LOCAL_H
12 
13 #ifdef EIGEN_AVOID_THREAD_LOCAL
14 
15 #ifdef EIGEN_THREAD_LOCAL
16 #undef EIGEN_THREAD_LOCAL
17 #endif
18 
19 #else
20 
21 #if EIGEN_MAX_CPP_VER >= 11 &&                         \
22     ((EIGEN_COMP_GNUC && EIGEN_GNUC_AT_LEAST(4, 8)) || \
23      __has_feature(cxx_thread_local)                || \
24      (EIGEN_COMP_MSVC >= 1900) )
25 #define EIGEN_THREAD_LOCAL static thread_local
26 #endif
27 
28 // Disable TLS for Apple and Android builds with older toolchains.
29 #if defined(__APPLE__)
30 // Included for TARGET_OS_IPHONE, __IPHONE_OS_VERSION_MIN_REQUIRED,
31 // __IPHONE_8_0.
32 #include <Availability.h>
33 #include <TargetConditionals.h>
34 #endif
35 // Checks whether C++11's `thread_local` storage duration specifier is
36 // supported.
37 #if defined(__apple_build_version__) &&     \
38     ((__apple_build_version__ < 8000042) || \
39      (TARGET_OS_IPHONE && __IPHONE_OS_VERSION_MIN_REQUIRED < __IPHONE_9_0))
40 // Notes: Xcode's clang did not support `thread_local` until version
41 // 8, and even then not for all iOS < 9.0.
42 #undef EIGEN_THREAD_LOCAL
43 
44 #elif defined(__ANDROID__) && EIGEN_COMP_CLANG
45 // There are platforms for which TLS should not be used even though the compiler
46 // makes it seem like it's supported (Android NDK < r12b for example).
47 // This is primarily because of linker problems and toolchain misconfiguration:
48 // TLS isn't supported until NDK r12b per
49 // https://developer.android.com/ndk/downloads/revision_history.html
50 // Since NDK r16, `__NDK_MAJOR__` and `__NDK_MINOR__` are defined in
51 // <android/ndk-version.h>. For NDK < r16, users should define these macros,
52 // e.g. `-D__NDK_MAJOR__=11 -D__NKD_MINOR__=0` for NDK r11.
53 #if __has_include(<android/ndk-version.h>)
54 #include <android/ndk-version.h>
55 #endif  // __has_include(<android/ndk-version.h>)
56 #if defined(__ANDROID__) && defined(__clang__) && defined(__NDK_MAJOR__) && \
57     defined(__NDK_MINOR__) &&                                               \
58     ((__NDK_MAJOR__ < 12) || ((__NDK_MAJOR__ == 12) && (__NDK_MINOR__ < 1)))
59 #undef EIGEN_THREAD_LOCAL
60 #endif
61 #endif  // defined(__ANDROID__) && defined(__clang__)
62 
63 #endif  // EIGEN_AVOID_THREAD_LOCAL
64 
65 namespace Eigen {
66 
67 namespace internal {
68 template <typename T>
69 struct ThreadLocalNoOpInitialize {
operatorThreadLocalNoOpInitialize70   void operator()(T&) const {}
71 };
72 
73 template <typename T>
74 struct ThreadLocalNoOpRelease {
operatorThreadLocalNoOpRelease75   void operator()(T&) const {}
76 };
77 
78 }  // namespace internal
79 
80 // Thread local container for elements of type T, that does not use thread local
81 // storage. As long as the number of unique threads accessing this storage
82 // is smaller than `capacity_`, it is lock-free and wait-free. Otherwise it will
83 // use a mutex for synchronization.
84 //
85 // Type `T` has to be default constructible, and by default each thread will get
86 // a default constructed value. It is possible to specify custom `initialize`
87 // callable, that will be called lazily from each thread accessing this object,
88 // and will be passed a default initialized object of type `T`. Also it's
89 // possible to pass a custom `release` callable, that will be invoked before
90 // calling ~T().
91 //
92 // Example:
93 //
94 //   struct Counter {
95 //     int value = 0;
96 //   }
97 //
98 //   Eigen::ThreadLocal<Counter> counter(10);
99 //
100 //   // Each thread will have access to it's own counter object.
101 //   Counter& cnt = counter.local();
102 //   cnt++;
103 //
104 // WARNING: Eigen::ThreadLocal uses the OS-specific value returned by
105 // std::this_thread::get_id() to identify threads. This value is not guaranteed
106 // to be unique except for the life of the thread. A newly created thread may
107 // get an OS-specific ID equal to that of an already destroyed thread.
108 //
109 // Somewhat similar to TBB thread local storage, with similar restrictions:
110 // https://www.threadingbuildingblocks.org/docs/help/reference/thread_local_storage/enumerable_thread_specific_cls.html
111 //
112 template <typename T,
113           typename Initialize = internal::ThreadLocalNoOpInitialize<T>,
114           typename Release = internal::ThreadLocalNoOpRelease<T>>
115 class ThreadLocal {
116   // We preallocate default constructed elements in MaxSizedVector.
117   static_assert(std::is_default_constructible<T>::value,
118                 "ThreadLocal data type must be default constructible");
119 
120  public:
ThreadLocal(int capacity)121   explicit ThreadLocal(int capacity)
122       : ThreadLocal(capacity, internal::ThreadLocalNoOpInitialize<T>(),
123                     internal::ThreadLocalNoOpRelease<T>()) {}
124 
ThreadLocal(int capacity,Initialize initialize)125   ThreadLocal(int capacity, Initialize initialize)
126       : ThreadLocal(capacity, std::move(initialize),
127                     internal::ThreadLocalNoOpRelease<T>()) {}
128 
ThreadLocal(int capacity,Initialize initialize,Release release)129   ThreadLocal(int capacity, Initialize initialize, Release release)
130       : initialize_(std::move(initialize)),
131         release_(std::move(release)),
132         capacity_(capacity),
133         data_(capacity_),
134         ptr_(capacity_),
135         filled_records_(0) {
136     eigen_assert(capacity_ >= 0);
137     data_.resize(capacity_);
138     for (int i = 0; i < capacity_; ++i) {
139       ptr_.emplace_back(nullptr);
140     }
141   }
142 
local()143   T& local() {
144     std::thread::id this_thread = std::this_thread::get_id();
145     if (capacity_ == 0) return SpilledLocal(this_thread);
146 
147     std::size_t h = std::hash<std::thread::id>()(this_thread);
148     const int start_idx = h % capacity_;
149 
150     // NOTE: From the definition of `std::this_thread::get_id()` it is
151     // guaranteed that we never can have concurrent insertions with the same key
152     // to our hash-map like data structure. If we didn't find an element during
153     // the initial traversal, it's guaranteed that no one else could have
154     // inserted it while we are in this function. This allows to massively
155     // simplify out lock-free insert-only hash map.
156 
157     // Check if we already have an element for `this_thread`.
158     int idx = start_idx;
159     while (ptr_[idx].load() != nullptr) {
160       ThreadIdAndValue& record = *(ptr_[idx].load());
161       if (record.thread_id == this_thread) return record.value;
162 
163       idx += 1;
164       if (idx >= capacity_) idx -= capacity_;
165       if (idx == start_idx) break;
166     }
167 
168     // If we are here, it means that we found an insertion point in lookup
169     // table at `idx`, or we did a full traversal and table is full.
170 
171     // If lock-free storage is full, fallback on mutex.
172     if (filled_records_.load() >= capacity_) return SpilledLocal(this_thread);
173 
174     // We double check that we still have space to insert an element into a lock
175     // free storage. If old value in `filled_records_` is larger than the
176     // records capacity, it means that some other thread added an element while
177     // we were traversing lookup table.
178     int insertion_index =
179         filled_records_.fetch_add(1, std::memory_order_relaxed);
180     if (insertion_index >= capacity_) return SpilledLocal(this_thread);
181 
182     // At this point it's guaranteed that we can access to
183     // data_[insertion_index_] without a data race.
184     data_[insertion_index].thread_id = this_thread;
185     initialize_(data_[insertion_index].value);
186 
187     // That's the pointer we'll put into the lookup table.
188     ThreadIdAndValue* inserted = &data_[insertion_index];
189 
190     // We'll use nullptr pointer to ThreadIdAndValue in a compare-and-swap loop.
191     ThreadIdAndValue* empty = nullptr;
192 
193     // Now we have to find an insertion point into the lookup table. We start
194     // from the `idx` that was identified as an insertion point above, it's
195     // guaranteed that we will have an empty record somewhere in a lookup table
196     // (because we created a record in the `data_`).
197     const int insertion_idx = idx;
198 
199     do {
200       // Always start search from the original insertion candidate.
201       idx = insertion_idx;
202       while (ptr_[idx].load() != nullptr) {
203         idx += 1;
204         if (idx >= capacity_) idx -= capacity_;
205         // If we did a full loop, it means that we don't have any free entries
206         // in the lookup table, and this means that something is terribly wrong.
207         eigen_assert(idx != insertion_idx);
208       }
209       // Atomic CAS of the pointer guarantees that any other thread, that will
210       // follow this pointer will see all the mutations in the `data_`.
211     } while (!ptr_[idx].compare_exchange_weak(empty, inserted));
212 
213     return inserted->value;
214   }
215 
216   // WARN: It's not thread safe to call it concurrently with `local()`.
ForEach(std::function<void (std::thread::id,T &)> f)217   void ForEach(std::function<void(std::thread::id, T&)> f) {
218     // Reading directly from `data_` is unsafe, because only CAS to the
219     // record in `ptr_` makes all changes visible to other threads.
220     for (auto& ptr : ptr_) {
221       ThreadIdAndValue* record = ptr.load();
222       if (record == nullptr) continue;
223       f(record->thread_id, record->value);
224     }
225 
226     // We did not spill into the map based storage.
227     if (filled_records_.load(std::memory_order_relaxed) < capacity_) return;
228 
229     // Adds a happens before edge from the last call to SpilledLocal().
230     std::unique_lock<std::mutex> lock(mu_);
231     for (auto& kv : per_thread_map_) {
232       f(kv.first, kv.second);
233     }
234   }
235 
236   // WARN: It's not thread safe to call it concurrently with `local()`.
~ThreadLocal()237   ~ThreadLocal() {
238     // Reading directly from `data_` is unsafe, because only CAS to the record
239     // in `ptr_` makes all changes visible to other threads.
240     for (auto& ptr : ptr_) {
241       ThreadIdAndValue* record = ptr.load();
242       if (record == nullptr) continue;
243       release_(record->value);
244     }
245 
246     // We did not spill into the map based storage.
247     if (filled_records_.load(std::memory_order_relaxed) < capacity_) return;
248 
249     // Adds a happens before edge from the last call to SpilledLocal().
250     std::unique_lock<std::mutex> lock(mu_);
251     for (auto& kv : per_thread_map_) {
252       release_(kv.second);
253     }
254   }
255 
256  private:
257   struct ThreadIdAndValue {
258     std::thread::id thread_id;
259     T value;
260   };
261 
262   // Use unordered map guarded by a mutex when lock free storage is full.
SpilledLocal(std::thread::id this_thread)263   T& SpilledLocal(std::thread::id this_thread) {
264     std::unique_lock<std::mutex> lock(mu_);
265 
266     auto it = per_thread_map_.find(this_thread);
267     if (it == per_thread_map_.end()) {
268       auto result = per_thread_map_.emplace(this_thread, T());
269       eigen_assert(result.second);
270       initialize_((*result.first).second);
271       return (*result.first).second;
272     } else {
273       return it->second;
274     }
275   }
276 
277   Initialize initialize_;
278   Release release_;
279   const int capacity_;
280 
281   // Storage that backs lock-free lookup table `ptr_`. Records stored in this
282   // storage contiguously starting from index 0.
283   MaxSizeVector<ThreadIdAndValue> data_;
284 
285   // Atomic pointers to the data stored in `data_`. Used as a lookup table for
286   // linear probing hash map (https://en.wikipedia.org/wiki/Linear_probing).
287   MaxSizeVector<std::atomic<ThreadIdAndValue*>> ptr_;
288 
289   // Number of records stored in the `data_`.
290   std::atomic<int> filled_records_;
291 
292   // We fallback on per thread map if lock-free storage is full. In practice
293   // this should never happen, if `capacity_` is a reasonable estimate of the
294   // number of threads running in a system.
295   std::mutex mu_;  // Protects per_thread_map_.
296   std::unordered_map<std::thread::id, T> per_thread_map_;
297 };
298 
299 }  // namespace Eigen
300 
301 #endif  // EIGEN_CXX11_THREADPOOL_THREAD_LOCAL_H
302