xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/anomaly_mode.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/Backtrace.h>
2 #include <c10/util/Exception.h>
3 #include <torch/csrc/autograd/anomaly_mode.h>
4 #include <torch/csrc/autograd/function.h>
5 #include <mutex>
6 
7 namespace torch::autograd {
8 
9 bool AnomalyMode::_enabled = false;
10 bool AnomalyMode::_check_nan = true;
11 
12 namespace {
get_anomaly_guard_lock()13 std::mutex& get_anomaly_guard_lock() {
14   static std::mutex anomaly_guard_lock{};
15   return anomaly_guard_lock;
16 }
17 
get_anomaly_counter()18 uint32_t& get_anomaly_counter() {
19   static uint32_t counter = 0;
20   return counter;
21 }
22 } // namespace
23 
DetectAnomalyGuard(bool check_nan)24 DetectAnomalyGuard::DetectAnomalyGuard(bool check_nan) {
25   TORCH_WARN_ONCE(
26       "This mode should be enabled only for debugging as the different tests will slow down your program execution.");
27   std::lock_guard<std::mutex> lock(get_anomaly_guard_lock());
28   uint32_t& counter = get_anomaly_counter();
29   counter++;
30   this->prev_check_nan_ = AnomalyMode::should_check_nan();
31   AnomalyMode::set_enabled(true, check_nan);
32 }
33 
~DetectAnomalyGuard()34 DetectAnomalyGuard::~DetectAnomalyGuard() {
35   std::lock_guard<std::mutex> lock(get_anomaly_guard_lock());
36   uint32_t& counter = get_anomaly_counter();
37   counter--;
38   AnomalyMode::set_enabled(counter > 0, this->prev_check_nan_);
39 }
40 
41 AnomalyMetadata::~AnomalyMetadata() = default;
42 
store_stack()43 void AnomalyMetadata::store_stack() {
44   traceback_ = c10::get_backtrace(/* frames_to_skip */ 1);
45 }
46 
print_stack(const std::string & current_node_name)47 void AnomalyMetadata::print_stack(const std::string& current_node_name) {
48   TORCH_WARN(
49       "Error detected in ",
50       current_node_name,
51       ". ",
52       "Traceback of forward call that caused the error:\n",
53       traceback_);
54 
55   auto& cur_parent = parent_;
56   // if there is no "parent_" in metadata, then it means this metadata's node
57   // is the root and stop printing the traceback
58   while (cur_parent) {
59     auto parent_metadata = cur_parent->metadata();
60     TORCH_WARN(
61         "\n\n",
62         "Previous calculation was induced by ",
63         cur_parent->name(),
64         ". "
65         "Traceback of forward call that induced the previous calculation:\n",
66         parent_metadata->traceback_);
67     // get the parent of this node, if this node is a root, pyparent is simply
68     // null
69     cur_parent = parent_metadata->parent_;
70   }
71 }
72 
assign_parent(const std::shared_ptr<Node> & parent_node)73 void AnomalyMetadata::assign_parent(const std::shared_ptr<Node>& parent_node) {
74   parent_ = parent_node;
75 }
76 
77 } // namespace torch::autograd
78