xref: /aosp_15_r20/external/pytorch/c10/util/Lazy.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <atomic>
4*da0073e9SAndroid Build Coastguard Worker #include <utility>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker namespace c10 {
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker /**
9*da0073e9SAndroid Build Coastguard Worker  * Thread-safe lazy value with opportunistic concurrency: on concurrent first
10*da0073e9SAndroid Build Coastguard Worker  * access, the factory may be called by multiple threads, but only one result is
11*da0073e9SAndroid Build Coastguard Worker  * stored and its reference returned to all the callers.
12*da0073e9SAndroid Build Coastguard Worker  *
13*da0073e9SAndroid Build Coastguard Worker  * Value is heap-allocated; this optimizes for the case in which the value is
14*da0073e9SAndroid Build Coastguard Worker  * never actually computed.
15*da0073e9SAndroid Build Coastguard Worker  */
16*da0073e9SAndroid Build Coastguard Worker template <class T>
17*da0073e9SAndroid Build Coastguard Worker class OptimisticLazy {
18*da0073e9SAndroid Build Coastguard Worker  public:
19*da0073e9SAndroid Build Coastguard Worker   OptimisticLazy() = default;
OptimisticLazy(const OptimisticLazy & other)20*da0073e9SAndroid Build Coastguard Worker   OptimisticLazy(const OptimisticLazy& other) {
21*da0073e9SAndroid Build Coastguard Worker     if (T* value = other.value_.load(std::memory_order_acquire)) {
22*da0073e9SAndroid Build Coastguard Worker       value_ = new T(*value);
23*da0073e9SAndroid Build Coastguard Worker     }
24*da0073e9SAndroid Build Coastguard Worker   }
OptimisticLazy(OptimisticLazy && other)25*da0073e9SAndroid Build Coastguard Worker   OptimisticLazy(OptimisticLazy&& other) noexcept
26*da0073e9SAndroid Build Coastguard Worker       : value_(other.value_.exchange(nullptr, std::memory_order_acq_rel)) {}
~OptimisticLazy()27*da0073e9SAndroid Build Coastguard Worker   ~OptimisticLazy() {
28*da0073e9SAndroid Build Coastguard Worker     reset();
29*da0073e9SAndroid Build Coastguard Worker   }
30*da0073e9SAndroid Build Coastguard Worker 
31*da0073e9SAndroid Build Coastguard Worker   template <class Factory>
ensure(Factory && factory)32*da0073e9SAndroid Build Coastguard Worker   T& ensure(Factory&& factory) {
33*da0073e9SAndroid Build Coastguard Worker     if (T* value = value_.load(std::memory_order_acquire)) {
34*da0073e9SAndroid Build Coastguard Worker       return *value;
35*da0073e9SAndroid Build Coastguard Worker     }
36*da0073e9SAndroid Build Coastguard Worker     T* value = new T(factory());
37*da0073e9SAndroid Build Coastguard Worker     T* old = nullptr;
38*da0073e9SAndroid Build Coastguard Worker     if (!value_.compare_exchange_strong(
39*da0073e9SAndroid Build Coastguard Worker             old, value, std::memory_order_release, std::memory_order_acquire)) {
40*da0073e9SAndroid Build Coastguard Worker       delete value;
41*da0073e9SAndroid Build Coastguard Worker       value = old;
42*da0073e9SAndroid Build Coastguard Worker     }
43*da0073e9SAndroid Build Coastguard Worker     return *value;
44*da0073e9SAndroid Build Coastguard Worker   }
45*da0073e9SAndroid Build Coastguard Worker 
46*da0073e9SAndroid Build Coastguard Worker   // The following methods are not thread-safe: they should not be called
47*da0073e9SAndroid Build Coastguard Worker   // concurrently with any other method.
48*da0073e9SAndroid Build Coastguard Worker 
49*da0073e9SAndroid Build Coastguard Worker   OptimisticLazy& operator=(const OptimisticLazy& other) {
50*da0073e9SAndroid Build Coastguard Worker     *this = OptimisticLazy{other};
51*da0073e9SAndroid Build Coastguard Worker     return *this;
52*da0073e9SAndroid Build Coastguard Worker   }
53*da0073e9SAndroid Build Coastguard Worker 
54*da0073e9SAndroid Build Coastguard Worker   OptimisticLazy& operator=(OptimisticLazy&& other) noexcept {
55*da0073e9SAndroid Build Coastguard Worker     if (this != &other) {
56*da0073e9SAndroid Build Coastguard Worker       reset();
57*da0073e9SAndroid Build Coastguard Worker       value_.store(
58*da0073e9SAndroid Build Coastguard Worker           other.value_.exchange(nullptr, std::memory_order_acquire),
59*da0073e9SAndroid Build Coastguard Worker           std::memory_order_release);
60*da0073e9SAndroid Build Coastguard Worker     }
61*da0073e9SAndroid Build Coastguard Worker     return *this;
62*da0073e9SAndroid Build Coastguard Worker   }
63*da0073e9SAndroid Build Coastguard Worker 
reset()64*da0073e9SAndroid Build Coastguard Worker   void reset() {
65*da0073e9SAndroid Build Coastguard Worker     if (T* old = value_.load(std::memory_order_relaxed)) {
66*da0073e9SAndroid Build Coastguard Worker       value_.store(nullptr, std::memory_order_relaxed);
67*da0073e9SAndroid Build Coastguard Worker       delete old;
68*da0073e9SAndroid Build Coastguard Worker     }
69*da0073e9SAndroid Build Coastguard Worker   }
70*da0073e9SAndroid Build Coastguard Worker 
71*da0073e9SAndroid Build Coastguard Worker  private:
72*da0073e9SAndroid Build Coastguard Worker   std::atomic<T*> value_{nullptr};
73*da0073e9SAndroid Build Coastguard Worker };
74*da0073e9SAndroid Build Coastguard Worker 
75*da0073e9SAndroid Build Coastguard Worker /**
76*da0073e9SAndroid Build Coastguard Worker  * Interface for a value that is computed on first access.
77*da0073e9SAndroid Build Coastguard Worker  */
78*da0073e9SAndroid Build Coastguard Worker template <class T>
79*da0073e9SAndroid Build Coastguard Worker class LazyValue {
80*da0073e9SAndroid Build Coastguard Worker  public:
81*da0073e9SAndroid Build Coastguard Worker   virtual ~LazyValue() = default;
82*da0073e9SAndroid Build Coastguard Worker 
83*da0073e9SAndroid Build Coastguard Worker   virtual const T& get() const = 0;
84*da0073e9SAndroid Build Coastguard Worker };
85*da0073e9SAndroid Build Coastguard Worker 
86*da0073e9SAndroid Build Coastguard Worker /**
87*da0073e9SAndroid Build Coastguard Worker  * Convenience thread-safe LazyValue implementation with opportunistic
88*da0073e9SAndroid Build Coastguard Worker  * concurrency.
89*da0073e9SAndroid Build Coastguard Worker  */
90*da0073e9SAndroid Build Coastguard Worker template <class T>
91*da0073e9SAndroid Build Coastguard Worker class OptimisticLazyValue : public LazyValue<T> {
92*da0073e9SAndroid Build Coastguard Worker  public:
get()93*da0073e9SAndroid Build Coastguard Worker   const T& get() const override {
94*da0073e9SAndroid Build Coastguard Worker     return value_.ensure([this] { return compute(); });
95*da0073e9SAndroid Build Coastguard Worker   }
96*da0073e9SAndroid Build Coastguard Worker 
97*da0073e9SAndroid Build Coastguard Worker  private:
98*da0073e9SAndroid Build Coastguard Worker   virtual T compute() const = 0;
99*da0073e9SAndroid Build Coastguard Worker 
100*da0073e9SAndroid Build Coastguard Worker   mutable OptimisticLazy<T> value_;
101*da0073e9SAndroid Build Coastguard Worker };
102*da0073e9SAndroid Build Coastguard Worker 
103*da0073e9SAndroid Build Coastguard Worker /**
104*da0073e9SAndroid Build Coastguard Worker  * Convenience immutable (thus thread-safe) LazyValue implementation for cases
105*da0073e9SAndroid Build Coastguard Worker  * in which the value is not actually lazy.
106*da0073e9SAndroid Build Coastguard Worker  */
107*da0073e9SAndroid Build Coastguard Worker template <class T>
108*da0073e9SAndroid Build Coastguard Worker class PrecomputedLazyValue : public LazyValue<T> {
109*da0073e9SAndroid Build Coastguard Worker  public:
PrecomputedLazyValue(T value)110*da0073e9SAndroid Build Coastguard Worker   PrecomputedLazyValue(T value) : value_(std::move(value)) {}
111*da0073e9SAndroid Build Coastguard Worker 
get()112*da0073e9SAndroid Build Coastguard Worker   const T& get() const override {
113*da0073e9SAndroid Build Coastguard Worker     return value_;
114*da0073e9SAndroid Build Coastguard Worker   }
115*da0073e9SAndroid Build Coastguard Worker 
116*da0073e9SAndroid Build Coastguard Worker  private:
117*da0073e9SAndroid Build Coastguard Worker   T value_;
118*da0073e9SAndroid Build Coastguard Worker };
119*da0073e9SAndroid Build Coastguard Worker 
120*da0073e9SAndroid Build Coastguard Worker } // namespace c10
121