1 #include <c10/macros/Macros.h> 2 #include <c10/util/Synchronized.h> 3 #include <array> 4 #include <atomic> 5 #include <mutex> 6 #include <thread> 7 8 namespace c10 { 9 10 namespace detail { 11 12 struct IncrementRAII final { 13 public: IncrementRAIIfinal14 explicit IncrementRAII(std::atomic<int32_t>* counter) : _counter(counter) { 15 _counter->fetch_add(1); 16 } 17 ~IncrementRAIIfinal18 ~IncrementRAII() { 19 _counter->fetch_sub(1); 20 } 21 22 private: 23 std::atomic<int32_t>* _counter; 24 25 C10_DISABLE_COPY_AND_ASSIGN(IncrementRAII); 26 }; 27 28 } // namespace detail 29 30 // LeftRight wait-free readers synchronization primitive 31 // https://hal.archives-ouvertes.fr/hal-01207881/document 32 // 33 // LeftRight is quite easy to use (it can make an arbitrary 34 // data structure permit wait-free reads), but it has some 35 // particular performance characteristics you should be aware 36 // of if you're deciding to use it: 37 // 38 // - Reads still incur an atomic write (this is how LeftRight 39 // keeps track of how long it needs to keep around the old 40 // data structure) 41 // 42 // - Writes get executed twice, to keep both the left and right 43 // versions up to date. So if your write is expensive or 44 // nondeterministic, this is also an inappropriate structure 45 // 46 // LeftRight is used fairly rarely in PyTorch's codebase. If you 47 // are still not sure if you need it or not, consult your local 48 // C++ expert. 49 // 50 template <class T> 51 class LeftRight final { 52 public: 53 template <class... Args> LeftRight(const Args &...args)54 explicit LeftRight(const Args&... args) 55 : _counters{{{0}, {0}}}, 56 _foregroundCounterIndex(0), 57 _foregroundDataIndex(0), 58 _data{{T{args...}, T{args...}}}, 59 _writeMutex() {} 60 61 // Copying and moving would not be threadsafe. 62 // Needs more thought and careful design to make that work. 63 LeftRight(const LeftRight&) = delete; 64 LeftRight(LeftRight&&) noexcept = delete; 65 LeftRight& operator=(const LeftRight&) = delete; 66 LeftRight& operator=(LeftRight&&) noexcept = delete; 67 ~LeftRight()68 ~LeftRight() { 69 // wait until any potentially running writers are finished 70 { std::unique_lock<std::mutex> lock(_writeMutex); } 71 72 // wait until any potentially running readers are finished 73 while (_counters[0].load() != 0 || _counters[1].load() != 0) { 74 std::this_thread::yield(); 75 } 76 } 77 78 template <typename F> read(F && readFunc)79 auto read(F&& readFunc) const { 80 detail::IncrementRAII _increment_counter( 81 &_counters[_foregroundCounterIndex.load()]); 82 83 return std::forward<F>(readFunc)(_data[_foregroundDataIndex.load()]); 84 } 85 86 // Throwing an exception in writeFunc is ok but causes the state to be either 87 // the old or the new state, depending on if the first or the second call to 88 // writeFunc threw. 89 template <typename F> write(F && writeFunc)90 auto write(F&& writeFunc) { 91 std::unique_lock<std::mutex> lock(_writeMutex); 92 93 return _write(std::forward<F>(writeFunc)); 94 } 95 96 private: 97 template <class F> _write(const F & writeFunc)98 auto _write(const F& writeFunc) { 99 /* 100 * Assume, A is in background and B in foreground. In simplified terms, we 101 * want to do the following: 102 * 1. Write to A (old background) 103 * 2. Switch A/B 104 * 3. Write to B (new background) 105 * 106 * More detailed algorithm (explanations on why this is important are below 107 * in code): 108 * 1. Write to A 109 * 2. Switch A/B data pointers 110 * 3. Wait until A counter is zero 111 * 4. Switch A/B counters 112 * 5. Wait until B counter is zero 113 * 6. Write to B 114 */ 115 116 auto localDataIndex = _foregroundDataIndex.load(); 117 118 // 1. Write to A 119 _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); 120 121 // 2. Switch A/B data pointers 122 localDataIndex = localDataIndex ^ 1; 123 _foregroundDataIndex = localDataIndex; 124 125 /* 126 * 3. Wait until A counter is zero 127 * 128 * In the previous write run, A was foreground and B was background. 129 * There was a time after switching _foregroundDataIndex (B to foreground) 130 * and before switching _foregroundCounterIndex, in which new readers could 131 * have read B but incremented A's counter. 132 * 133 * In this current run, we just switched _foregroundDataIndex (A back to 134 * foreground), but before writing to the new background B, we have to make 135 * sure A's counter was zero briefly, so all these old readers are gone. 136 */ 137 auto localCounterIndex = _foregroundCounterIndex.load(); 138 _waitForBackgroundCounterToBeZero(localCounterIndex); 139 140 /* 141 * 4. Switch A/B counters 142 * 143 * Now that we know all readers on B are really gone, we can switch the 144 * counters and have new readers increment A's counter again, which is the 145 * correct counter since they're reading A. 146 */ 147 localCounterIndex = localCounterIndex ^ 1; 148 _foregroundCounterIndex = localCounterIndex; 149 150 /* 151 * 5. Wait until B counter is zero 152 * 153 * This waits for all the readers on B that came in while both data and 154 * counter for B was in foreground, i.e. normal readers that happened 155 * outside of that brief gap between switching data and counter. 156 */ 157 _waitForBackgroundCounterToBeZero(localCounterIndex); 158 159 // 6. Write to B 160 return _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); 161 } 162 163 template <class F> _callWriteFuncOnBackgroundInstance(const F & writeFunc,uint8_t localDataIndex)164 auto _callWriteFuncOnBackgroundInstance( 165 const F& writeFunc, 166 uint8_t localDataIndex) { 167 try { 168 return writeFunc(_data[localDataIndex ^ 1]); 169 } catch (...) { 170 // recover invariant by copying from the foreground instance 171 _data[localDataIndex ^ 1] = _data[localDataIndex]; 172 // rethrow 173 throw; 174 } 175 } 176 _waitForBackgroundCounterToBeZero(uint8_t counterIndex)177 void _waitForBackgroundCounterToBeZero(uint8_t counterIndex) { 178 while (_counters[counterIndex ^ 1].load() != 0) { 179 std::this_thread::yield(); 180 } 181 } 182 183 mutable std::array<std::atomic<int32_t>, 2> _counters; 184 std::atomic<uint8_t> _foregroundCounterIndex; 185 std::atomic<uint8_t> _foregroundDataIndex; 186 std::array<T, 2> _data; 187 std::mutex _writeMutex; 188 }; 189 190 // RWSafeLeftRightWrapper is API compatible with LeftRight and uses a 191 // read-write lock to protect T (data). 192 template <class T> 193 class RWSafeLeftRightWrapper final { 194 public: 195 template <class... Args> RWSafeLeftRightWrapper(const Args &...args)196 explicit RWSafeLeftRightWrapper(const Args&... args) : data_{args...} {} 197 198 // RWSafeLeftRightWrapper is not copyable or moveable since LeftRight 199 // is not copyable or moveable. 200 RWSafeLeftRightWrapper(const RWSafeLeftRightWrapper&) = delete; 201 RWSafeLeftRightWrapper(RWSafeLeftRightWrapper&&) noexcept = delete; 202 RWSafeLeftRightWrapper& operator=(const RWSafeLeftRightWrapper&) = delete; 203 RWSafeLeftRightWrapper& operator=(RWSafeLeftRightWrapper&&) noexcept = delete; 204 205 template <typename F> 206 // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) read(F && readFunc)207 auto read(F&& readFunc) const { 208 return data_.withLock( 209 [&readFunc](T const& data) { return std::forward<F>(readFunc)(data); }); 210 } 211 212 template <typename F> 213 // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) write(F && writeFunc)214 auto write(F&& writeFunc) { 215 return data_.withLock( 216 [&writeFunc](T& data) { return std::forward<F>(writeFunc)(data); }); 217 } 218 219 private: 220 c10::Synchronized<T> data_; 221 }; 222 223 } // namespace c10 224