xref: /aosp_15_r20/external/pytorch/c10/util/ThreadLocal.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker /**
6*da0073e9SAndroid Build Coastguard Worker  * Android versions with libgnustl incorrectly handle thread_local C++
7*da0073e9SAndroid Build Coastguard Worker  * qualifier with composite types. NDK up to r17 version is affected.
8*da0073e9SAndroid Build Coastguard Worker  *
9*da0073e9SAndroid Build Coastguard Worker  * (A fix landed on Jun 4 2018:
10*da0073e9SAndroid Build Coastguard Worker  * https://android-review.googlesource.com/c/toolchain/gcc/+/683601)
11*da0073e9SAndroid Build Coastguard Worker  *
12*da0073e9SAndroid Build Coastguard Worker  * In such cases, use c10::ThreadLocal<T> wrapper
13*da0073e9SAndroid Build Coastguard Worker  * which is `pthread_*` based with smart pointer semantics.
14*da0073e9SAndroid Build Coastguard Worker  *
15*da0073e9SAndroid Build Coastguard Worker  * In addition, convenient macro C10_DEFINE_TLS_static is available.
16*da0073e9SAndroid Build Coastguard Worker  * To define static TLS variable of type std::string, do the following
17*da0073e9SAndroid Build Coastguard Worker  * ```
18*da0073e9SAndroid Build Coastguard Worker  *  C10_DEFINE_TLS_static(std::string, str_tls_);
19*da0073e9SAndroid Build Coastguard Worker  *  ///////
20*da0073e9SAndroid Build Coastguard Worker  *  {
21*da0073e9SAndroid Build Coastguard Worker  *    *str_tls_ = "abc";
22*da0073e9SAndroid Build Coastguard Worker  *    assert(str_tls_->length(), 3);
23*da0073e9SAndroid Build Coastguard Worker  *  }
24*da0073e9SAndroid Build Coastguard Worker  * ```
25*da0073e9SAndroid Build Coastguard Worker  *
26*da0073e9SAndroid Build Coastguard Worker  * (see c10/test/util/ThreadLocal_test.cpp for more examples)
27*da0073e9SAndroid Build Coastguard Worker  */
28*da0073e9SAndroid Build Coastguard Worker #if !defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
29*da0073e9SAndroid Build Coastguard Worker 
30*da0073e9SAndroid Build Coastguard Worker #if defined(C10_ANDROID) && defined(__GLIBCXX__) && __GLIBCXX__ < 20180604
31*da0073e9SAndroid Build Coastguard Worker #define C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE
32*da0073e9SAndroid Build Coastguard Worker #endif // defined(C10_ANDROID) && defined(__GLIBCXX__) && __GLIBCXX__ < 20180604
33*da0073e9SAndroid Build Coastguard Worker 
34*da0073e9SAndroid Build Coastguard Worker #endif // !defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
35*da0073e9SAndroid Build Coastguard Worker 
36*da0073e9SAndroid Build Coastguard Worker #if defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
37*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
38*da0073e9SAndroid Build Coastguard Worker #include <errno.h>
39*da0073e9SAndroid Build Coastguard Worker #include <pthread.h>
40*da0073e9SAndroid Build Coastguard Worker #include <memory>
41*da0073e9SAndroid Build Coastguard Worker namespace c10 {
42*da0073e9SAndroid Build Coastguard Worker 
43*da0073e9SAndroid Build Coastguard Worker /**
44*da0073e9SAndroid Build Coastguard Worker  * @brief Temporary thread_local C++ qualifier replacement for Android
45*da0073e9SAndroid Build Coastguard Worker  * based on `pthread_*`.
46*da0073e9SAndroid Build Coastguard Worker  * To be used with composite types that provide default ctor.
47*da0073e9SAndroid Build Coastguard Worker  */
48*da0073e9SAndroid Build Coastguard Worker template <typename Type>
49*da0073e9SAndroid Build Coastguard Worker class ThreadLocal {
50*da0073e9SAndroid Build Coastguard Worker  public:
ThreadLocal()51*da0073e9SAndroid Build Coastguard Worker   ThreadLocal() {
52*da0073e9SAndroid Build Coastguard Worker     pthread_key_create(
53*da0073e9SAndroid Build Coastguard Worker         &key_, [](void* buf) { delete static_cast<Type*>(buf); });
54*da0073e9SAndroid Build Coastguard Worker   }
55*da0073e9SAndroid Build Coastguard Worker 
~ThreadLocal()56*da0073e9SAndroid Build Coastguard Worker   ~ThreadLocal() {
57*da0073e9SAndroid Build Coastguard Worker     if (void* current = pthread_getspecific(key_)) {
58*da0073e9SAndroid Build Coastguard Worker       delete static_cast<Type*>(current);
59*da0073e9SAndroid Build Coastguard Worker     }
60*da0073e9SAndroid Build Coastguard Worker 
61*da0073e9SAndroid Build Coastguard Worker     pthread_key_delete(key_);
62*da0073e9SAndroid Build Coastguard Worker   }
63*da0073e9SAndroid Build Coastguard Worker 
64*da0073e9SAndroid Build Coastguard Worker   ThreadLocal(const ThreadLocal&) = delete;
65*da0073e9SAndroid Build Coastguard Worker   ThreadLocal& operator=(const ThreadLocal&) = delete;
66*da0073e9SAndroid Build Coastguard Worker 
get()67*da0073e9SAndroid Build Coastguard Worker   Type& get() {
68*da0073e9SAndroid Build Coastguard Worker     if (void* current = pthread_getspecific(key_)) {
69*da0073e9SAndroid Build Coastguard Worker       return *static_cast<Type*>(current);
70*da0073e9SAndroid Build Coastguard Worker     }
71*da0073e9SAndroid Build Coastguard Worker 
72*da0073e9SAndroid Build Coastguard Worker     std::unique_ptr<Type> ptr = std::make_unique<Type>();
73*da0073e9SAndroid Build Coastguard Worker     if (0 == pthread_setspecific(key_, ptr.get())) {
74*da0073e9SAndroid Build Coastguard Worker       return *ptr.release();
75*da0073e9SAndroid Build Coastguard Worker     }
76*da0073e9SAndroid Build Coastguard Worker 
77*da0073e9SAndroid Build Coastguard Worker     int err = errno;
78*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(false, "pthread_setspecific() failed, errno = ", err);
79*da0073e9SAndroid Build Coastguard Worker   }
80*da0073e9SAndroid Build Coastguard Worker 
81*da0073e9SAndroid Build Coastguard Worker   Type& operator*() {
82*da0073e9SAndroid Build Coastguard Worker     return get();
83*da0073e9SAndroid Build Coastguard Worker   }
84*da0073e9SAndroid Build Coastguard Worker 
85*da0073e9SAndroid Build Coastguard Worker   Type* operator->() {
86*da0073e9SAndroid Build Coastguard Worker     return &get();
87*da0073e9SAndroid Build Coastguard Worker   }
88*da0073e9SAndroid Build Coastguard Worker 
89*da0073e9SAndroid Build Coastguard Worker  private:
90*da0073e9SAndroid Build Coastguard Worker   pthread_key_t key_;
91*da0073e9SAndroid Build Coastguard Worker };
92*da0073e9SAndroid Build Coastguard Worker 
93*da0073e9SAndroid Build Coastguard Worker } // namespace c10
94*da0073e9SAndroid Build Coastguard Worker 
95*da0073e9SAndroid Build Coastguard Worker #define C10_DEFINE_TLS_static(Type, Name) static ::c10::ThreadLocal<Type> Name
96*da0073e9SAndroid Build Coastguard Worker 
97*da0073e9SAndroid Build Coastguard Worker #define C10_DECLARE_TLS_class_static(Class, Type, Name) \
98*da0073e9SAndroid Build Coastguard Worker   static ::c10::ThreadLocal<Type> Name
99*da0073e9SAndroid Build Coastguard Worker 
100*da0073e9SAndroid Build Coastguard Worker #define C10_DEFINE_TLS_class_static(Class, Type, Name) \
101*da0073e9SAndroid Build Coastguard Worker   ::c10::ThreadLocal<Type> Class::Name
102*da0073e9SAndroid Build Coastguard Worker 
103*da0073e9SAndroid Build Coastguard Worker #else // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
104*da0073e9SAndroid Build Coastguard Worker 
105*da0073e9SAndroid Build Coastguard Worker namespace c10 {
106*da0073e9SAndroid Build Coastguard Worker 
107*da0073e9SAndroid Build Coastguard Worker /**
108*da0073e9SAndroid Build Coastguard Worker  * @brief Default thread_local implementation for non-Android cases.
109*da0073e9SAndroid Build Coastguard Worker  * To be used with composite types that provide default ctor.
110*da0073e9SAndroid Build Coastguard Worker  */
111*da0073e9SAndroid Build Coastguard Worker template <typename Type>
112*da0073e9SAndroid Build Coastguard Worker class ThreadLocal {
113*da0073e9SAndroid Build Coastguard Worker  public:
114*da0073e9SAndroid Build Coastguard Worker   using Accessor = Type* (*)();
ThreadLocal(Accessor accessor)115*da0073e9SAndroid Build Coastguard Worker   explicit ThreadLocal(Accessor accessor) : accessor_(accessor) {}
116*da0073e9SAndroid Build Coastguard Worker 
117*da0073e9SAndroid Build Coastguard Worker   ThreadLocal(const ThreadLocal&) = delete;
118*da0073e9SAndroid Build Coastguard Worker   ThreadLocal& operator=(const ThreadLocal&) = delete;
119*da0073e9SAndroid Build Coastguard Worker 
get()120*da0073e9SAndroid Build Coastguard Worker   Type& get() {
121*da0073e9SAndroid Build Coastguard Worker     return *accessor_();
122*da0073e9SAndroid Build Coastguard Worker   }
123*da0073e9SAndroid Build Coastguard Worker 
124*da0073e9SAndroid Build Coastguard Worker   Type& operator*() {
125*da0073e9SAndroid Build Coastguard Worker     return get();
126*da0073e9SAndroid Build Coastguard Worker   }
127*da0073e9SAndroid Build Coastguard Worker 
128*da0073e9SAndroid Build Coastguard Worker   Type* operator->() {
129*da0073e9SAndroid Build Coastguard Worker     return &get();
130*da0073e9SAndroid Build Coastguard Worker   }
131*da0073e9SAndroid Build Coastguard Worker 
132*da0073e9SAndroid Build Coastguard Worker  private:
133*da0073e9SAndroid Build Coastguard Worker   Accessor accessor_;
134*da0073e9SAndroid Build Coastguard Worker };
135*da0073e9SAndroid Build Coastguard Worker 
136*da0073e9SAndroid Build Coastguard Worker } // namespace c10
137*da0073e9SAndroid Build Coastguard Worker 
138*da0073e9SAndroid Build Coastguard Worker #define C10_DEFINE_TLS_static(Type, Name)     \
139*da0073e9SAndroid Build Coastguard Worker   static ::c10::ThreadLocal<Type> Name([]() { \
140*da0073e9SAndroid Build Coastguard Worker     static thread_local Type var;             \
141*da0073e9SAndroid Build Coastguard Worker     return &var;                              \
142*da0073e9SAndroid Build Coastguard Worker   })
143*da0073e9SAndroid Build Coastguard Worker 
144*da0073e9SAndroid Build Coastguard Worker #define C10_DECLARE_TLS_class_static(Class, Type, Name) \
145*da0073e9SAndroid Build Coastguard Worker   static ::c10::ThreadLocal<Type> Name
146*da0073e9SAndroid Build Coastguard Worker 
147*da0073e9SAndroid Build Coastguard Worker #define C10_DEFINE_TLS_class_static(Class, Type, Name) \
148*da0073e9SAndroid Build Coastguard Worker   ::c10::ThreadLocal<Type> Class::Name([]() {          \
149*da0073e9SAndroid Build Coastguard Worker     static thread_local Type var;                      \
150*da0073e9SAndroid Build Coastguard Worker     return &var;                                       \
151*da0073e9SAndroid Build Coastguard Worker   })
152*da0073e9SAndroid Build Coastguard Worker 
153*da0073e9SAndroid Build Coastguard Worker #endif // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
154