1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/traceback.h"
17
18 #include <stdexcept>
19 #include <string>
20 #include <utility>
21
22 #include "absl/hash/hash.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "pybind11/pytypes.h"
26 #include "tensorflow/compiler/xla/python/exceptions.h"
27 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
28 #include "tensorflow/core/platform/logging.h"
29
30 namespace xla {
31
32 namespace py = pybind11;
33
34 bool Traceback::enabled_ = true;
35
Traceback()36 Traceback::Traceback() {
37 DCHECK(PyGILState_Check());
38 PyThreadState* thread_state = PyThreadState_GET();
39
40 #if PY_VERSION_HEX < 0x030b0000
41 for (PyFrameObject* py_frame = thread_state->frame; py_frame != nullptr;
42 py_frame = py_frame->f_back) {
43 Py_INCREF(py_frame->f_code);
44 frames_.emplace_back(py_frame->f_code, py_frame->f_lasti);
45 }
46 #else // PY_VERSION_HEX < 0x030b0000
47 for (PyFrameObject* py_frame = PyThreadState_GetFrame(thread_state);
48 py_frame != nullptr; py_frame = PyFrame_GetBack(py_frame)) {
49 frames_.emplace_back(PyFrame_GetCode(py_frame), PyFrame_GetLasti(py_frame));
50 Py_XDECREF(py_frame);
51 }
52 #endif // PY_VERSION_HEX < 0x030b0000
53 }
54
~Traceback()55 Traceback::~Traceback() {
56 for (auto& frame : frames_) {
57 DCHECK(PyGILState_Check());
58 Py_DECREF(frame.first);
59 }
60 }
61
Traceback(Traceback && other)62 Traceback::Traceback(Traceback&& other) : frames_(std::move(other.frames_)) {
63 // absl::InlinedVector does not always clear itself if moved. Since we rely on
64 // its empty() method to destroy Traceback differently, we explicitly clear
65 // here.
66 other.frames_.clear();
67 }
68
ToString() const69 std::string Traceback::Frame::ToString() const {
70 return absl::StrFormat("%s:%d (%s)", file_name, line_num, function_name);
71 }
72
ToString() const73 std::string Traceback::ToString() const {
74 std::vector<std::string> frame_strs;
75 frame_strs.reserve(frames_.size());
76 for (const Frame& frame : Frames()) {
77 frame_strs.push_back(frame.ToString());
78 }
79 return absl::StrJoin(frame_strs, "\n");
80 }
81
Frames() const82 std::vector<Traceback::Frame> Traceback::Frames() const {
83 // We require the GIL because we manipulate Python strings.
84 CHECK(PyGILState_Check());
85 std::vector<Traceback::Frame> frames;
86 frames.reserve(frames_.size());
87 for (const auto& frame : frames_) {
88 frames.push_back(Frame{
89 std::string(py::reinterpret_borrow<py::str>(frame.first->co_filename)),
90 std::string(py::reinterpret_borrow<py::str>(frame.first->co_name)),
91 frame.first->co_firstlineno,
92 PyCode_Addr2Line(frame.first, frame.second)});
93 }
94 return frames;
95 }
96
Get()97 std::shared_ptr<Traceback> Traceback::Get() {
98 DCHECK(PyGILState_Check());
99 if (!enabled_) {
100 return nullptr;
101 }
102 return std::make_shared<Traceback>();
103 }
104
SafeDestroy(Traceback traceback)105 void Traceback::SafeDestroy(Traceback traceback) {
106 // We want Traceback objects to be safe to destroy without holding the
107 // GIL, so we defer destruction of the strings.
108 GlobalPyRefManager()->AddGarbage(traceback.frames_);
109 traceback.frames_.clear();
110 }
111
SetEnabled(bool enabled)112 void Traceback::SetEnabled(bool enabled) { enabled_ = enabled; }
113
AsPythonTraceback() const114 py::object Traceback::AsPythonTraceback() const {
115 py::object traceback = py::none();
116 py::dict globals;
117 py::handle traceback_type(reinterpret_cast<PyObject*>(&PyTraceBack_Type));
118 for (const std::pair<PyCodeObject*, int>& frame : frames_) {
119 PyFrameObject* py_frame = PyFrame_New(PyThreadState_Get(), frame.first,
120 globals.ptr(), /*locals=*/nullptr);
121
122 traceback = traceback_type(
123 /*tb_next=*/std::move(traceback),
124 /*tb_frame=*/
125 py::reinterpret_steal<py::object>(
126 reinterpret_cast<PyObject*>(py_frame)),
127 /*tb_lasti=*/frame.second,
128 /*tb_lineno=*/PyCode_Addr2Line(frame.first, frame.second));
129 }
130 return traceback;
131 }
132
BuildTracebackSubmodule(py::module & m)133 void BuildTracebackSubmodule(py::module& m) {
134 py::class_<Traceback::Frame>(m, "Frame")
135 .def_readonly("file_name", &Traceback::Frame::file_name)
136 .def_readonly("function_name", &Traceback::Frame::function_name)
137 .def_readonly("function_start_line",
138 &Traceback::Frame::function_start_line)
139 .def_readonly("line_num", &Traceback::Frame::line_num)
140 .def("__repr__", [](const Traceback::Frame& frame) {
141 return absl::StrFormat("%s;%s:%d", frame.function_name, frame.file_name,
142 frame.line_num);
143 });
144
145 py::class_<Traceback, std::shared_ptr<Traceback>> traceback(
146 m, "Traceback", "Represents a Python stack trace.");
147 traceback.def_property_static(
148 "enabled", [](py::object /* cls */) { return Traceback::enabled(); },
149 [](py::object /* cls */, bool enabled) {
150 return Traceback::SetEnabled(enabled);
151 });
152 traceback.def_static(
153 "get_traceback", []() { return Traceback::Get(); },
154 R"doc(
155 Returns a :class:`Traceback` for the current thread.
156
157 If ``Traceback.enabled`` is ``True``, returns a :class:`Traceback` object
158 that describes the Python stack of the calling thread. Stack trace
159 collection has a small overhead, so it is disabled by default. If traceback
160 collection is disabled, returns ``None``.
161 )doc");
162 traceback.def_property_readonly("frames", &Traceback::Frames);
163 traceback.def("raw_frames", [](const Traceback& tb) -> py::tuple {
164 // We return a tuple of lists, rather than a list of tuples, because it
165 // is cheaper to allocate only three Python objects for everything rather
166 // than one per frame.
167 py::list out_code(tb.raw_frames().size());
168 py::list out_lasti(tb.raw_frames().size());
169 for (size_t i = 0; i < tb.raw_frames().size(); ++i) {
170 const auto& frame = tb.raw_frames()[i];
171 out_code[i] = py::reinterpret_borrow<py::object>(
172 reinterpret_cast<PyObject*>(frame.first));
173 out_lasti[i] = py::int_(frame.second);
174 }
175 return py::make_tuple(out_code, out_lasti);
176 });
177 traceback.def("__str__", &Traceback::ToString);
178 traceback.def("__eq__",
179 [](const Traceback& a, const Traceback& b) { return a == b; });
180 traceback.def("__hash__",
181 [](const Traceback& tb) { return absl::HashOf(tb); });
182 traceback.def("as_python_traceback", &Traceback::AsPythonTraceback);
183
184 traceback.def_static(
185 "code_addr2line",
186 [](py::handle code, int lasti) {
187 if (!PyCode_Check(code.ptr())) {
188 throw xla::XlaRuntimeError("code argument must be a code object");
189 }
190 return PyCode_Addr2Line(reinterpret_cast<PyCodeObject*>(code.ptr()),
191 lasti);
192 },
193 "Python wrapper around the Python C API function PyCode_Addr2Line");
194
195 #if PY_VERSION_HEX < 0x030b0000
196 // This function replaces the exception traceback associated with the current
197 // Python thread.
198 m.def(
199 "replace_thread_exc_traceback",
200 [](py::object tb) {
201 if (!tb.is_none() && !PyTraceBack_Check(tb.ptr())) {
202 throw xla::XlaRuntimeError(
203 "argument must be a traceback object or None");
204 }
205 PyThreadState* thread_state = PyThreadState_Get();
206 if (!thread_state->exc_info->exc_traceback) {
207 throw xla::XlaRuntimeError(
208 "Current thread does not have an active "
209 "exception traceback");
210 }
211 PyObject* old_exc_traceback = thread_state->exc_info->exc_traceback;
212 PyObject* new_tb = tb.is_none() ? nullptr : tb.release().ptr();
213 thread_state->exc_info->exc_traceback = new_tb;
214 Py_XDECREF(old_exc_traceback);
215 },
216 py::arg("traceback"));
217 #endif // PY_VERSION_HEX < 0x30b0000
218 }
219 } // namespace xla
220