1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <memory> 5 #include <string> 6 7 namespace torch::autograd { 8 9 // forward declaration of Node from function.h 10 struct Node; 11 12 struct TORCH_API AnomalyMode { is_enabledAnomalyMode13 static bool is_enabled() { 14 return _enabled; 15 } should_check_nanAnomalyMode16 static bool should_check_nan() { 17 return _check_nan; 18 } 19 static void set_enabled(bool enabled, bool check_nan = true) { 20 _enabled = enabled; 21 _check_nan = check_nan; 22 } 23 24 private: 25 static bool _enabled; 26 static bool _check_nan; 27 }; 28 29 /// A RAII guard that enables Anomaly Detection Mode. 30 /// 31 /// Anomaly detection mode is useful for debugging problems happening 32 /// in the backward, such as unexpectedly modified tensors or NaNs 33 /// occuring in the backward. 34 /// 35 /// The enabling of anomaly mode is global - as soon as there is one 36 /// such guard, it is enabled for all computation and threads. It also 37 /// comes with a significant performance penalty. 38 /// 39 /// Example: 40 /// @code 41 /// auto x = torch::tensor({1.}, torch::requires_grad()); 42 /// { 43 /// torch::autograd::DetectAnomalyGuard detect_anomaly; 44 /// auto x = torch::tensor({5.0}, torch::requires_grad()); 45 /// auto y = x * x; 46 /// auto z = y * y; 47 /// y += 1; 48 /// z.backward(); 49 /// } 50 /// @endcode 51 class TORCH_API DetectAnomalyGuard { 52 public: 53 DetectAnomalyGuard(bool check_nan = true); 54 ~DetectAnomalyGuard(); 55 56 private: 57 bool prev_check_nan_; 58 }; 59 60 struct TORCH_API AnomalyMetadata { 61 virtual ~AnomalyMetadata(); 62 virtual void store_stack(); 63 virtual void print_stack(const std::string& current_node_name); 64 virtual void assign_parent(const std::shared_ptr<Node>& parent_node); 65 66 private: 67 std::string traceback_; 68 std::shared_ptr<Node> parent_; 69 }; 70 71 } // namespace torch::autograd 72