1 #pragma once 2 3 #include <ATen/core/Tensor.h> 4 #include <unordered_set> 5 6 namespace torch::autograd { 7 8 // [ Using ForwardGrad ] 9 // ForwardGrad needs to be a shared_ptr to satisfy constraints of its inner 10 // design. But this shared_ptr must be uniquely associated with the object that 11 // stores it (as of writing, either AutogradMeta or SavedVariable). This object 12 // is called the "owning object" in the discussions below. This owning object 13 // must call `ForwardGrad::clear()` when it is destroyed to ensure that the 14 // ForwardGrad is properly de-allocated. 15 16 struct ForwardGrad; 17 18 // This file contains two classes that are used to store forward AD gradients 19 // and ensure that they are scoped properly. Because forward AD runs 20 // concurrently with the evaluation of the function, we need a mechanism to 21 // separate different forward AD invocations and be able to compute the right 22 // gradients. We model such invocations as levels here. The particular scoping 23 // issue mentioned above has two main drivers: 24 // - Ensure that we can conveniently use forward AD within a high level API 25 // without 26 // leaking the forward AD states outside. 27 // - Ensure that we can keep the level that we expose to the user API simple 28 // (an integer 29 // that represents the nesting depth) while avoiding confusions when the 30 // level index is re-used. 31 32 // The important external APIs from this file are: 33 // - ForwardADLevel::get_next_idx() that can be used to enter a new level and 34 // get its index 35 // - ForwardADLevel::release_idx() that can be used to exit a given level. 36 // - ForwardGrad() can be used to store a given forward gradient that will 37 // handle the level 38 // tracking automatically. 39 40 // The basic implementation strategy is as follows: 41 // Every tensor has a ForwardGrad, maintaining a map from levels to tangents. 42 // ForwardGrad is responsible for registering itself to the appropriate 43 // ForwardADLevel when a new tangent is added to it via ForwardGrad::set_value 44 // and to un-register itself from this same level if that tangent is removed via 45 // ForwardGrad::reset. The ForwardADLevel is created when a new level is entered 46 // via ForwardADLevel::get_next_idx. A reference to the new ForwardADLevel is 47 // stored into a global (for the whole process) vector that ensure it can be 48 // accessed via ForwardADLevel::get_by_idx. This reference is deleted when the 49 // index is released by the user when calling ForwardADLevel::release_idx. When 50 // it is destructed, the ForwardADLevel is responsible for clearing all the 51 // tangents for its level stored in all the ForwardGrad that registered with it. 52 // 53 // This process-wide level design, compared to a thread local one, allows us to 54 // use very simple user facing handle for the level (an int) while enabling 55 // cross-thread forward AD. The only required synchronization for the user is 56 // when entering and exiting the levels. Some discussion on alternative design 57 // is in https://github.com/pytorch/pytorch/pull/49097#discussion_r543716453 and 58 // can be refined in the future. 59 60 // Correctness of concurrency: 61 // Each class uses its own lock when reading or modifying internal storages. 62 // This allows in particular to safely remove tangents from ForwardGrad when the 63 // ForwardADLevel is being exited. We ensure no deadlock by ensuring that a 64 // methods never calls into another class's method while the local class's lock 65 // is held except in one single case: calling from ForwardADLevel's destructor 66 // into ForwardGrad::reset with update_level=false. 67 68 // The lifetime of these objects is as follows: 69 // The ForwardADLevel can be in three states: 70 // - Initialized: where one of its reference is held by the global vector 71 // and there may be more 72 // references held by temporary variables in ForwardGrad's methods. 73 // - About to be destructed: where "release_idx" has been called and the 74 // only reason for the 75 // ForwardADLevel not to be destructed right away is that some methods in 76 // ForwardGrad have owning reference to it. This is done so that a 77 // ForwardADLevel can never be destructed when a ForwardGrad is 78 // registered with it and in the process of adding something to its 79 // internal state. 80 // - Being destructed: Here the ForwardADLevel is not referenced anymore 81 // and can be safely reset 82 // all of the ForwardGrad. Note that we can have more than one reset 83 // being called here (which is ok) but we are guaranteed that there is at 84 // least one. 85 // The ForwardGrad is simpler as there is no intermediary state and no special 86 // destructor for. The logic to unregister it from the different ForwardADLevel 87 // is done when the owning object (AutogradMeta or SavedVariable) is being 88 // destroyed. 89 90 // Other considered design: 91 // To avoid having the ForwardGrad::clear, we considered storing weak_ptr inside 92 // the ForwardADLevel. While this would work, it would mean that the set inside 93 // the ForwardADLevel would only grow unless we do an expensive linear scan to 94 // remove all the dangling weak pointers. Hence this approach was not used. 95 96 // Data structures in this file are optimized for this maximum number of levels. 97 // The number of levels corresponds to the degree of the gradient being 98 // computed using forward AD and we don't expect more than second order 99 // gradients to be common. 100 #define EXPECTED_MAX_LEVEL 2 101 102 struct TORCH_API ForwardADLevel { ForwardADLevelForwardADLevel103 ForwardADLevel(uint64_t idx) : idx_(idx) {} 104 ~ForwardADLevel(); 105 106 static uint64_t get_next_idx(); 107 static void release_idx(uint64_t idx); 108 static std::shared_ptr<ForwardADLevel> get_by_idx(uint64_t idx); 109 static std::shared_ptr<ForwardADLevel> try_get_by_idx(uint64_t idx); 110 eraseForwardADLevel111 void erase(const std::shared_ptr<ForwardGrad>& grad) { 112 std::lock_guard<std::mutex> lock(mutex_); 113 grads_.erase(grad); 114 } 115 insertForwardADLevel116 void insert(const std::shared_ptr<ForwardGrad>& grad) { 117 std::lock_guard<std::mutex> lock(mutex_); 118 grads_.insert(grad); 119 } 120 121 private: 122 std::unordered_set<std::shared_ptr<ForwardGrad>> grads_; 123 std::mutex mutex_; 124 uint64_t idx_; 125 }; 126 127 struct TORCH_API ForwardGrad : std::enable_shared_from_this<ForwardGrad> { 128 ForwardGrad() = default; 129 130 // This function must only be called when AutogradMeta or SavedVariable is 131 // being destructed as it ensures that: 132 // - The only (potential) other references to this ForwardGrad are the 133 // different level it is registered to 134 // - No other thread will try to call `set_value` or `value` ever from now 135 // on 136 // - Any of the ForwardADLevel that this ForwardGrad is registered with 137 // might 138 // call `reset` at any point during this function clearForwardGrad139 void clear() { 140 c10::SmallVector<uint64_t, EXPECTED_MAX_LEVEL> levels_idx; 141 142 { 143 std::lock_guard<std::mutex> lock(mutex_); 144 for (auto& c : content_) { 145 levels_idx.push_back(c.first); 146 } 147 } 148 149 for (auto l_idx : levels_idx) { 150 // Use "try" version here as another thread might have deleted this 151 // level before we got here 152 // This is an owning reference as we want to keep the level alive 153 // until we successfully unregister ourselves 154 auto level = ForwardADLevel::try_get_by_idx(l_idx); 155 if (level) { 156 level->erase(shared_from_this()); 157 } 158 } 159 } 160 set_valueForwardGrad161 void set_value(const at::Tensor& value, uint64_t level) { 162 // Owning reference to ensure the forward_level is not destroyed 163 // while we are updating our internal state 164 auto forward_level = ForwardADLevel::get_by_idx(level); 165 forward_level->insert(shared_from_this()); 166 167 std::lock_guard<std::mutex> lock(mutex_); 168 content_.insert({level, value}); 169 } 170 171 // This function removes the tangent for a given level from this ForwardGrad 172 // Use the update_level flag to disable notifying the level about this reset 173 // This flag is most notably used by the ForwardADLevel destructor. 174 void reset(uint64_t level, bool update_level = true) { 175 if (update_level) { 176 ForwardADLevel::get_by_idx(level)->erase(shared_from_this()); 177 } 178 179 std::unique_lock<std::mutex> lock(mutex_); 180 const auto& it = content_.find(level); 181 TORCH_INTERNAL_ASSERT( 182 it != content_.end(), "Resetting a non-existent level."); 183 // Keep the Tensor alive until we have released the lock 184 // This is needed as we can be in a case where this function is called by 185 // ForwardADLevel destructor 186 auto t = (*it).second; 187 content_.erase(level); 188 lock.unlock(); 189 } 190 191 const at::Tensor& value(uint64_t level) const; 192 containsForwardGrad193 bool contains(uint64_t level) { 194 std::lock_guard<std::mutex> lock(mutex_); 195 return content_.count(level) > 0; 196 } 197 emptyForwardGrad198 bool empty() const { 199 return content_.empty(); 200 } 201 202 static const at::Tensor& undef_grad(); 203 204 private: 205 // TODO(albanD): replace this with a SmallVector 206 std::unordered_map<uint64_t, at::Tensor> content_; 207 mutable std::mutex mutex_; 208 }; 209 210 } // namespace torch::autograd 211