xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/forward_grad.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/forward_grad.h>
2 
3 namespace torch::autograd {
4 
5 namespace {
6 // See discussion in forward_grad.h for why these are global variables and not
7 // thread local
8 
9 std::mutex all_forward_levels_mutex_;
10 std::vector<std::shared_ptr<ForwardADLevel>> all_forward_levels_;
11 
12 const static at::Tensor singleton_undefined_tensor;
13 } // namespace
14 
get_next_idx()15 uint64_t ForwardADLevel::get_next_idx() {
16   std::lock_guard<std::mutex> lock(all_forward_levels_mutex_);
17   auto next_idx = all_forward_levels_.size();
18   TORCH_CHECK(
19       next_idx == 0, "Nested forward mode AD is not supported at the moment");
20   all_forward_levels_.push_back(std::make_shared<ForwardADLevel>(next_idx));
21   return next_idx;
22 }
23 
release_idx(uint64_t idx)24 void ForwardADLevel::release_idx(uint64_t idx) {
25   std::unique_lock<std::mutex> lock(all_forward_levels_mutex_);
26   TORCH_CHECK(
27       idx + 1 == all_forward_levels_.size(),
28       "Exiting a forward AD level that is not the "
29       "last that was created is not support. Ensure they are released in the reverse "
30       "order they were created.");
31   TORCH_INTERNAL_ASSERT(!all_forward_levels_.empty());
32   // Keep the level alive until we have released the lock
33   auto lvl = all_forward_levels_.back();
34   all_forward_levels_.pop_back();
35   lock.unlock();
36 }
37 
get_by_idx(uint64_t idx)38 std::shared_ptr<ForwardADLevel> ForwardADLevel::get_by_idx(uint64_t idx) {
39   std::lock_guard<std::mutex> lock(all_forward_levels_mutex_);
40   TORCH_CHECK(
41       idx < all_forward_levels_.size(),
42       "Trying to access a forward AD level with an invalid index. "
43       "This index was either not created or is already deleted.");
44   return all_forward_levels_[idx];
45 }
46 
try_get_by_idx(uint64_t idx)47 std::shared_ptr<ForwardADLevel> ForwardADLevel::try_get_by_idx(uint64_t idx) {
48   std::lock_guard<std::mutex> lock(all_forward_levels_mutex_);
49   if (idx < all_forward_levels_.size()) {
50     return all_forward_levels_[idx];
51   } else {
52     return nullptr;
53   }
54 }
55 
~ForwardADLevel()56 ForwardADLevel::~ForwardADLevel() {
57   std::lock_guard<std::mutex> lock(mutex_);
58   auto it = grads_.begin();
59   while (it != grads_.end()) {
60     // Warning this will lock *it mutex
61     // This is ok as this function is the *only* one to call back into another
62     // class's method.
63     (*it)->reset(idx_, /* update_level */ false);
64     it = grads_.erase(it);
65   }
66 }
67 
value(uint64_t level) const68 const at::Tensor& ForwardGrad::value(uint64_t level) const {
69   std::lock_guard<std::mutex> lock(mutex_);
70   const auto& it = content_.find(level);
71   return it == content_.end() ? singleton_undefined_tensor : (*it).second;
72 }
73 
undef_grad()74 const at::Tensor& ForwardGrad::undef_grad() {
75   return singleton_undefined_tensor;
76 }
77 
78 } // namespace torch::autograd
79