xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/anomaly_mode.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <memory>
5 #include <string>
6 
7 namespace torch::autograd {
8 
9 // forward declaration of Node from function.h
10 struct Node;
11 
12 struct TORCH_API AnomalyMode {
is_enabledAnomalyMode13   static bool is_enabled() {
14     return _enabled;
15   }
should_check_nanAnomalyMode16   static bool should_check_nan() {
17     return _check_nan;
18   }
19   static void set_enabled(bool enabled, bool check_nan = true) {
20     _enabled = enabled;
21     _check_nan = check_nan;
22   }
23 
24  private:
25   static bool _enabled;
26   static bool _check_nan;
27 };
28 
29 /// A RAII guard that enables Anomaly Detection Mode.
30 ///
31 /// Anomaly detection mode is useful for debugging problems happening
32 /// in the backward, such as unexpectedly modified tensors or NaNs
33 /// occuring in the backward.
34 ///
35 /// The enabling of anomaly mode is global - as soon as there is one
36 /// such guard, it is enabled for all computation and threads. It also
37 /// comes with a significant performance penalty.
38 ///
39 /// Example:
40 /// @code
41 /// auto x = torch::tensor({1.}, torch::requires_grad());
42 /// {
43 ///   torch::autograd::DetectAnomalyGuard detect_anomaly;
44 ///   auto x = torch::tensor({5.0}, torch::requires_grad());
45 ///   auto y = x * x;
46 ///   auto z = y * y;
47 ///   y += 1;
48 ///   z.backward();
49 /// }
50 /// @endcode
51 class TORCH_API DetectAnomalyGuard {
52  public:
53   DetectAnomalyGuard(bool check_nan = true);
54   ~DetectAnomalyGuard();
55 
56  private:
57   bool prev_check_nan_;
58 };
59 
60 struct TORCH_API AnomalyMetadata {
61   virtual ~AnomalyMetadata();
62   virtual void store_stack();
63   virtual void print_stack(const std::string& current_node_name);
64   virtual void assign_parent(const std::shared_ptr<Node>& parent_node);
65 
66  private:
67   std::string traceback_;
68   std::shared_ptr<Node> parent_;
69 };
70 
71 } // namespace torch::autograd
72