1 #include <torch/csrc/autograd/forward_grad.h> 2 3 namespace torch::autograd { 4 5 namespace { 6 // See discussion in forward_grad.h for why these are global variables and not 7 // thread local 8 9 std::mutex all_forward_levels_mutex_; 10 std::vector<std::shared_ptr<ForwardADLevel>> all_forward_levels_; 11 12 const static at::Tensor singleton_undefined_tensor; 13 } // namespace 14 get_next_idx()15uint64_t ForwardADLevel::get_next_idx() { 16 std::lock_guard<std::mutex> lock(all_forward_levels_mutex_); 17 auto next_idx = all_forward_levels_.size(); 18 TORCH_CHECK( 19 next_idx == 0, "Nested forward mode AD is not supported at the moment"); 20 all_forward_levels_.push_back(std::make_shared<ForwardADLevel>(next_idx)); 21 return next_idx; 22 } 23 release_idx(uint64_t idx)24void ForwardADLevel::release_idx(uint64_t idx) { 25 std::unique_lock<std::mutex> lock(all_forward_levels_mutex_); 26 TORCH_CHECK( 27 idx + 1 == all_forward_levels_.size(), 28 "Exiting a forward AD level that is not the " 29 "last that was created is not support. Ensure they are released in the reverse " 30 "order they were created."); 31 TORCH_INTERNAL_ASSERT(!all_forward_levels_.empty()); 32 // Keep the level alive until we have released the lock 33 auto lvl = all_forward_levels_.back(); 34 all_forward_levels_.pop_back(); 35 lock.unlock(); 36 } 37 get_by_idx(uint64_t idx)38std::shared_ptr<ForwardADLevel> ForwardADLevel::get_by_idx(uint64_t idx) { 39 std::lock_guard<std::mutex> lock(all_forward_levels_mutex_); 40 TORCH_CHECK( 41 idx < all_forward_levels_.size(), 42 "Trying to access a forward AD level with an invalid index. " 43 "This index was either not created or is already deleted."); 44 return all_forward_levels_[idx]; 45 } 46 try_get_by_idx(uint64_t idx)47std::shared_ptr<ForwardADLevel> ForwardADLevel::try_get_by_idx(uint64_t idx) { 48 std::lock_guard<std::mutex> lock(all_forward_levels_mutex_); 49 if (idx < all_forward_levels_.size()) { 50 return all_forward_levels_[idx]; 51 } else { 52 return nullptr; 53 } 54 } 55 ~ForwardADLevel()56ForwardADLevel::~ForwardADLevel() { 57 std::lock_guard<std::mutex> lock(mutex_); 58 auto it = grads_.begin(); 59 while (it != grads_.end()) { 60 // Warning this will lock *it mutex 61 // This is ok as this function is the *only* one to call back into another 62 // class's method. 63 (*it)->reset(idx_, /* update_level */ false); 64 it = grads_.erase(it); 65 } 66 } 67 value(uint64_t level) const68const at::Tensor& ForwardGrad::value(uint64_t level) const { 69 std::lock_guard<std::mutex> lock(mutex_); 70 const auto& it = content_.find(level); 71 return it == content_.end() ? singleton_undefined_tensor : (*it).second; 72 } 73 undef_grad()74const at::Tensor& ForwardGrad::undef_grad() { 75 return singleton_undefined_tensor; 76 } 77 78 } // namespace torch::autograd 79