xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/forward_grad.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <unordered_set>
5 
6 namespace torch::autograd {
7 
8 // [ Using ForwardGrad ]
9 // ForwardGrad needs to be a shared_ptr to satisfy constraints of its inner
10 // design. But this shared_ptr must be uniquely associated with the object that
11 // stores it (as of writing, either AutogradMeta or SavedVariable). This object
12 // is called the "owning object" in the discussions below. This owning object
13 // must call `ForwardGrad::clear()` when it is destroyed to ensure that the
14 // ForwardGrad is properly de-allocated.
15 
16 struct ForwardGrad;
17 
18 // This file contains two classes that are used to store forward AD gradients
19 // and ensure that they are scoped properly. Because forward AD runs
20 // concurrently with the evaluation of the function, we need a mechanism to
21 // separate different forward AD invocations and be able to compute the right
22 // gradients. We model such invocations as levels here. The particular scoping
23 // issue mentioned above has two main drivers:
24 //   - Ensure that we can conveniently use forward AD within a high level API
25 //   without
26 //     leaking the forward AD states outside.
27 //   - Ensure that we can keep the level that we expose to the user API simple
28 //   (an integer
29 //     that represents the nesting depth) while avoiding confusions when the
30 //     level index is re-used.
31 
32 // The important external APIs from this file are:
33 //   - ForwardADLevel::get_next_idx() that can be used to enter a new level and
34 //   get its index
35 //   - ForwardADLevel::release_idx() that can be used to exit a given level.
36 //   - ForwardGrad() can be used to store a given forward gradient that will
37 //   handle the level
38 //     tracking automatically.
39 
40 // The basic implementation strategy is as follows:
41 // Every tensor has a ForwardGrad, maintaining a map from levels to tangents.
42 // ForwardGrad is responsible for registering itself to the appropriate
43 // ForwardADLevel when a new tangent is added to it via ForwardGrad::set_value
44 // and to un-register itself from this same level if that tangent is removed via
45 // ForwardGrad::reset. The ForwardADLevel is created when a new level is entered
46 // via ForwardADLevel::get_next_idx. A reference to the new ForwardADLevel is
47 // stored into a global (for the whole process) vector that ensure it can be
48 // accessed via ForwardADLevel::get_by_idx. This reference is deleted when the
49 // index is released by the user when calling ForwardADLevel::release_idx. When
50 // it is destructed, the ForwardADLevel is responsible for clearing all the
51 // tangents for its level stored in all the ForwardGrad that registered with it.
52 //
53 // This process-wide level design, compared to a thread local one, allows us to
54 // use very simple user facing handle for the level (an int) while enabling
55 // cross-thread forward AD. The only required synchronization for the user is
56 // when entering and exiting the levels. Some discussion on alternative design
57 // is in https://github.com/pytorch/pytorch/pull/49097#discussion_r543716453 and
58 // can be refined in the future.
59 
60 // Correctness of concurrency:
61 // Each class uses its own lock when reading or modifying internal storages.
62 // This allows in particular to safely remove tangents from ForwardGrad when the
63 // ForwardADLevel is being exited. We ensure no deadlock by ensuring that a
64 // methods never calls into another class's method while the local class's lock
65 // is held except in one single case: calling from ForwardADLevel's destructor
66 // into ForwardGrad::reset with update_level=false.
67 
68 // The lifetime of these objects is as follows:
69 // The ForwardADLevel can be in three states:
70 //      - Initialized: where one of its reference is held by the global vector
71 //      and there may be more
72 //        references held by temporary variables in ForwardGrad's methods.
73 //      - About to be destructed: where "release_idx" has been called and the
74 //      only reason for the
75 //        ForwardADLevel not to be destructed right away is that some methods in
76 //        ForwardGrad have owning reference to it. This is done so that a
77 //        ForwardADLevel can never be destructed when a ForwardGrad is
78 //        registered with it and in the process of adding something to its
79 //        internal state.
80 //      - Being destructed: Here the ForwardADLevel is not referenced anymore
81 //      and can be safely reset
82 //        all of the ForwardGrad. Note that we can have more than one reset
83 //        being called here (which is ok) but we are guaranteed that there is at
84 //        least one.
85 // The ForwardGrad is simpler as there is no intermediary state and no special
86 // destructor for. The logic to unregister it from the different ForwardADLevel
87 // is done when the owning object (AutogradMeta or SavedVariable) is being
88 // destroyed.
89 
90 // Other considered design:
91 // To avoid having the ForwardGrad::clear, we considered storing weak_ptr inside
92 // the ForwardADLevel. While this would work, it would mean that the set inside
93 // the ForwardADLevel would only grow unless we do an expensive linear scan to
94 // remove all the dangling weak pointers. Hence this approach was not used.
95 
96 // Data structures in this file are optimized for this maximum number of levels.
97 // The number of levels corresponds to the degree of the gradient being
98 // computed using forward AD and we don't expect more than second order
99 // gradients to be common.
100 #define EXPECTED_MAX_LEVEL 2
101 
102 struct TORCH_API ForwardADLevel {
ForwardADLevelForwardADLevel103   ForwardADLevel(uint64_t idx) : idx_(idx) {}
104   ~ForwardADLevel();
105 
106   static uint64_t get_next_idx();
107   static void release_idx(uint64_t idx);
108   static std::shared_ptr<ForwardADLevel> get_by_idx(uint64_t idx);
109   static std::shared_ptr<ForwardADLevel> try_get_by_idx(uint64_t idx);
110 
eraseForwardADLevel111   void erase(const std::shared_ptr<ForwardGrad>& grad) {
112     std::lock_guard<std::mutex> lock(mutex_);
113     grads_.erase(grad);
114   }
115 
insertForwardADLevel116   void insert(const std::shared_ptr<ForwardGrad>& grad) {
117     std::lock_guard<std::mutex> lock(mutex_);
118     grads_.insert(grad);
119   }
120 
121  private:
122   std::unordered_set<std::shared_ptr<ForwardGrad>> grads_;
123   std::mutex mutex_;
124   uint64_t idx_;
125 };
126 
127 struct TORCH_API ForwardGrad : std::enable_shared_from_this<ForwardGrad> {
128   ForwardGrad() = default;
129 
130   // This function must only be called when AutogradMeta or SavedVariable is
131   // being destructed as it ensures that:
132   //   - The only (potential) other references to this ForwardGrad are the
133   //     different level it is registered to
134   //   - No other thread will try to call `set_value` or `value` ever from now
135   //   on
136   //   - Any of the ForwardADLevel that this ForwardGrad is registered with
137   //   might
138   //     call `reset` at any point during this function
clearForwardGrad139   void clear() {
140     c10::SmallVector<uint64_t, EXPECTED_MAX_LEVEL> levels_idx;
141 
142     {
143       std::lock_guard<std::mutex> lock(mutex_);
144       for (auto& c : content_) {
145         levels_idx.push_back(c.first);
146       }
147     }
148 
149     for (auto l_idx : levels_idx) {
150       // Use "try" version here as another thread might have deleted this
151       // level before we got here
152       // This is an owning reference as we want to keep the level alive
153       // until we successfully unregister ourselves
154       auto level = ForwardADLevel::try_get_by_idx(l_idx);
155       if (level) {
156         level->erase(shared_from_this());
157       }
158     }
159   }
160 
set_valueForwardGrad161   void set_value(const at::Tensor& value, uint64_t level) {
162     // Owning reference to ensure the forward_level is not destroyed
163     // while we are updating our internal state
164     auto forward_level = ForwardADLevel::get_by_idx(level);
165     forward_level->insert(shared_from_this());
166 
167     std::lock_guard<std::mutex> lock(mutex_);
168     content_.insert({level, value});
169   }
170 
171   // This function removes the tangent for a given level from this ForwardGrad
172   // Use the update_level flag to disable notifying the level about this reset
173   // This flag is most notably used by the ForwardADLevel destructor.
174   void reset(uint64_t level, bool update_level = true) {
175     if (update_level) {
176       ForwardADLevel::get_by_idx(level)->erase(shared_from_this());
177     }
178 
179     std::unique_lock<std::mutex> lock(mutex_);
180     const auto& it = content_.find(level);
181     TORCH_INTERNAL_ASSERT(
182         it != content_.end(), "Resetting a non-existent level.");
183     // Keep the Tensor alive until we have released the lock
184     // This is needed as we can be in a case where this function is called by
185     // ForwardADLevel destructor
186     auto t = (*it).second;
187     content_.erase(level);
188     lock.unlock();
189   }
190 
191   const at::Tensor& value(uint64_t level) const;
192 
containsForwardGrad193   bool contains(uint64_t level) {
194     std::lock_guard<std::mutex> lock(mutex_);
195     return content_.count(level) > 0;
196   }
197 
emptyForwardGrad198   bool empty() const {
199     return content_.empty();
200   }
201 
202   static const at::Tensor& undef_grad();
203 
204  private:
205   // TODO(albanD): replace this with a SmallVector
206   std::unordered_map<uint64_t, at::Tensor> content_;
207   mutable std::mutex mutex_;
208 };
209 
210 } // namespace torch::autograd
211