xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/Exception.h>
2 #include <pybind11/pybind11.h>
3 #include <torch/csrc/Exceptions.h>
4 #include <torch/csrc/autograd/python_anomaly_mode.h>
5 #include <torch/csrc/autograd/python_cpp_function.h>
6 #include <torch/csrc/python_headers.h>
7 #include <torch/csrc/utils/object_ptr.h>
8 #include <torch/csrc/utils/pybind.h>
9 #include <torch/csrc/utils/python_strings.h>
10 
11 namespace torch::autograd {
12 
store_stack()13 void PyAnomalyMetadata::store_stack() {
14   pybind11::gil_scoped_acquire gil;
15   THPObjectPtr mod(PyImport_ImportModule("torch.fx.traceback"));
16   if (!mod) {
17     throw python_error();
18   }
19 
20   THPObjectPtr list(PyObject_CallMethod(mod.get(), "format_stack", ""));
21   if (!list) {
22     throw python_error();
23   }
24 
25   if (PyDict_SetItemString(dict(), ANOMALY_TRACE_KEY, list.get())) {
26     throw python_error();
27   }
28 }
29 
print_stack(const std::string & current_node_name)30 void PyAnomalyMetadata::print_stack(const std::string& current_node_name) {
31   pybind11::gil_scoped_acquire gil;
32   if (!PyDict_Check(dict())) {
33     throw std::runtime_error("Anomaly metadata is not a python dictionary.");
34   }
35   PyObject* trace_stack = PyDict_GetItemString(dict(), ANOMALY_TRACE_KEY);
36   _print_stack(trace_stack, current_node_name, false);
37   PyObject* pyparent(PyDict_GetItemString(dict(), ANOMALY_PARENT_KEY));
38 
39   // if there is no "parent_" in metadata, then it means this metadata's node
40   // is the root and stop printing the traceback
41   while (pyparent) {
42     THPObjectPtr parent_metadata(PyObject_GetAttrString(pyparent, "metadata"));
43     if (!parent_metadata) {
44       throw python_error();
45     }
46     THPObjectPtr parent_name_pyobj(PyObject_CallMethod(pyparent, "name", ""));
47     if (!parent_name_pyobj) {
48       throw python_error();
49     }
50     const char* parent_name_char = PyUnicode_AsUTF8(parent_name_pyobj.get());
51     if (!parent_name_char) {
52       throw python_error();
53     }
54     const std::string parent_name(parent_name_char);
55     PyObject* parent_stack =
56         PyDict_GetItemString(parent_metadata.get(), ANOMALY_TRACE_KEY);
57     _print_stack(parent_stack, parent_name, true);
58     // get the parent of this node, if this node is a root, pyparent is simply
59     // null
60     pyparent = PyDict_GetItemString(parent_metadata.get(), ANOMALY_PARENT_KEY);
61   }
62 }
63 
assign_parent(const std::shared_ptr<Node> & parent_node)64 void PyAnomalyMetadata::assign_parent(
65     const std::shared_ptr<Node>& parent_node) {
66   // assign the python object of parent_node in metadata["parent_"]
67   // if parent_node is nullptr, then do nothing (it can mean that "parent_" key
68   // is not in metadata)
69 
70   pybind11::gil_scoped_acquire gil;
71   if (!parent_node)
72     return;
73 
74   THPObjectPtr parent_node_(functionToPyObject(parent_node));
75   if (!parent_node_) {
76     throw python_error();
77   }
78   if (PyDict_SetItemString(dict(), ANOMALY_PARENT_KEY, parent_node_.get())) {
79     throw python_error();
80   }
81 }
82 
_print_stack(PyObject * stack,const std::string & current_node_name,bool is_parent)83 void _print_stack(
84     PyObject* stack,
85     const std::string& current_node_name,
86     bool is_parent) {
87   if (!stack) {
88     TORCH_WARN(
89         "Error detected in ",
90         current_node_name,
91         ". ",
92         "No forward pass information available. Enable detect anomaly "
93         "during forward pass for more information.");
94     return;
95   }
96 
97   THPObjectPtr empty_string(PyUnicode_FromString(""));
98   if (!empty_string) {
99     throw python_error();
100   }
101 
102   // stack is a list of Python strings ending with newlines. Use join to convert
103   // to a single string.
104   THPObjectPtr msg(PyUnicode_Join(empty_string, stack));
105   if (!msg) {
106     throw python_error();
107   }
108 
109   if (!is_parent) {
110     TORCH_WARN(
111         "Error detected in ",
112         current_node_name,
113         ". ",
114         "Traceback of forward call that caused the error:\n",
115         THPUtils_unpackString(msg.get()));
116   } else {
117     TORCH_WARN(
118         "\n\n",
119         "Previous calculation was induced by ",
120         current_node_name,
121         ". "
122         "Traceback of forward call that induced the previous calculation:\n",
123         THPUtils_unpackString(msg.get()));
124   }
125 }
126 
127 } // namespace torch::autograd
128