xref: /aosp_15_r20/external/pytorch/c10/util/LeftRight.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
2*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Synchronized.h>
3*da0073e9SAndroid Build Coastguard Worker #include <array>
4*da0073e9SAndroid Build Coastguard Worker #include <atomic>
5*da0073e9SAndroid Build Coastguard Worker #include <mutex>
6*da0073e9SAndroid Build Coastguard Worker #include <thread>
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker namespace c10 {
9*da0073e9SAndroid Build Coastguard Worker 
10*da0073e9SAndroid Build Coastguard Worker namespace detail {
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker struct IncrementRAII final {
13*da0073e9SAndroid Build Coastguard Worker  public:
IncrementRAIIfinal14*da0073e9SAndroid Build Coastguard Worker   explicit IncrementRAII(std::atomic<int32_t>* counter) : _counter(counter) {
15*da0073e9SAndroid Build Coastguard Worker     _counter->fetch_add(1);
16*da0073e9SAndroid Build Coastguard Worker   }
17*da0073e9SAndroid Build Coastguard Worker 
~IncrementRAIIfinal18*da0073e9SAndroid Build Coastguard Worker   ~IncrementRAII() {
19*da0073e9SAndroid Build Coastguard Worker     _counter->fetch_sub(1);
20*da0073e9SAndroid Build Coastguard Worker   }
21*da0073e9SAndroid Build Coastguard Worker 
22*da0073e9SAndroid Build Coastguard Worker  private:
23*da0073e9SAndroid Build Coastguard Worker   std::atomic<int32_t>* _counter;
24*da0073e9SAndroid Build Coastguard Worker 
25*da0073e9SAndroid Build Coastguard Worker   C10_DISABLE_COPY_AND_ASSIGN(IncrementRAII);
26*da0073e9SAndroid Build Coastguard Worker };
27*da0073e9SAndroid Build Coastguard Worker 
28*da0073e9SAndroid Build Coastguard Worker } // namespace detail
29*da0073e9SAndroid Build Coastguard Worker 
30*da0073e9SAndroid Build Coastguard Worker // LeftRight wait-free readers synchronization primitive
31*da0073e9SAndroid Build Coastguard Worker // https://hal.archives-ouvertes.fr/hal-01207881/document
32*da0073e9SAndroid Build Coastguard Worker //
33*da0073e9SAndroid Build Coastguard Worker // LeftRight is quite easy to use (it can make an arbitrary
34*da0073e9SAndroid Build Coastguard Worker // data structure permit wait-free reads), but it has some
35*da0073e9SAndroid Build Coastguard Worker // particular performance characteristics you should be aware
36*da0073e9SAndroid Build Coastguard Worker // of if you're deciding to use it:
37*da0073e9SAndroid Build Coastguard Worker //
38*da0073e9SAndroid Build Coastguard Worker //  - Reads still incur an atomic write (this is how LeftRight
39*da0073e9SAndroid Build Coastguard Worker //    keeps track of how long it needs to keep around the old
40*da0073e9SAndroid Build Coastguard Worker //    data structure)
41*da0073e9SAndroid Build Coastguard Worker //
42*da0073e9SAndroid Build Coastguard Worker //  - Writes get executed twice, to keep both the left and right
43*da0073e9SAndroid Build Coastguard Worker //    versions up to date.  So if your write is expensive or
44*da0073e9SAndroid Build Coastguard Worker //    nondeterministic, this is also an inappropriate structure
45*da0073e9SAndroid Build Coastguard Worker //
46*da0073e9SAndroid Build Coastguard Worker // LeftRight is used fairly rarely in PyTorch's codebase.  If you
47*da0073e9SAndroid Build Coastguard Worker // are still not sure if you need it or not, consult your local
48*da0073e9SAndroid Build Coastguard Worker // C++ expert.
49*da0073e9SAndroid Build Coastguard Worker //
50*da0073e9SAndroid Build Coastguard Worker template <class T>
51*da0073e9SAndroid Build Coastguard Worker class LeftRight final {
52*da0073e9SAndroid Build Coastguard Worker  public:
53*da0073e9SAndroid Build Coastguard Worker   template <class... Args>
LeftRight(const Args &...args)54*da0073e9SAndroid Build Coastguard Worker   explicit LeftRight(const Args&... args)
55*da0073e9SAndroid Build Coastguard Worker       : _counters{{{0}, {0}}},
56*da0073e9SAndroid Build Coastguard Worker         _foregroundCounterIndex(0),
57*da0073e9SAndroid Build Coastguard Worker         _foregroundDataIndex(0),
58*da0073e9SAndroid Build Coastguard Worker         _data{{T{args...}, T{args...}}},
59*da0073e9SAndroid Build Coastguard Worker         _writeMutex() {}
60*da0073e9SAndroid Build Coastguard Worker 
61*da0073e9SAndroid Build Coastguard Worker   // Copying and moving would not be threadsafe.
62*da0073e9SAndroid Build Coastguard Worker   // Needs more thought and careful design to make that work.
63*da0073e9SAndroid Build Coastguard Worker   LeftRight(const LeftRight&) = delete;
64*da0073e9SAndroid Build Coastguard Worker   LeftRight(LeftRight&&) noexcept = delete;
65*da0073e9SAndroid Build Coastguard Worker   LeftRight& operator=(const LeftRight&) = delete;
66*da0073e9SAndroid Build Coastguard Worker   LeftRight& operator=(LeftRight&&) noexcept = delete;
67*da0073e9SAndroid Build Coastguard Worker 
~LeftRight()68*da0073e9SAndroid Build Coastguard Worker   ~LeftRight() {
69*da0073e9SAndroid Build Coastguard Worker     // wait until any potentially running writers are finished
70*da0073e9SAndroid Build Coastguard Worker     { std::unique_lock<std::mutex> lock(_writeMutex); }
71*da0073e9SAndroid Build Coastguard Worker 
72*da0073e9SAndroid Build Coastguard Worker     // wait until any potentially running readers are finished
73*da0073e9SAndroid Build Coastguard Worker     while (_counters[0].load() != 0 || _counters[1].load() != 0) {
74*da0073e9SAndroid Build Coastguard Worker       std::this_thread::yield();
75*da0073e9SAndroid Build Coastguard Worker     }
76*da0073e9SAndroid Build Coastguard Worker   }
77*da0073e9SAndroid Build Coastguard Worker 
78*da0073e9SAndroid Build Coastguard Worker   template <typename F>
read(F && readFunc)79*da0073e9SAndroid Build Coastguard Worker   auto read(F&& readFunc) const {
80*da0073e9SAndroid Build Coastguard Worker     detail::IncrementRAII _increment_counter(
81*da0073e9SAndroid Build Coastguard Worker         &_counters[_foregroundCounterIndex.load()]);
82*da0073e9SAndroid Build Coastguard Worker 
83*da0073e9SAndroid Build Coastguard Worker     return std::forward<F>(readFunc)(_data[_foregroundDataIndex.load()]);
84*da0073e9SAndroid Build Coastguard Worker   }
85*da0073e9SAndroid Build Coastguard Worker 
86*da0073e9SAndroid Build Coastguard Worker   // Throwing an exception in writeFunc is ok but causes the state to be either
87*da0073e9SAndroid Build Coastguard Worker   // the old or the new state, depending on if the first or the second call to
88*da0073e9SAndroid Build Coastguard Worker   // writeFunc threw.
89*da0073e9SAndroid Build Coastguard Worker   template <typename F>
write(F && writeFunc)90*da0073e9SAndroid Build Coastguard Worker   auto write(F&& writeFunc) {
91*da0073e9SAndroid Build Coastguard Worker     std::unique_lock<std::mutex> lock(_writeMutex);
92*da0073e9SAndroid Build Coastguard Worker 
93*da0073e9SAndroid Build Coastguard Worker     return _write(std::forward<F>(writeFunc));
94*da0073e9SAndroid Build Coastguard Worker   }
95*da0073e9SAndroid Build Coastguard Worker 
96*da0073e9SAndroid Build Coastguard Worker  private:
97*da0073e9SAndroid Build Coastguard Worker   template <class F>
_write(const F & writeFunc)98*da0073e9SAndroid Build Coastguard Worker   auto _write(const F& writeFunc) {
99*da0073e9SAndroid Build Coastguard Worker     /*
100*da0073e9SAndroid Build Coastguard Worker      * Assume, A is in background and B in foreground. In simplified terms, we
101*da0073e9SAndroid Build Coastguard Worker      * want to do the following:
102*da0073e9SAndroid Build Coastguard Worker      * 1. Write to A (old background)
103*da0073e9SAndroid Build Coastguard Worker      * 2. Switch A/B
104*da0073e9SAndroid Build Coastguard Worker      * 3. Write to B (new background)
105*da0073e9SAndroid Build Coastguard Worker      *
106*da0073e9SAndroid Build Coastguard Worker      * More detailed algorithm (explanations on why this is important are below
107*da0073e9SAndroid Build Coastguard Worker      * in code):
108*da0073e9SAndroid Build Coastguard Worker      * 1. Write to A
109*da0073e9SAndroid Build Coastguard Worker      * 2. Switch A/B data pointers
110*da0073e9SAndroid Build Coastguard Worker      * 3. Wait until A counter is zero
111*da0073e9SAndroid Build Coastguard Worker      * 4. Switch A/B counters
112*da0073e9SAndroid Build Coastguard Worker      * 5. Wait until B counter is zero
113*da0073e9SAndroid Build Coastguard Worker      * 6. Write to B
114*da0073e9SAndroid Build Coastguard Worker      */
115*da0073e9SAndroid Build Coastguard Worker 
116*da0073e9SAndroid Build Coastguard Worker     auto localDataIndex = _foregroundDataIndex.load();
117*da0073e9SAndroid Build Coastguard Worker 
118*da0073e9SAndroid Build Coastguard Worker     // 1. Write to A
119*da0073e9SAndroid Build Coastguard Worker     _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex);
120*da0073e9SAndroid Build Coastguard Worker 
121*da0073e9SAndroid Build Coastguard Worker     // 2. Switch A/B data pointers
122*da0073e9SAndroid Build Coastguard Worker     localDataIndex = localDataIndex ^ 1;
123*da0073e9SAndroid Build Coastguard Worker     _foregroundDataIndex = localDataIndex;
124*da0073e9SAndroid Build Coastguard Worker 
125*da0073e9SAndroid Build Coastguard Worker     /*
126*da0073e9SAndroid Build Coastguard Worker      * 3. Wait until A counter is zero
127*da0073e9SAndroid Build Coastguard Worker      *
128*da0073e9SAndroid Build Coastguard Worker      * In the previous write run, A was foreground and B was background.
129*da0073e9SAndroid Build Coastguard Worker      * There was a time after switching _foregroundDataIndex (B to foreground)
130*da0073e9SAndroid Build Coastguard Worker      * and before switching _foregroundCounterIndex, in which new readers could
131*da0073e9SAndroid Build Coastguard Worker      * have read B but incremented A's counter.
132*da0073e9SAndroid Build Coastguard Worker      *
133*da0073e9SAndroid Build Coastguard Worker      * In this current run, we just switched _foregroundDataIndex (A back to
134*da0073e9SAndroid Build Coastguard Worker      * foreground), but before writing to the new background B, we have to make
135*da0073e9SAndroid Build Coastguard Worker      * sure A's counter was zero briefly, so all these old readers are gone.
136*da0073e9SAndroid Build Coastguard Worker      */
137*da0073e9SAndroid Build Coastguard Worker     auto localCounterIndex = _foregroundCounterIndex.load();
138*da0073e9SAndroid Build Coastguard Worker     _waitForBackgroundCounterToBeZero(localCounterIndex);
139*da0073e9SAndroid Build Coastguard Worker 
140*da0073e9SAndroid Build Coastguard Worker     /*
141*da0073e9SAndroid Build Coastguard Worker      * 4. Switch A/B counters
142*da0073e9SAndroid Build Coastguard Worker      *
143*da0073e9SAndroid Build Coastguard Worker      * Now that we know all readers on B are really gone, we can switch the
144*da0073e9SAndroid Build Coastguard Worker      * counters and have new readers increment A's counter again, which is the
145*da0073e9SAndroid Build Coastguard Worker      * correct counter since they're reading A.
146*da0073e9SAndroid Build Coastguard Worker      */
147*da0073e9SAndroid Build Coastguard Worker     localCounterIndex = localCounterIndex ^ 1;
148*da0073e9SAndroid Build Coastguard Worker     _foregroundCounterIndex = localCounterIndex;
149*da0073e9SAndroid Build Coastguard Worker 
150*da0073e9SAndroid Build Coastguard Worker     /*
151*da0073e9SAndroid Build Coastguard Worker      * 5. Wait until B counter is zero
152*da0073e9SAndroid Build Coastguard Worker      *
153*da0073e9SAndroid Build Coastguard Worker      * This waits for all the readers on B that came in while both data and
154*da0073e9SAndroid Build Coastguard Worker      * counter for B was in foreground, i.e. normal readers that happened
155*da0073e9SAndroid Build Coastguard Worker      * outside of that brief gap between switching data and counter.
156*da0073e9SAndroid Build Coastguard Worker      */
157*da0073e9SAndroid Build Coastguard Worker     _waitForBackgroundCounterToBeZero(localCounterIndex);
158*da0073e9SAndroid Build Coastguard Worker 
159*da0073e9SAndroid Build Coastguard Worker     // 6. Write to B
160*da0073e9SAndroid Build Coastguard Worker     return _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex);
161*da0073e9SAndroid Build Coastguard Worker   }
162*da0073e9SAndroid Build Coastguard Worker 
163*da0073e9SAndroid Build Coastguard Worker   template <class F>
_callWriteFuncOnBackgroundInstance(const F & writeFunc,uint8_t localDataIndex)164*da0073e9SAndroid Build Coastguard Worker   auto _callWriteFuncOnBackgroundInstance(
165*da0073e9SAndroid Build Coastguard Worker       const F& writeFunc,
166*da0073e9SAndroid Build Coastguard Worker       uint8_t localDataIndex) {
167*da0073e9SAndroid Build Coastguard Worker     try {
168*da0073e9SAndroid Build Coastguard Worker       return writeFunc(_data[localDataIndex ^ 1]);
169*da0073e9SAndroid Build Coastguard Worker     } catch (...) {
170*da0073e9SAndroid Build Coastguard Worker       // recover invariant by copying from the foreground instance
171*da0073e9SAndroid Build Coastguard Worker       _data[localDataIndex ^ 1] = _data[localDataIndex];
172*da0073e9SAndroid Build Coastguard Worker       // rethrow
173*da0073e9SAndroid Build Coastguard Worker       throw;
174*da0073e9SAndroid Build Coastguard Worker     }
175*da0073e9SAndroid Build Coastguard Worker   }
176*da0073e9SAndroid Build Coastguard Worker 
_waitForBackgroundCounterToBeZero(uint8_t counterIndex)177*da0073e9SAndroid Build Coastguard Worker   void _waitForBackgroundCounterToBeZero(uint8_t counterIndex) {
178*da0073e9SAndroid Build Coastguard Worker     while (_counters[counterIndex ^ 1].load() != 0) {
179*da0073e9SAndroid Build Coastguard Worker       std::this_thread::yield();
180*da0073e9SAndroid Build Coastguard Worker     }
181*da0073e9SAndroid Build Coastguard Worker   }
182*da0073e9SAndroid Build Coastguard Worker 
183*da0073e9SAndroid Build Coastguard Worker   mutable std::array<std::atomic<int32_t>, 2> _counters;
184*da0073e9SAndroid Build Coastguard Worker   std::atomic<uint8_t> _foregroundCounterIndex;
185*da0073e9SAndroid Build Coastguard Worker   std::atomic<uint8_t> _foregroundDataIndex;
186*da0073e9SAndroid Build Coastguard Worker   std::array<T, 2> _data;
187*da0073e9SAndroid Build Coastguard Worker   std::mutex _writeMutex;
188*da0073e9SAndroid Build Coastguard Worker };
189*da0073e9SAndroid Build Coastguard Worker 
190*da0073e9SAndroid Build Coastguard Worker // RWSafeLeftRightWrapper is API compatible with LeftRight and uses a
191*da0073e9SAndroid Build Coastguard Worker // read-write lock to protect T (data).
192*da0073e9SAndroid Build Coastguard Worker template <class T>
193*da0073e9SAndroid Build Coastguard Worker class RWSafeLeftRightWrapper final {
194*da0073e9SAndroid Build Coastguard Worker  public:
195*da0073e9SAndroid Build Coastguard Worker   template <class... Args>
RWSafeLeftRightWrapper(const Args &...args)196*da0073e9SAndroid Build Coastguard Worker   explicit RWSafeLeftRightWrapper(const Args&... args) : data_{args...} {}
197*da0073e9SAndroid Build Coastguard Worker 
198*da0073e9SAndroid Build Coastguard Worker   // RWSafeLeftRightWrapper is not copyable or moveable since LeftRight
199*da0073e9SAndroid Build Coastguard Worker   // is not copyable or moveable.
200*da0073e9SAndroid Build Coastguard Worker   RWSafeLeftRightWrapper(const RWSafeLeftRightWrapper&) = delete;
201*da0073e9SAndroid Build Coastguard Worker   RWSafeLeftRightWrapper(RWSafeLeftRightWrapper&&) noexcept = delete;
202*da0073e9SAndroid Build Coastguard Worker   RWSafeLeftRightWrapper& operator=(const RWSafeLeftRightWrapper&) = delete;
203*da0073e9SAndroid Build Coastguard Worker   RWSafeLeftRightWrapper& operator=(RWSafeLeftRightWrapper&&) noexcept = delete;
204*da0073e9SAndroid Build Coastguard Worker 
205*da0073e9SAndroid Build Coastguard Worker   template <typename F>
206*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
read(F && readFunc)207*da0073e9SAndroid Build Coastguard Worker   auto read(F&& readFunc) const {
208*da0073e9SAndroid Build Coastguard Worker     return data_.withLock(
209*da0073e9SAndroid Build Coastguard Worker         [&readFunc](T const& data) { return std::forward<F>(readFunc)(data); });
210*da0073e9SAndroid Build Coastguard Worker   }
211*da0073e9SAndroid Build Coastguard Worker 
212*da0073e9SAndroid Build Coastguard Worker   template <typename F>
213*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
write(F && writeFunc)214*da0073e9SAndroid Build Coastguard Worker   auto write(F&& writeFunc) {
215*da0073e9SAndroid Build Coastguard Worker     return data_.withLock(
216*da0073e9SAndroid Build Coastguard Worker         [&writeFunc](T& data) { return std::forward<F>(writeFunc)(data); });
217*da0073e9SAndroid Build Coastguard Worker   }
218*da0073e9SAndroid Build Coastguard Worker 
219*da0073e9SAndroid Build Coastguard Worker  private:
220*da0073e9SAndroid Build Coastguard Worker   c10::Synchronized<T> data_;
221*da0073e9SAndroid Build Coastguard Worker };
222*da0073e9SAndroid Build Coastguard Worker 
223*da0073e9SAndroid Build Coastguard Worker } // namespace c10
224