xref: /aosp_15_r20/external/pytorch/c10/util/LeftRight.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/macros/Macros.h>
2 #include <c10/util/Synchronized.h>
3 #include <array>
4 #include <atomic>
5 #include <mutex>
6 #include <thread>
7 
8 namespace c10 {
9 
10 namespace detail {
11 
12 struct IncrementRAII final {
13  public:
IncrementRAIIfinal14   explicit IncrementRAII(std::atomic<int32_t>* counter) : _counter(counter) {
15     _counter->fetch_add(1);
16   }
17 
~IncrementRAIIfinal18   ~IncrementRAII() {
19     _counter->fetch_sub(1);
20   }
21 
22  private:
23   std::atomic<int32_t>* _counter;
24 
25   C10_DISABLE_COPY_AND_ASSIGN(IncrementRAII);
26 };
27 
28 } // namespace detail
29 
30 // LeftRight wait-free readers synchronization primitive
31 // https://hal.archives-ouvertes.fr/hal-01207881/document
32 //
33 // LeftRight is quite easy to use (it can make an arbitrary
34 // data structure permit wait-free reads), but it has some
35 // particular performance characteristics you should be aware
36 // of if you're deciding to use it:
37 //
38 //  - Reads still incur an atomic write (this is how LeftRight
39 //    keeps track of how long it needs to keep around the old
40 //    data structure)
41 //
42 //  - Writes get executed twice, to keep both the left and right
43 //    versions up to date.  So if your write is expensive or
44 //    nondeterministic, this is also an inappropriate structure
45 //
46 // LeftRight is used fairly rarely in PyTorch's codebase.  If you
47 // are still not sure if you need it or not, consult your local
48 // C++ expert.
49 //
50 template <class T>
51 class LeftRight final {
52  public:
53   template <class... Args>
LeftRight(const Args &...args)54   explicit LeftRight(const Args&... args)
55       : _counters{{{0}, {0}}},
56         _foregroundCounterIndex(0),
57         _foregroundDataIndex(0),
58         _data{{T{args...}, T{args...}}},
59         _writeMutex() {}
60 
61   // Copying and moving would not be threadsafe.
62   // Needs more thought and careful design to make that work.
63   LeftRight(const LeftRight&) = delete;
64   LeftRight(LeftRight&&) noexcept = delete;
65   LeftRight& operator=(const LeftRight&) = delete;
66   LeftRight& operator=(LeftRight&&) noexcept = delete;
67 
~LeftRight()68   ~LeftRight() {
69     // wait until any potentially running writers are finished
70     { std::unique_lock<std::mutex> lock(_writeMutex); }
71 
72     // wait until any potentially running readers are finished
73     while (_counters[0].load() != 0 || _counters[1].load() != 0) {
74       std::this_thread::yield();
75     }
76   }
77 
78   template <typename F>
read(F && readFunc)79   auto read(F&& readFunc) const {
80     detail::IncrementRAII _increment_counter(
81         &_counters[_foregroundCounterIndex.load()]);
82 
83     return std::forward<F>(readFunc)(_data[_foregroundDataIndex.load()]);
84   }
85 
86   // Throwing an exception in writeFunc is ok but causes the state to be either
87   // the old or the new state, depending on if the first or the second call to
88   // writeFunc threw.
89   template <typename F>
write(F && writeFunc)90   auto write(F&& writeFunc) {
91     std::unique_lock<std::mutex> lock(_writeMutex);
92 
93     return _write(std::forward<F>(writeFunc));
94   }
95 
96  private:
97   template <class F>
_write(const F & writeFunc)98   auto _write(const F& writeFunc) {
99     /*
100      * Assume, A is in background and B in foreground. In simplified terms, we
101      * want to do the following:
102      * 1. Write to A (old background)
103      * 2. Switch A/B
104      * 3. Write to B (new background)
105      *
106      * More detailed algorithm (explanations on why this is important are below
107      * in code):
108      * 1. Write to A
109      * 2. Switch A/B data pointers
110      * 3. Wait until A counter is zero
111      * 4. Switch A/B counters
112      * 5. Wait until B counter is zero
113      * 6. Write to B
114      */
115 
116     auto localDataIndex = _foregroundDataIndex.load();
117 
118     // 1. Write to A
119     _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex);
120 
121     // 2. Switch A/B data pointers
122     localDataIndex = localDataIndex ^ 1;
123     _foregroundDataIndex = localDataIndex;
124 
125     /*
126      * 3. Wait until A counter is zero
127      *
128      * In the previous write run, A was foreground and B was background.
129      * There was a time after switching _foregroundDataIndex (B to foreground)
130      * and before switching _foregroundCounterIndex, in which new readers could
131      * have read B but incremented A's counter.
132      *
133      * In this current run, we just switched _foregroundDataIndex (A back to
134      * foreground), but before writing to the new background B, we have to make
135      * sure A's counter was zero briefly, so all these old readers are gone.
136      */
137     auto localCounterIndex = _foregroundCounterIndex.load();
138     _waitForBackgroundCounterToBeZero(localCounterIndex);
139 
140     /*
141      * 4. Switch A/B counters
142      *
143      * Now that we know all readers on B are really gone, we can switch the
144      * counters and have new readers increment A's counter again, which is the
145      * correct counter since they're reading A.
146      */
147     localCounterIndex = localCounterIndex ^ 1;
148     _foregroundCounterIndex = localCounterIndex;
149 
150     /*
151      * 5. Wait until B counter is zero
152      *
153      * This waits for all the readers on B that came in while both data and
154      * counter for B was in foreground, i.e. normal readers that happened
155      * outside of that brief gap between switching data and counter.
156      */
157     _waitForBackgroundCounterToBeZero(localCounterIndex);
158 
159     // 6. Write to B
160     return _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex);
161   }
162 
163   template <class F>
_callWriteFuncOnBackgroundInstance(const F & writeFunc,uint8_t localDataIndex)164   auto _callWriteFuncOnBackgroundInstance(
165       const F& writeFunc,
166       uint8_t localDataIndex) {
167     try {
168       return writeFunc(_data[localDataIndex ^ 1]);
169     } catch (...) {
170       // recover invariant by copying from the foreground instance
171       _data[localDataIndex ^ 1] = _data[localDataIndex];
172       // rethrow
173       throw;
174     }
175   }
176 
_waitForBackgroundCounterToBeZero(uint8_t counterIndex)177   void _waitForBackgroundCounterToBeZero(uint8_t counterIndex) {
178     while (_counters[counterIndex ^ 1].load() != 0) {
179       std::this_thread::yield();
180     }
181   }
182 
183   mutable std::array<std::atomic<int32_t>, 2> _counters;
184   std::atomic<uint8_t> _foregroundCounterIndex;
185   std::atomic<uint8_t> _foregroundDataIndex;
186   std::array<T, 2> _data;
187   std::mutex _writeMutex;
188 };
189 
190 // RWSafeLeftRightWrapper is API compatible with LeftRight and uses a
191 // read-write lock to protect T (data).
192 template <class T>
193 class RWSafeLeftRightWrapper final {
194  public:
195   template <class... Args>
RWSafeLeftRightWrapper(const Args &...args)196   explicit RWSafeLeftRightWrapper(const Args&... args) : data_{args...} {}
197 
198   // RWSafeLeftRightWrapper is not copyable or moveable since LeftRight
199   // is not copyable or moveable.
200   RWSafeLeftRightWrapper(const RWSafeLeftRightWrapper&) = delete;
201   RWSafeLeftRightWrapper(RWSafeLeftRightWrapper&&) noexcept = delete;
202   RWSafeLeftRightWrapper& operator=(const RWSafeLeftRightWrapper&) = delete;
203   RWSafeLeftRightWrapper& operator=(RWSafeLeftRightWrapper&&) noexcept = delete;
204 
205   template <typename F>
206   // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
read(F && readFunc)207   auto read(F&& readFunc) const {
208     return data_.withLock(
209         [&readFunc](T const& data) { return std::forward<F>(readFunc)(data); });
210   }
211 
212   template <typename F>
213   // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
write(F && writeFunc)214   auto write(F&& writeFunc) {
215     return data_.withLock(
216         [&writeFunc](T& data) { return std::forward<F>(writeFunc)(data); });
217   }
218 
219  private:
220   c10::Synchronized<T> data_;
221 };
222 
223 } // namespace c10
224