xref: /aosp_15_r20/external/pytorch/c10/util/Lazy.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <atomic>
4 #include <utility>
5 
6 namespace c10 {
7 
8 /**
9  * Thread-safe lazy value with opportunistic concurrency: on concurrent first
10  * access, the factory may be called by multiple threads, but only one result is
11  * stored and its reference returned to all the callers.
12  *
13  * Value is heap-allocated; this optimizes for the case in which the value is
14  * never actually computed.
15  */
16 template <class T>
17 class OptimisticLazy {
18  public:
19   OptimisticLazy() = default;
OptimisticLazy(const OptimisticLazy & other)20   OptimisticLazy(const OptimisticLazy& other) {
21     if (T* value = other.value_.load(std::memory_order_acquire)) {
22       value_ = new T(*value);
23     }
24   }
OptimisticLazy(OptimisticLazy && other)25   OptimisticLazy(OptimisticLazy&& other) noexcept
26       : value_(other.value_.exchange(nullptr, std::memory_order_acq_rel)) {}
~OptimisticLazy()27   ~OptimisticLazy() {
28     reset();
29   }
30 
31   template <class Factory>
ensure(Factory && factory)32   T& ensure(Factory&& factory) {
33     if (T* value = value_.load(std::memory_order_acquire)) {
34       return *value;
35     }
36     T* value = new T(factory());
37     T* old = nullptr;
38     if (!value_.compare_exchange_strong(
39             old, value, std::memory_order_release, std::memory_order_acquire)) {
40       delete value;
41       value = old;
42     }
43     return *value;
44   }
45 
46   // The following methods are not thread-safe: they should not be called
47   // concurrently with any other method.
48 
49   OptimisticLazy& operator=(const OptimisticLazy& other) {
50     *this = OptimisticLazy{other};
51     return *this;
52   }
53 
54   OptimisticLazy& operator=(OptimisticLazy&& other) noexcept {
55     if (this != &other) {
56       reset();
57       value_.store(
58           other.value_.exchange(nullptr, std::memory_order_acquire),
59           std::memory_order_release);
60     }
61     return *this;
62   }
63 
reset()64   void reset() {
65     if (T* old = value_.load(std::memory_order_relaxed)) {
66       value_.store(nullptr, std::memory_order_relaxed);
67       delete old;
68     }
69   }
70 
71  private:
72   std::atomic<T*> value_{nullptr};
73 };
74 
75 /**
76  * Interface for a value that is computed on first access.
77  */
78 template <class T>
79 class LazyValue {
80  public:
81   virtual ~LazyValue() = default;
82 
83   virtual const T& get() const = 0;
84 };
85 
86 /**
87  * Convenience thread-safe LazyValue implementation with opportunistic
88  * concurrency.
89  */
90 template <class T>
91 class OptimisticLazyValue : public LazyValue<T> {
92  public:
get()93   const T& get() const override {
94     return value_.ensure([this] { return compute(); });
95   }
96 
97  private:
98   virtual T compute() const = 0;
99 
100   mutable OptimisticLazy<T> value_;
101 };
102 
103 /**
104  * Convenience immutable (thus thread-safe) LazyValue implementation for cases
105  * in which the value is not actually lazy.
106  */
107 template <class T>
108 class PrecomputedLazyValue : public LazyValue<T> {
109  public:
PrecomputedLazyValue(T value)110   PrecomputedLazyValue(T value) : value_(std::move(value)) {}
111 
get()112   const T& get() const override {
113     return value_;
114   }
115 
116  private:
117   T value_;
118 };
119 
120 } // namespace c10
121