xref: /aosp_15_r20/external/tensorflow/tensorflow/core/platform/refcount.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_PLATFORM_REFCOUNT_H_
17 #define TENSORFLOW_CORE_PLATFORM_REFCOUNT_H_
18 
19 #include <atomic>
20 #include <map>
21 #include <memory>
22 
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/mutex.h"
25 #include "tensorflow/core/platform/thread_annotations.h"
26 
27 namespace tensorflow {
28 namespace core {
29 
30 class RefCounted {
31  public:
32   // Initial reference count is one.
33   RefCounted();
34 
35   // Increments reference count by one.
36   void Ref() const;
37 
38   // Decrements reference count by one.  If the count remains
39   // positive, returns false.  When the count reaches zero, returns
40   // true and deletes this, in which case the caller must not access
41   // the object afterward.
42   bool Unref() const;
43 
44   // Gets the current reference count.
45   int_fast32_t RefCount() const;
46 
47   // Return whether the reference count is one.
48   // If the reference count is used in the conventional way, a
49   // reference count of 1 implies that the current thread owns the
50   // reference and no other thread shares it.
51   // This call performs the test for a reference count of one, and
52   // performs the memory barrier needed for the owning thread
53   // to act on the object, knowing that it has exclusive access to the
54   // object.
55   bool RefCountIsOne() const;
56 
57  protected:
58   // Make destructor protected so that RefCounted objects cannot
59   // be instantiated directly. Only subclasses can be instantiated.
60   virtual ~RefCounted();
61 
62   // Increments reference count by one if the object is not being destructed.
63   // This function is used by WeakRefCounted for securely acquiring a
64   // strong reference. It is only safe to call this as part of the weak
65   // reference implementation.
66   bool TryRef() const;
67 
68  private:
69   mutable std::atomic_int_fast32_t ref_;
70 
71   RefCounted(const RefCounted&) = delete;
72   void operator=(const RefCounted&) = delete;
73 };
74 
75 // A deleter class to form a std::unique_ptr that unrefs objects.
76 struct RefCountDeleter {
operatorRefCountDeleter77   void operator()(const RefCounted* o) const { o->Unref(); }
78 };
79 
80 // A unique_ptr that unrefs the owned object on destruction.
81 template <typename T>
82 using RefCountPtr = std::unique_ptr<T, RefCountDeleter>;
83 
84 // Helper class to unref an object when out-of-scope.
85 class ScopedUnref {
86  public:
ScopedUnref(const RefCounted * o)87   explicit ScopedUnref(const RefCounted* o) : obj_(o) {}
~ScopedUnref()88   ~ScopedUnref() {
89     if (obj_) obj_->Unref();
90   }
91 
92  private:
93   const RefCounted* obj_;
94 
95   ScopedUnref(const ScopedUnref&) = delete;
96   void operator=(const ScopedUnref&) = delete;
97 };
98 
99 // Forward declaration for friend class of WeakRefCounted.
100 template <typename T>
101 class WeakPtr;
102 
103 // A WeakNotifyFn is called when the weakly referred object is being destroyed.
104 // The object may already be destructed when the call occurs. A WeakNotifyFn
105 // can be passed into WeakPtr at construction.
106 using WeakNotifyFn = std::function<void()>;
107 
108 // A base class for RefCounted objects that allow weak references by WeakPtr.
109 // WeakRefCounted and every WeakPtr to it, each holds a strong reference to a
110 // WeakRefData.
111 //
112 // If the WeakRefCounted is valid, WeakPtr::GetNewRef() returns a new strong
113 // reference to the WeakRefCounted.
114 // If the WeakRefCounted is being destructed, `WeakRefCounted::ref_ == 0`;
115 // if the WeakRefcounted is already destructed,`WeakRefData::ptr == nullptr`.
116 // In either case, WeakPtr::GetNewRef() returns a nullptr.
117 class WeakRefCounted : public RefCounted {
118  public:
WeakRefCount()119   int WeakRefCount() const {
120     // Each weak ref owns one ref to data_, and *this owns the last one.
121     return data_->RefCount() - 1;
122   }
123 
124  protected:
~WeakRefCounted()125   ~WeakRefCounted() override { data_->Notify(); }
126 
127  private:
128   struct WeakRefData : public RefCounted {
WeakRefDataWeakRefData129     explicit WeakRefData(WeakRefCounted* ptr) : ptr(ptr), next_notifier_id(1) {}
130 
131     mutable mutex mu;
132     WeakRefCounted* ptr TF_GUARDED_BY(mu);
133     std::map<int, WeakNotifyFn> notifiers;
134     int next_notifier_id;
135 
136     // Notifies WeakPtr instansces that this object is being destructed.
NotifyWeakRefData137     void Notify() {
138       mutex_lock ml(mu);
139 
140       while (!notifiers.empty()) {
141         auto iter = notifiers.begin();
142         WeakNotifyFn notify_fn = std::move(iter->second);
143         notifiers.erase(iter);
144 
145         mu.unlock();
146         notify_fn();
147         mu.lock();
148       }
149       ptr = nullptr;
150     }
151 
GetNewRefWeakRefData152     WeakRefCounted* GetNewRef() {
153       mutex_lock ml(mu);
154       if (ptr != nullptr && ptr->TryRef()) {
155         return ptr;
156       }
157       return nullptr;
158     }
159 
160     // Inserts notify_fn and returns a non-zero id.
161     // Returns 0 if insertion fails due to the object is being destroyed.
162     // 0 is also used by WeakPtr to represent "no notify_fn".
AddNotifierWeakRefData163     int AddNotifier(WeakNotifyFn notify_fn) {
164       mutex_lock ml(mu);
165       if (ptr == nullptr) {
166         return 0;
167       }
168       int notifier_id = next_notifier_id++;
169       notifiers.emplace(notifier_id, std::move(notify_fn));
170       return notifier_id;
171     }
172 
RemoveNotifierWeakRefData173     void RemoveNotifier(int notifier_id) {
174       mutex_lock ml(mu);
175       notifiers.erase(notifier_id);
176     }
177   };
178 
179   RefCountPtr<WeakRefData> data_{new WeakRefData(this)};
180 
181   template <typename T>
182   friend class WeakPtr;
183   // MSVC14 workaround: access permission of a nested class member is not
184   // treated as an ordinary member in MSVC14.
185   friend struct WeakRefData;
186 };
187 
188 // A weak reference to a WeakRefCounted object. Refer to WeakRefCounted.
189 template <typename T>
190 class WeakPtr {
191  public:
192   // Creates a weak reference.
193   // When the object is being destroyed, notify_fn is called.
194   explicit WeakPtr(WeakRefCounted* ptr, WeakNotifyFn notify_fn = nullptr)
data_(nullptr)195       : data_(nullptr), notifier_id_(0) {
196     if (ptr != nullptr) {
197       ptr->data_->Ref();
198       data_.reset(ptr->data_.get());
199       if (notify_fn) {
200         notifier_id_ = data_->AddNotifier(notify_fn);
201       }
202     }
203   }
204 
~WeakPtr()205   ~WeakPtr() {
206     if (data_ != nullptr && notifier_id_ != 0) {
207       data_->RemoveNotifier(notifier_id_);
208     }
209   }
210 
211   // NOTE(feyu): change data_ to a IntrusivePtr to make WeakPtr copyable.
212   WeakPtr(const WeakPtr& other) = delete;
213   WeakPtr& operator=(const WeakPtr& other) = delete;
214 
WeakPtr(WeakPtr && other)215   WeakPtr(WeakPtr&& other) {
216     data_ = std::move(other.data_);
217     notifier_id_ = other.notifier_id_;
218     other.notifier_id_ = 0;
219   }
220 
221   WeakPtr& operator=(WeakPtr&& other) {
222     if (this != &other) {
223       if (data_ != nullptr && notifier_id_ != 0) {
224         data_->RemoveNotifier(notifier_id_);
225       }
226       data_ = std::move(other.data_);
227       notifier_id_ = other.notifier_id_;
228       other.notifier_id_ = 0;
229     }
230     return *this;
231   }
232 
233   // Returns a new strong reference to the referred object, or nullptr if the
234   // object is in an invalid state (being destructed or already destructed).
GetNewRef()235   RefCountPtr<T> GetNewRef() const {
236     RefCountPtr<T> ref;
237     if (data_ != nullptr) {
238       WeakRefCounted* ptr = data_->GetNewRef();
239       ref.reset(static_cast<T*>(ptr));
240     }
241     return std::move(ref);
242   }
243 
244  private:
245   RefCountPtr<WeakRefCounted::WeakRefData> data_;
246   int notifier_id_;
247 };
248 
249 // Inlined routines, since these are performance critical
RefCounted()250 inline RefCounted::RefCounted() : ref_(1) {}
251 
~RefCounted()252 inline RefCounted::~RefCounted() {
253   // A destructing object has ref_ == 0.
254   // It is a bug if the object is resurrected (ref_ > 0) before delete is
255   // called by Unref().
256   DCHECK_EQ(ref_.load(), 0);
257 }
258 
Ref()259 inline void RefCounted::Ref() const {
260   // Ref() uses relaxed order because it is never called with old_ref == 0.
261   // When old_ref >= 1, no actions depend on the new value of ref.
262   int_fast32_t old_ref = ref_.fetch_add(1, std::memory_order_relaxed);
263   DCHECK_GT(old_ref, 0);
264 }
265 
TryRef()266 inline bool RefCounted::TryRef() const {
267   // This is not on a hot path.
268   // Be conservative and use seq_cst to prevent racing with Unref() when
269   // old_ref == 0, as done in LLVM libstdc++.
270   int_fast32_t old_ref = ref_.load();
271   while (old_ref != 0) {
272     if (ref_.compare_exchange_weak(old_ref, old_ref + 1)) {
273       return true;
274     }
275   }
276   // Already destructing, cannot increase ref.
277   return false;
278 }
279 
Unref()280 inline bool RefCounted::Unref() const {
281   DCHECK_GT(ref_.load(), 0);
282   // acq_rel is used to prevent reordering introduces object access after
283   // destruction.
284 
285   // Using release alone is a bug on systems where acq_rel differs from release.
286   // (e.g. arm), according to Herb Sutter's 2012 talk on "Atomic<> Weapons".
287   if (ref_.fetch_sub(1, std::memory_order_acq_rel) == 1) {
288     delete this;
289     return true;
290   }
291   return false;
292 }
293 
RefCount()294 inline int_fast32_t RefCounted::RefCount() const {
295   return ref_.load(std::memory_order_acquire);
296 }
297 
RefCountIsOne()298 inline bool RefCounted::RefCountIsOne() const {
299   return (ref_.load(std::memory_order_acquire) == 1);
300 }
301 
302 }  // namespace core
303 }  // namespace tensorflow
304 
305 #endif  // TENSORFLOW_CORE_PLATFORM_REFCOUNT_H_
306