xref: /aosp_15_r20/external/pytorch/c10/util/ThreadLocal.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/macros/Macros.h>
4 
5 /**
6  * Android versions with libgnustl incorrectly handle thread_local C++
7  * qualifier with composite types. NDK up to r17 version is affected.
8  *
9  * (A fix landed on Jun 4 2018:
10  * https://android-review.googlesource.com/c/toolchain/gcc/+/683601)
11  *
12  * In such cases, use c10::ThreadLocal<T> wrapper
13  * which is `pthread_*` based with smart pointer semantics.
14  *
15  * In addition, convenient macro C10_DEFINE_TLS_static is available.
16  * To define static TLS variable of type std::string, do the following
17  * ```
18  *  C10_DEFINE_TLS_static(std::string, str_tls_);
19  *  ///////
20  *  {
21  *    *str_tls_ = "abc";
22  *    assert(str_tls_->length(), 3);
23  *  }
24  * ```
25  *
26  * (see c10/test/util/ThreadLocal_test.cpp for more examples)
27  */
28 #if !defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
29 
30 #if defined(C10_ANDROID) && defined(__GLIBCXX__) && __GLIBCXX__ < 20180604
31 #define C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE
32 #endif // defined(C10_ANDROID) && defined(__GLIBCXX__) && __GLIBCXX__ < 20180604
33 
34 #endif // !defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
35 
36 #if defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
37 #include <c10/util/Exception.h>
38 #include <errno.h>
39 #include <pthread.h>
40 #include <memory>
41 namespace c10 {
42 
43 /**
44  * @brief Temporary thread_local C++ qualifier replacement for Android
45  * based on `pthread_*`.
46  * To be used with composite types that provide default ctor.
47  */
48 template <typename Type>
49 class ThreadLocal {
50  public:
ThreadLocal()51   ThreadLocal() {
52     pthread_key_create(
53         &key_, [](void* buf) { delete static_cast<Type*>(buf); });
54   }
55 
~ThreadLocal()56   ~ThreadLocal() {
57     if (void* current = pthread_getspecific(key_)) {
58       delete static_cast<Type*>(current);
59     }
60 
61     pthread_key_delete(key_);
62   }
63 
64   ThreadLocal(const ThreadLocal&) = delete;
65   ThreadLocal& operator=(const ThreadLocal&) = delete;
66 
get()67   Type& get() {
68     if (void* current = pthread_getspecific(key_)) {
69       return *static_cast<Type*>(current);
70     }
71 
72     std::unique_ptr<Type> ptr = std::make_unique<Type>();
73     if (0 == pthread_setspecific(key_, ptr.get())) {
74       return *ptr.release();
75     }
76 
77     int err = errno;
78     TORCH_INTERNAL_ASSERT(false, "pthread_setspecific() failed, errno = ", err);
79   }
80 
81   Type& operator*() {
82     return get();
83   }
84 
85   Type* operator->() {
86     return &get();
87   }
88 
89  private:
90   pthread_key_t key_;
91 };
92 
93 } // namespace c10
94 
95 #define C10_DEFINE_TLS_static(Type, Name) static ::c10::ThreadLocal<Type> Name
96 
97 #define C10_DECLARE_TLS_class_static(Class, Type, Name) \
98   static ::c10::ThreadLocal<Type> Name
99 
100 #define C10_DEFINE_TLS_class_static(Class, Type, Name) \
101   ::c10::ThreadLocal<Type> Class::Name
102 
103 #else // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
104 
105 namespace c10 {
106 
107 /**
108  * @brief Default thread_local implementation for non-Android cases.
109  * To be used with composite types that provide default ctor.
110  */
111 template <typename Type>
112 class ThreadLocal {
113  public:
114   using Accessor = Type* (*)();
ThreadLocal(Accessor accessor)115   explicit ThreadLocal(Accessor accessor) : accessor_(accessor) {}
116 
117   ThreadLocal(const ThreadLocal&) = delete;
118   ThreadLocal& operator=(const ThreadLocal&) = delete;
119 
get()120   Type& get() {
121     return *accessor_();
122   }
123 
124   Type& operator*() {
125     return get();
126   }
127 
128   Type* operator->() {
129     return &get();
130   }
131 
132  private:
133   Accessor accessor_;
134 };
135 
136 } // namespace c10
137 
138 #define C10_DEFINE_TLS_static(Type, Name)     \
139   static ::c10::ThreadLocal<Type> Name([]() { \
140     static thread_local Type var;             \
141     return &var;                              \
142   })
143 
144 #define C10_DECLARE_TLS_class_static(Class, Type, Name) \
145   static ::c10::ThreadLocal<Type> Name
146 
147 #define C10_DEFINE_TLS_class_static(Class, Type, Name) \
148   ::c10::ThreadLocal<Type> Class::Name([]() {          \
149     static thread_local Type var;                      \
150     return &var;                                       \
151   })
152 
153 #endif // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
154