xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/traceback.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/traceback.h"
17 
18 #include <stdexcept>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/hash/hash.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "pybind11/pytypes.h"
26 #include "tensorflow/compiler/xla/python/exceptions.h"
27 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
28 #include "tensorflow/core/platform/logging.h"
29 
30 namespace xla {
31 
32 namespace py = pybind11;
33 
34 bool Traceback::enabled_ = true;
35 
Traceback()36 Traceback::Traceback() {
37   DCHECK(PyGILState_Check());
38   PyThreadState* thread_state = PyThreadState_GET();
39 
40 #if PY_VERSION_HEX < 0x030b0000
41   for (PyFrameObject* py_frame = thread_state->frame; py_frame != nullptr;
42        py_frame = py_frame->f_back) {
43     Py_INCREF(py_frame->f_code);
44     frames_.emplace_back(py_frame->f_code, py_frame->f_lasti);
45   }
46 #else   // PY_VERSION_HEX < 0x030b0000
47   for (PyFrameObject* py_frame = PyThreadState_GetFrame(thread_state);
48        py_frame != nullptr; py_frame = PyFrame_GetBack(py_frame)) {
49     frames_.emplace_back(PyFrame_GetCode(py_frame), PyFrame_GetLasti(py_frame));
50     Py_XDECREF(py_frame);
51   }
52 #endif  // PY_VERSION_HEX < 0x030b0000
53 }
54 
~Traceback()55 Traceback::~Traceback() {
56   for (auto& frame : frames_) {
57     DCHECK(PyGILState_Check());
58     Py_DECREF(frame.first);
59   }
60 }
61 
Traceback(Traceback && other)62 Traceback::Traceback(Traceback&& other) : frames_(std::move(other.frames_)) {
63   // absl::InlinedVector does not always clear itself if moved. Since we rely on
64   // its empty() method to destroy Traceback differently, we explicitly clear
65   // here.
66   other.frames_.clear();
67 }
68 
ToString() const69 std::string Traceback::Frame::ToString() const {
70   return absl::StrFormat("%s:%d (%s)", file_name, line_num, function_name);
71 }
72 
ToString() const73 std::string Traceback::ToString() const {
74   std::vector<std::string> frame_strs;
75   frame_strs.reserve(frames_.size());
76   for (const Frame& frame : Frames()) {
77     frame_strs.push_back(frame.ToString());
78   }
79   return absl::StrJoin(frame_strs, "\n");
80 }
81 
Frames() const82 std::vector<Traceback::Frame> Traceback::Frames() const {
83   // We require the GIL because we manipulate Python strings.
84   CHECK(PyGILState_Check());
85   std::vector<Traceback::Frame> frames;
86   frames.reserve(frames_.size());
87   for (const auto& frame : frames_) {
88     frames.push_back(Frame{
89         std::string(py::reinterpret_borrow<py::str>(frame.first->co_filename)),
90         std::string(py::reinterpret_borrow<py::str>(frame.first->co_name)),
91         frame.first->co_firstlineno,
92         PyCode_Addr2Line(frame.first, frame.second)});
93   }
94   return frames;
95 }
96 
Get()97 std::shared_ptr<Traceback> Traceback::Get() {
98   DCHECK(PyGILState_Check());
99   if (!enabled_) {
100     return nullptr;
101   }
102   return std::make_shared<Traceback>();
103 }
104 
SafeDestroy(Traceback traceback)105 void Traceback::SafeDestroy(Traceback traceback) {
106   // We want Traceback objects to be safe to destroy without holding the
107   // GIL, so we defer destruction of the strings.
108   GlobalPyRefManager()->AddGarbage(traceback.frames_);
109   traceback.frames_.clear();
110 }
111 
SetEnabled(bool enabled)112 void Traceback::SetEnabled(bool enabled) { enabled_ = enabled; }
113 
AsPythonTraceback() const114 py::object Traceback::AsPythonTraceback() const {
115   py::object traceback = py::none();
116   py::dict globals;
117   py::handle traceback_type(reinterpret_cast<PyObject*>(&PyTraceBack_Type));
118   for (const std::pair<PyCodeObject*, int>& frame : frames_) {
119     PyFrameObject* py_frame = PyFrame_New(PyThreadState_Get(), frame.first,
120                                           globals.ptr(), /*locals=*/nullptr);
121 
122     traceback = traceback_type(
123         /*tb_next=*/std::move(traceback),
124         /*tb_frame=*/
125         py::reinterpret_steal<py::object>(
126             reinterpret_cast<PyObject*>(py_frame)),
127         /*tb_lasti=*/frame.second,
128         /*tb_lineno=*/PyCode_Addr2Line(frame.first, frame.second));
129   }
130   return traceback;
131 }
132 
BuildTracebackSubmodule(py::module & m)133 void BuildTracebackSubmodule(py::module& m) {
134   py::class_<Traceback::Frame>(m, "Frame")
135       .def_readonly("file_name", &Traceback::Frame::file_name)
136       .def_readonly("function_name", &Traceback::Frame::function_name)
137       .def_readonly("function_start_line",
138                     &Traceback::Frame::function_start_line)
139       .def_readonly("line_num", &Traceback::Frame::line_num)
140       .def("__repr__", [](const Traceback::Frame& frame) {
141         return absl::StrFormat("%s;%s:%d", frame.function_name, frame.file_name,
142                                frame.line_num);
143       });
144 
145   py::class_<Traceback, std::shared_ptr<Traceback>> traceback(
146       m, "Traceback", "Represents a Python stack trace.");
147   traceback.def_property_static(
148       "enabled", [](py::object /* cls */) { return Traceback::enabled(); },
149       [](py::object /* cls */, bool enabled) {
150         return Traceback::SetEnabled(enabled);
151       });
152   traceback.def_static(
153       "get_traceback", []() { return Traceback::Get(); },
154       R"doc(
155     Returns a :class:`Traceback` for the current thread.
156 
157     If ``Traceback.enabled`` is ``True``, returns a :class:`Traceback` object
158     that describes the Python stack of the calling thread. Stack trace
159     collection has a small overhead, so it is disabled by default. If traceback
160     collection is disabled, returns ``None``.
161     )doc");
162   traceback.def_property_readonly("frames", &Traceback::Frames);
163   traceback.def("raw_frames", [](const Traceback& tb) -> py::tuple {
164     // We return a tuple of lists, rather than a list of tuples, because it
165     // is cheaper to allocate only three Python objects for everything rather
166     // than one per frame.
167     py::list out_code(tb.raw_frames().size());
168     py::list out_lasti(tb.raw_frames().size());
169     for (size_t i = 0; i < tb.raw_frames().size(); ++i) {
170       const auto& frame = tb.raw_frames()[i];
171       out_code[i] = py::reinterpret_borrow<py::object>(
172           reinterpret_cast<PyObject*>(frame.first));
173       out_lasti[i] = py::int_(frame.second);
174     }
175     return py::make_tuple(out_code, out_lasti);
176   });
177   traceback.def("__str__", &Traceback::ToString);
178   traceback.def("__eq__",
179                 [](const Traceback& a, const Traceback& b) { return a == b; });
180   traceback.def("__hash__",
181                 [](const Traceback& tb) { return absl::HashOf(tb); });
182   traceback.def("as_python_traceback", &Traceback::AsPythonTraceback);
183 
184   traceback.def_static(
185       "code_addr2line",
186       [](py::handle code, int lasti) {
187         if (!PyCode_Check(code.ptr())) {
188           throw xla::XlaRuntimeError("code argument must be a code object");
189         }
190         return PyCode_Addr2Line(reinterpret_cast<PyCodeObject*>(code.ptr()),
191                                 lasti);
192       },
193       "Python wrapper around the Python C API function PyCode_Addr2Line");
194 
195 #if PY_VERSION_HEX < 0x030b0000
196   // This function replaces the exception traceback associated with the current
197   // Python thread.
198   m.def(
199       "replace_thread_exc_traceback",
200       [](py::object tb) {
201         if (!tb.is_none() && !PyTraceBack_Check(tb.ptr())) {
202           throw xla::XlaRuntimeError(
203               "argument must be a traceback object or None");
204         }
205         PyThreadState* thread_state = PyThreadState_Get();
206         if (!thread_state->exc_info->exc_traceback) {
207           throw xla::XlaRuntimeError(
208               "Current thread does not have an active "
209               "exception traceback");
210         }
211         PyObject* old_exc_traceback = thread_state->exc_info->exc_traceback;
212         PyObject* new_tb = tb.is_none() ? nullptr : tb.release().ptr();
213         thread_state->exc_info->exc_traceback = new_tb;
214         Py_XDECREF(old_exc_traceback);
215       },
216       py::arg("traceback"));
217 #endif  // PY_VERSION_HEX < 0x30b0000
218 }
219 }  // namespace xla
220