1 #include <c10/util/Exception.h>
2 #include <pybind11/pybind11.h>
3 #include <torch/csrc/Exceptions.h>
4 #include <torch/csrc/autograd/python_anomaly_mode.h>
5 #include <torch/csrc/autograd/python_cpp_function.h>
6 #include <torch/csrc/python_headers.h>
7 #include <torch/csrc/utils/object_ptr.h>
8 #include <torch/csrc/utils/pybind.h>
9 #include <torch/csrc/utils/python_strings.h>
10
11 namespace torch::autograd {
12
store_stack()13 void PyAnomalyMetadata::store_stack() {
14 pybind11::gil_scoped_acquire gil;
15 THPObjectPtr mod(PyImport_ImportModule("torch.fx.traceback"));
16 if (!mod) {
17 throw python_error();
18 }
19
20 THPObjectPtr list(PyObject_CallMethod(mod.get(), "format_stack", ""));
21 if (!list) {
22 throw python_error();
23 }
24
25 if (PyDict_SetItemString(dict(), ANOMALY_TRACE_KEY, list.get())) {
26 throw python_error();
27 }
28 }
29
print_stack(const std::string & current_node_name)30 void PyAnomalyMetadata::print_stack(const std::string& current_node_name) {
31 pybind11::gil_scoped_acquire gil;
32 if (!PyDict_Check(dict())) {
33 throw std::runtime_error("Anomaly metadata is not a python dictionary.");
34 }
35 PyObject* trace_stack = PyDict_GetItemString(dict(), ANOMALY_TRACE_KEY);
36 _print_stack(trace_stack, current_node_name, false);
37 PyObject* pyparent(PyDict_GetItemString(dict(), ANOMALY_PARENT_KEY));
38
39 // if there is no "parent_" in metadata, then it means this metadata's node
40 // is the root and stop printing the traceback
41 while (pyparent) {
42 THPObjectPtr parent_metadata(PyObject_GetAttrString(pyparent, "metadata"));
43 if (!parent_metadata) {
44 throw python_error();
45 }
46 THPObjectPtr parent_name_pyobj(PyObject_CallMethod(pyparent, "name", ""));
47 if (!parent_name_pyobj) {
48 throw python_error();
49 }
50 const char* parent_name_char = PyUnicode_AsUTF8(parent_name_pyobj.get());
51 if (!parent_name_char) {
52 throw python_error();
53 }
54 const std::string parent_name(parent_name_char);
55 PyObject* parent_stack =
56 PyDict_GetItemString(parent_metadata.get(), ANOMALY_TRACE_KEY);
57 _print_stack(parent_stack, parent_name, true);
58 // get the parent of this node, if this node is a root, pyparent is simply
59 // null
60 pyparent = PyDict_GetItemString(parent_metadata.get(), ANOMALY_PARENT_KEY);
61 }
62 }
63
assign_parent(const std::shared_ptr<Node> & parent_node)64 void PyAnomalyMetadata::assign_parent(
65 const std::shared_ptr<Node>& parent_node) {
66 // assign the python object of parent_node in metadata["parent_"]
67 // if parent_node is nullptr, then do nothing (it can mean that "parent_" key
68 // is not in metadata)
69
70 pybind11::gil_scoped_acquire gil;
71 if (!parent_node)
72 return;
73
74 THPObjectPtr parent_node_(functionToPyObject(parent_node));
75 if (!parent_node_) {
76 throw python_error();
77 }
78 if (PyDict_SetItemString(dict(), ANOMALY_PARENT_KEY, parent_node_.get())) {
79 throw python_error();
80 }
81 }
82
_print_stack(PyObject * stack,const std::string & current_node_name,bool is_parent)83 void _print_stack(
84 PyObject* stack,
85 const std::string& current_node_name,
86 bool is_parent) {
87 if (!stack) {
88 TORCH_WARN(
89 "Error detected in ",
90 current_node_name,
91 ". ",
92 "No forward pass information available. Enable detect anomaly "
93 "during forward pass for more information.");
94 return;
95 }
96
97 THPObjectPtr empty_string(PyUnicode_FromString(""));
98 if (!empty_string) {
99 throw python_error();
100 }
101
102 // stack is a list of Python strings ending with newlines. Use join to convert
103 // to a single string.
104 THPObjectPtr msg(PyUnicode_Join(empty_string, stack));
105 if (!msg) {
106 throw python_error();
107 }
108
109 if (!is_parent) {
110 TORCH_WARN(
111 "Error detected in ",
112 current_node_name,
113 ". ",
114 "Traceback of forward call that caused the error:\n",
115 THPUtils_unpackString(msg.get()));
116 } else {
117 TORCH_WARN(
118 "\n\n",
119 "Previous calculation was induced by ",
120 current_node_name,
121 ". "
122 "Traceback of forward call that induced the previous calculation:\n",
123 THPUtils_unpackString(msg.get()));
124 }
125 }
126
127 } // namespace torch::autograd
128