1*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h> 2*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Synchronized.h> 3*da0073e9SAndroid Build Coastguard Worker #include <array> 4*da0073e9SAndroid Build Coastguard Worker #include <atomic> 5*da0073e9SAndroid Build Coastguard Worker #include <mutex> 6*da0073e9SAndroid Build Coastguard Worker #include <thread> 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker namespace c10 { 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker namespace detail { 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker struct IncrementRAII final { 13*da0073e9SAndroid Build Coastguard Worker public: IncrementRAIIfinal14*da0073e9SAndroid Build Coastguard Worker explicit IncrementRAII(std::atomic<int32_t>* counter) : _counter(counter) { 15*da0073e9SAndroid Build Coastguard Worker _counter->fetch_add(1); 16*da0073e9SAndroid Build Coastguard Worker } 17*da0073e9SAndroid Build Coastguard Worker ~IncrementRAIIfinal18*da0073e9SAndroid Build Coastguard Worker ~IncrementRAII() { 19*da0073e9SAndroid Build Coastguard Worker _counter->fetch_sub(1); 20*da0073e9SAndroid Build Coastguard Worker } 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker private: 23*da0073e9SAndroid Build Coastguard Worker std::atomic<int32_t>* _counter; 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker C10_DISABLE_COPY_AND_ASSIGN(IncrementRAII); 26*da0073e9SAndroid Build Coastguard Worker }; 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker } // namespace detail 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker // LeftRight wait-free readers synchronization primitive 31*da0073e9SAndroid Build Coastguard Worker // https://hal.archives-ouvertes.fr/hal-01207881/document 32*da0073e9SAndroid Build Coastguard Worker // 33*da0073e9SAndroid Build Coastguard Worker // LeftRight is quite easy to use (it can make an arbitrary 34*da0073e9SAndroid Build Coastguard Worker // data structure permit wait-free reads), but it has some 35*da0073e9SAndroid Build Coastguard Worker // particular performance characteristics you should be aware 36*da0073e9SAndroid Build Coastguard Worker // of if you're deciding to use it: 37*da0073e9SAndroid Build Coastguard Worker // 38*da0073e9SAndroid Build Coastguard Worker // - Reads still incur an atomic write (this is how LeftRight 39*da0073e9SAndroid Build Coastguard Worker // keeps track of how long it needs to keep around the old 40*da0073e9SAndroid Build Coastguard Worker // data structure) 41*da0073e9SAndroid Build Coastguard Worker // 42*da0073e9SAndroid Build Coastguard Worker // - Writes get executed twice, to keep both the left and right 43*da0073e9SAndroid Build Coastguard Worker // versions up to date. So if your write is expensive or 44*da0073e9SAndroid Build Coastguard Worker // nondeterministic, this is also an inappropriate structure 45*da0073e9SAndroid Build Coastguard Worker // 46*da0073e9SAndroid Build Coastguard Worker // LeftRight is used fairly rarely in PyTorch's codebase. If you 47*da0073e9SAndroid Build Coastguard Worker // are still not sure if you need it or not, consult your local 48*da0073e9SAndroid Build Coastguard Worker // C++ expert. 49*da0073e9SAndroid Build Coastguard Worker // 50*da0073e9SAndroid Build Coastguard Worker template <class T> 51*da0073e9SAndroid Build Coastguard Worker class LeftRight final { 52*da0073e9SAndroid Build Coastguard Worker public: 53*da0073e9SAndroid Build Coastguard Worker template <class... Args> LeftRight(const Args &...args)54*da0073e9SAndroid Build Coastguard Worker explicit LeftRight(const Args&... args) 55*da0073e9SAndroid Build Coastguard Worker : _counters{{{0}, {0}}}, 56*da0073e9SAndroid Build Coastguard Worker _foregroundCounterIndex(0), 57*da0073e9SAndroid Build Coastguard Worker _foregroundDataIndex(0), 58*da0073e9SAndroid Build Coastguard Worker _data{{T{args...}, T{args...}}}, 59*da0073e9SAndroid Build Coastguard Worker _writeMutex() {} 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker // Copying and moving would not be threadsafe. 62*da0073e9SAndroid Build Coastguard Worker // Needs more thought and careful design to make that work. 63*da0073e9SAndroid Build Coastguard Worker LeftRight(const LeftRight&) = delete; 64*da0073e9SAndroid Build Coastguard Worker LeftRight(LeftRight&&) noexcept = delete; 65*da0073e9SAndroid Build Coastguard Worker LeftRight& operator=(const LeftRight&) = delete; 66*da0073e9SAndroid Build Coastguard Worker LeftRight& operator=(LeftRight&&) noexcept = delete; 67*da0073e9SAndroid Build Coastguard Worker ~LeftRight()68*da0073e9SAndroid Build Coastguard Worker ~LeftRight() { 69*da0073e9SAndroid Build Coastguard Worker // wait until any potentially running writers are finished 70*da0073e9SAndroid Build Coastguard Worker { std::unique_lock<std::mutex> lock(_writeMutex); } 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker // wait until any potentially running readers are finished 73*da0073e9SAndroid Build Coastguard Worker while (_counters[0].load() != 0 || _counters[1].load() != 0) { 74*da0073e9SAndroid Build Coastguard Worker std::this_thread::yield(); 75*da0073e9SAndroid Build Coastguard Worker } 76*da0073e9SAndroid Build Coastguard Worker } 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker template <typename F> read(F && readFunc)79*da0073e9SAndroid Build Coastguard Worker auto read(F&& readFunc) const { 80*da0073e9SAndroid Build Coastguard Worker detail::IncrementRAII _increment_counter( 81*da0073e9SAndroid Build Coastguard Worker &_counters[_foregroundCounterIndex.load()]); 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker return std::forward<F>(readFunc)(_data[_foregroundDataIndex.load()]); 84*da0073e9SAndroid Build Coastguard Worker } 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker // Throwing an exception in writeFunc is ok but causes the state to be either 87*da0073e9SAndroid Build Coastguard Worker // the old or the new state, depending on if the first or the second call to 88*da0073e9SAndroid Build Coastguard Worker // writeFunc threw. 89*da0073e9SAndroid Build Coastguard Worker template <typename F> write(F && writeFunc)90*da0073e9SAndroid Build Coastguard Worker auto write(F&& writeFunc) { 91*da0073e9SAndroid Build Coastguard Worker std::unique_lock<std::mutex> lock(_writeMutex); 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker return _write(std::forward<F>(writeFunc)); 94*da0073e9SAndroid Build Coastguard Worker } 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker private: 97*da0073e9SAndroid Build Coastguard Worker template <class F> _write(const F & writeFunc)98*da0073e9SAndroid Build Coastguard Worker auto _write(const F& writeFunc) { 99*da0073e9SAndroid Build Coastguard Worker /* 100*da0073e9SAndroid Build Coastguard Worker * Assume, A is in background and B in foreground. In simplified terms, we 101*da0073e9SAndroid Build Coastguard Worker * want to do the following: 102*da0073e9SAndroid Build Coastguard Worker * 1. Write to A (old background) 103*da0073e9SAndroid Build Coastguard Worker * 2. Switch A/B 104*da0073e9SAndroid Build Coastguard Worker * 3. Write to B (new background) 105*da0073e9SAndroid Build Coastguard Worker * 106*da0073e9SAndroid Build Coastguard Worker * More detailed algorithm (explanations on why this is important are below 107*da0073e9SAndroid Build Coastguard Worker * in code): 108*da0073e9SAndroid Build Coastguard Worker * 1. Write to A 109*da0073e9SAndroid Build Coastguard Worker * 2. Switch A/B data pointers 110*da0073e9SAndroid Build Coastguard Worker * 3. Wait until A counter is zero 111*da0073e9SAndroid Build Coastguard Worker * 4. Switch A/B counters 112*da0073e9SAndroid Build Coastguard Worker * 5. Wait until B counter is zero 113*da0073e9SAndroid Build Coastguard Worker * 6. Write to B 114*da0073e9SAndroid Build Coastguard Worker */ 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker auto localDataIndex = _foregroundDataIndex.load(); 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker // 1. Write to A 119*da0073e9SAndroid Build Coastguard Worker _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker // 2. Switch A/B data pointers 122*da0073e9SAndroid Build Coastguard Worker localDataIndex = localDataIndex ^ 1; 123*da0073e9SAndroid Build Coastguard Worker _foregroundDataIndex = localDataIndex; 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker /* 126*da0073e9SAndroid Build Coastguard Worker * 3. Wait until A counter is zero 127*da0073e9SAndroid Build Coastguard Worker * 128*da0073e9SAndroid Build Coastguard Worker * In the previous write run, A was foreground and B was background. 129*da0073e9SAndroid Build Coastguard Worker * There was a time after switching _foregroundDataIndex (B to foreground) 130*da0073e9SAndroid Build Coastguard Worker * and before switching _foregroundCounterIndex, in which new readers could 131*da0073e9SAndroid Build Coastguard Worker * have read B but incremented A's counter. 132*da0073e9SAndroid Build Coastguard Worker * 133*da0073e9SAndroid Build Coastguard Worker * In this current run, we just switched _foregroundDataIndex (A back to 134*da0073e9SAndroid Build Coastguard Worker * foreground), but before writing to the new background B, we have to make 135*da0073e9SAndroid Build Coastguard Worker * sure A's counter was zero briefly, so all these old readers are gone. 136*da0073e9SAndroid Build Coastguard Worker */ 137*da0073e9SAndroid Build Coastguard Worker auto localCounterIndex = _foregroundCounterIndex.load(); 138*da0073e9SAndroid Build Coastguard Worker _waitForBackgroundCounterToBeZero(localCounterIndex); 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker /* 141*da0073e9SAndroid Build Coastguard Worker * 4. Switch A/B counters 142*da0073e9SAndroid Build Coastguard Worker * 143*da0073e9SAndroid Build Coastguard Worker * Now that we know all readers on B are really gone, we can switch the 144*da0073e9SAndroid Build Coastguard Worker * counters and have new readers increment A's counter again, which is the 145*da0073e9SAndroid Build Coastguard Worker * correct counter since they're reading A. 146*da0073e9SAndroid Build Coastguard Worker */ 147*da0073e9SAndroid Build Coastguard Worker localCounterIndex = localCounterIndex ^ 1; 148*da0073e9SAndroid Build Coastguard Worker _foregroundCounterIndex = localCounterIndex; 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker /* 151*da0073e9SAndroid Build Coastguard Worker * 5. Wait until B counter is zero 152*da0073e9SAndroid Build Coastguard Worker * 153*da0073e9SAndroid Build Coastguard Worker * This waits for all the readers on B that came in while both data and 154*da0073e9SAndroid Build Coastguard Worker * counter for B was in foreground, i.e. normal readers that happened 155*da0073e9SAndroid Build Coastguard Worker * outside of that brief gap between switching data and counter. 156*da0073e9SAndroid Build Coastguard Worker */ 157*da0073e9SAndroid Build Coastguard Worker _waitForBackgroundCounterToBeZero(localCounterIndex); 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker // 6. Write to B 160*da0073e9SAndroid Build Coastguard Worker return _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); 161*da0073e9SAndroid Build Coastguard Worker } 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker template <class F> _callWriteFuncOnBackgroundInstance(const F & writeFunc,uint8_t localDataIndex)164*da0073e9SAndroid Build Coastguard Worker auto _callWriteFuncOnBackgroundInstance( 165*da0073e9SAndroid Build Coastguard Worker const F& writeFunc, 166*da0073e9SAndroid Build Coastguard Worker uint8_t localDataIndex) { 167*da0073e9SAndroid Build Coastguard Worker try { 168*da0073e9SAndroid Build Coastguard Worker return writeFunc(_data[localDataIndex ^ 1]); 169*da0073e9SAndroid Build Coastguard Worker } catch (...) { 170*da0073e9SAndroid Build Coastguard Worker // recover invariant by copying from the foreground instance 171*da0073e9SAndroid Build Coastguard Worker _data[localDataIndex ^ 1] = _data[localDataIndex]; 172*da0073e9SAndroid Build Coastguard Worker // rethrow 173*da0073e9SAndroid Build Coastguard Worker throw; 174*da0073e9SAndroid Build Coastguard Worker } 175*da0073e9SAndroid Build Coastguard Worker } 176*da0073e9SAndroid Build Coastguard Worker _waitForBackgroundCounterToBeZero(uint8_t counterIndex)177*da0073e9SAndroid Build Coastguard Worker void _waitForBackgroundCounterToBeZero(uint8_t counterIndex) { 178*da0073e9SAndroid Build Coastguard Worker while (_counters[counterIndex ^ 1].load() != 0) { 179*da0073e9SAndroid Build Coastguard Worker std::this_thread::yield(); 180*da0073e9SAndroid Build Coastguard Worker } 181*da0073e9SAndroid Build Coastguard Worker } 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker mutable std::array<std::atomic<int32_t>, 2> _counters; 184*da0073e9SAndroid Build Coastguard Worker std::atomic<uint8_t> _foregroundCounterIndex; 185*da0073e9SAndroid Build Coastguard Worker std::atomic<uint8_t> _foregroundDataIndex; 186*da0073e9SAndroid Build Coastguard Worker std::array<T, 2> _data; 187*da0073e9SAndroid Build Coastguard Worker std::mutex _writeMutex; 188*da0073e9SAndroid Build Coastguard Worker }; 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker // RWSafeLeftRightWrapper is API compatible with LeftRight and uses a 191*da0073e9SAndroid Build Coastguard Worker // read-write lock to protect T (data). 192*da0073e9SAndroid Build Coastguard Worker template <class T> 193*da0073e9SAndroid Build Coastguard Worker class RWSafeLeftRightWrapper final { 194*da0073e9SAndroid Build Coastguard Worker public: 195*da0073e9SAndroid Build Coastguard Worker template <class... Args> RWSafeLeftRightWrapper(const Args &...args)196*da0073e9SAndroid Build Coastguard Worker explicit RWSafeLeftRightWrapper(const Args&... args) : data_{args...} {} 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker // RWSafeLeftRightWrapper is not copyable or moveable since LeftRight 199*da0073e9SAndroid Build Coastguard Worker // is not copyable or moveable. 200*da0073e9SAndroid Build Coastguard Worker RWSafeLeftRightWrapper(const RWSafeLeftRightWrapper&) = delete; 201*da0073e9SAndroid Build Coastguard Worker RWSafeLeftRightWrapper(RWSafeLeftRightWrapper&&) noexcept = delete; 202*da0073e9SAndroid Build Coastguard Worker RWSafeLeftRightWrapper& operator=(const RWSafeLeftRightWrapper&) = delete; 203*da0073e9SAndroid Build Coastguard Worker RWSafeLeftRightWrapper& operator=(RWSafeLeftRightWrapper&&) noexcept = delete; 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker template <typename F> 206*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) read(F && readFunc)207*da0073e9SAndroid Build Coastguard Worker auto read(F&& readFunc) const { 208*da0073e9SAndroid Build Coastguard Worker return data_.withLock( 209*da0073e9SAndroid Build Coastguard Worker [&readFunc](T const& data) { return std::forward<F>(readFunc)(data); }); 210*da0073e9SAndroid Build Coastguard Worker } 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker template <typename F> 213*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) write(F && writeFunc)214*da0073e9SAndroid Build Coastguard Worker auto write(F&& writeFunc) { 215*da0073e9SAndroid Build Coastguard Worker return data_.withLock( 216*da0073e9SAndroid Build Coastguard Worker [&writeFunc](T& data) { return std::forward<F>(writeFunc)(data); }); 217*da0073e9SAndroid Build Coastguard Worker } 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker private: 220*da0073e9SAndroid Build Coastguard Worker c10::Synchronized<T> data_; 221*da0073e9SAndroid Build Coastguard Worker }; 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker } // namespace c10 224