xref: /aosp_15_r20/external/tensorflow/tensorflow/python/util/tf_stack.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 */
15 
16 // We extract stack traces in Python using the logic in tf_stack.cc, which
17 // stores a list of PyCodeObject*. Such stack trace extraction is really fast.
18 //
19 // We store the retrieved stack trace within the Node object directly. Then
20 // whenever the graph is instantiated/copies, we copy the stack trace with it.
21 // Since the graph instantiation goes through the protobuf roundtrip, we store
22 // the original stack traces mapping attached in FunctionLibraryDefinition.
23 
24 // clang-format off
25 // These headers must be at the top, before including Python.h header
26 // Otherwise, we get C2039 on MSVC due to 'copysign'
27 #include "pybind11/complex.h"
28 #include "pybind11/pybind11.h"
29 #include "pybind11/stl.h"
30 #include "pybind11/stl_bind.h"
31 // clang-format on
32 
33 #include <frameobject.h>
34 
35 #include <algorithm>
36 #include <vector>
37 
38 #include "Python.h"
39 #include "absl/algorithm/container.h"
40 #include "absl/container/flat_hash_set.h"
41 #include "absl/hash/hash.h"
42 #include "absl/strings/str_format.h"
43 #include "absl/strings/str_join.h"
44 #include "absl/types/span.h"
45 #include "tensorflow/c/c_api_internal.h"
46 #include "tensorflow/core/graph/graph.h"
47 #include "tensorflow/core/platform/mutex.h"
48 #include "tensorflow/core/platform/path.h"
49 #include "tensorflow/python/util/stack_trace.h"
50 
51 struct StackFrame;  // Forward declaration.
52 struct StackTrace;
53 
54 PYBIND11_MAKE_OPAQUE(std::vector<StackFrame>);
55 PYBIND11_MAKE_OPAQUE(StackTrace);
56 
57 namespace tensorflow {
58 
59 namespace {
60 
61 namespace py = pybind11;
62 
63 using StringSet = absl::flat_hash_set<std::string>;
64 
65 // Python wrapper for a SourceMap.
66 class PyBindSourceMap {
67  public:
PyBindSourceMap()68   PyBindSourceMap() : source_map_(std::make_shared<SourceMap>()) {}
69 
70   // Shares ownership with whoever captures traces in the scope of this map.
71   std::shared_ptr<SourceMap> source_map_;
72 };
73 
74 // Python wrapper for a FileSet.
75 class PyBindFileSet {
76  public:
PyBindFileSet()77   PyBindFileSet() : file_set_(std::make_shared<StringSet>()) {}
78 
79   // Shares ownership with whoever captures traces in the scope of this set.
80   std::shared_ptr<StringSet> file_set_;
81 };
82 
83 // Returns contents of the line corresponding to the given frame.
84 //
85 // Precondition: must be holding Python GIL.
LineContents(const StackFrame & frame)86 py::str LineContents(const StackFrame& frame) {
87   DCheckPyGilStateForStackTrace();
88   // Pointers are to avoid static destruction of pybind::object, which
89   // occurs in uncontrollable states.
90   static const auto* inspect = new py::module(py::module::import("inspect"));
91   static const auto* getmodule = new py::function(inspect->attr("getmodule"));
92   static const auto* linecache =
93       new py::module(py::module::import("linecache"));
94   static const auto* checkcache =
95       new py::function(linecache->attr("checkcache"));
96   static const auto* getline = new py::function(linecache->attr("getline"));
97   (*checkcache)(py::str(frame.file_name));
98 
99   // Here we use the undocumented second argument of inspect.getmodule to look
100   // up a module from a filename. It has been unchanged since 2015.
101   const auto& module = (*getmodule)(py::none(), py::str(frame.file_name));
102   py::object dict = py::none();
103   if (!module.is_none()) {
104     // module dict is used by getline to resolve import hooks; see the
105     // stdlib's inspect module.
106     dict = module.attr("__dict__");
107   }
108   return py::cast<py::str>(
109       (*getline)(py::str(frame.file_name), py::int_(frame.line_number), dict)
110           .attr("strip")());
111 }
112 
113 // Ignores the frames containing this substring for common prefix calculation.
114 static const char* kFilenameToIgnorePrefix = "<embedded";
115 
116 // Converts the given stack frame to string, according to options defined in
117 // `opts`.
StackFrameToString(const StackFrame & frame,const AbstractStackTrace::TracePrintingOptions & opts,int shared_prefix_size=0)118 std::string StackFrameToString(
119     const StackFrame& frame,
120     const AbstractStackTrace::TracePrintingOptions& opts,
121     int shared_prefix_size = 0) {
122   std::string out = absl::StrFormat(
123       "File \"%s\", line %d, in %s",
124       absl::StrContains(frame.file_name, kFilenameToIgnorePrefix)
125           ? frame.file_name
126           : frame.file_name.substr(shared_prefix_size),
127       frame.line_number, frame.function_name);
128 
129   if (opts.show_line_contents) {
130     PyGILState_STATE state = PyGILState_Ensure();
131     std::string line_contents = std::string(LineContents(frame));
132     PyGILState_Release(state);
133     if (!line_contents.empty()) {
134       absl::StrAppend(&out, "\n  ", line_contents);
135     }
136   }
137   return out;
138 }
139 
140 class StackTraceWrapper : public AbstractStackTrace {
141  public:
StackTraceWrapper(absl::Span<const StackFrame> stack_frames)142   explicit StackTraceWrapper(absl::Span<const StackFrame> stack_frames)
143       : stack_frames_cache_(std::vector<StackFrame>(stack_frames.begin(),
144                                                     stack_frames.end())) {}
145 
StackTraceWrapper(StackTraceWrapper && rhs)146   StackTraceWrapper(StackTraceWrapper&& rhs) {
147     captured_ = std::move(rhs.captured_);
148     source_map_ = std::move(rhs.source_map_);
149     filter_ = std::move(rhs.filter_);
150     stacklevel_ = rhs.stacklevel_;
151     tensorflow::mutex_lock lock(rhs.mu_);
152     stack_frames_cache_ = std::move(rhs.stack_frames_cache_);
153     last_stack_frame_cache_ = std::move(rhs.last_stack_frame_cache_);
154   }
155 
operator =(StackTraceWrapper && rhs)156   StackTraceWrapper& operator=(StackTraceWrapper&& rhs) {
157     if (&rhs == this) return *this;
158 
159     captured_ = std::move(rhs.captured_);
160     source_map_ = std::move(rhs.source_map_);
161     filter_ = std::move(rhs.filter_);
162     stacklevel_ = rhs.stacklevel_;
163 
164     tensorflow::mutex_lock self_lock(mu_);
165     tensorflow::mutex_lock rhs_lock(rhs.mu_);
166 
167     stack_frames_cache_ = std::move(rhs.stack_frames_cache_);
168     last_stack_frame_cache_ = std::move(rhs.last_stack_frame_cache_);
169     return *this;
170   }
171 
ExtractStack(const std::shared_ptr<SourceMap> & source_map,const std::shared_ptr<StringSet> & filter,int stacklevel)172   static StackTraceWrapper ExtractStack(
173       const std::shared_ptr<SourceMap>& source_map,
174       const std::shared_ptr<StringSet>& filter, int stacklevel) {
175     return StackTraceWrapper{StackTrace::Capture(-1), source_map, filter,
176                              stacklevel};
177   }
178 
ToFrames() const179   absl::Span<const StackFrame> ToFrames() const override {
180     tensorflow::mutex_lock lock(mu_);
181     if (stack_frames_cache_) {
182       return *stack_frames_cache_;
183     }
184 
185     // Grabbing the GIL solves two purposes: 1) makes the class thread-safe,
186     // and 2) ToStackFrames and LineContents actually need it.
187     PyGILState_STATE state = PyGILState_Ensure();
188 
189     stack_frames_cache_ = captured_.ToStackFrames(
190         *source_map_, [&](const char* f) { return StackTraceFiltering(f); });
191 
192     // Drop last stack frames.
193     int newsize = stack_frames_cache_->size() - stacklevel_;
194     if (newsize < 0) {
195       newsize = 0;
196     }
197     stack_frames_cache_->resize(newsize);
198 
199     PyGILState_Release(state);
200     return *stack_frames_cache_;
201   }
202 
get_stacklevel() const203   int get_stacklevel() const { return stacklevel_; }
204 
set_stacklevel(int stacklevel)205   void set_stacklevel(int stacklevel) { stacklevel_ = stacklevel; }
206 
GetUserFrames(int limit=-1) const207   std::vector<StackFrame> GetUserFrames(int limit = -1) const {
208     PyGILState_STATE state = PyGILState_Ensure();
209     std::vector<StackFrame> user_frames = captured_.ToStackFrames(
210         *source_map_,
211         [&](const char* file_name) {
212           return StackTraceFiltering(file_name) ||
213                  IsInternalFrameForFilename(file_name);
214         },
215         /*reverse_traversal=*/true,
216         /*limit=*/limit);
217     PyGILState_Release(state);
218     // ensure we use the original (outermost first) ordering.
219     absl::c_reverse(user_frames);
220     return user_frames;
221   }
222 
LastUserFrame() const223   StackFrame LastUserFrame() const override {
224     tensorflow::mutex_lock lock(mu_);
225     if (last_stack_frame_cache_) {
226       return *last_stack_frame_cache_;
227     }
228 
229     PyGILState_STATE state = PyGILState_Ensure();
230     std::vector<StackFrame> last_frame = GetUserFrames(1);
231 
232     if (last_frame.empty()) {
233       last_stack_frame_cache_ = StackFrame{"", -1, ""};
234     } else {
235       DCHECK_EQ(last_frame.size(), 1);
236       last_stack_frame_cache_ = last_frame[0];
237     }
238     PyGILState_Release(state);
239     return *last_stack_frame_cache_;
240   }
241 
242   // Erases a section of the stack trace.
Erase(int first,int last)243   void Erase(int first, int last) {
244     tensorflow::mutex_lock lock(mu_);
245     if (!stack_frames_cache_) {
246       ToFrames();
247     }
248     DCHECK_GE(first, 0);
249     DCHECK_LT(first, stack_frames_cache_->size());
250     DCHECK_GE(last, 0);
251     DCHECK_LE(last, stack_frames_cache_->size());
252     auto it = stack_frames_cache_->begin();
253     stack_frames_cache_->erase(it + first, it + last);
254   }
255 
ToString(const TracePrintingOptions & opts) const256   std::string ToString(const TracePrintingOptions& opts) const override {
257     std::vector<std::string> files_to_find_prefix;
258     for (const StackFrame& frame : ToFrames()) {
259       if (!absl::StrContains(frame.file_name, kFilenameToIgnorePrefix)) {
260         files_to_find_prefix.push_back(frame.file_name);
261       }
262     }
263     int shared_prefix_size =
264         opts.filter_common_prefix
265             ? io::CommonPathPrefix(files_to_find_prefix).size()
266             : 0;
267 
268     tensorflow::mutex_lock lock(mu_);
269     if (!opts.drop_internal_frames) {
270       return ToStringHelper(*stack_frames_cache_, opts, shared_prefix_size);
271     }
272 
273     std::vector<StackFrame> filtered_frames;
274     for (const StackFrame& frame : *stack_frames_cache_) {
275       if (!IsInternalFrameForFilename(frame.file_name)) {
276         filtered_frames.push_back(frame);
277       }
278     }
279     return ToStringHelper(filtered_frames, opts, shared_prefix_size);
280   }
281 
~StackTraceWrapper()282   ~StackTraceWrapper() override {
283     PyGILState_STATE state = PyGILState_Ensure();
284     captured_.Clear();
285     source_map_.reset();
286     filter_.reset();
287     PyGILState_Release(state);
288   }
289 
290  private:
StackTraceWrapper(StackTrace && captured,const std::shared_ptr<SourceMap> & source_map,const std::shared_ptr<StringSet> & filter,int stacklevel)291   StackTraceWrapper(StackTrace&& captured,
292                     const std::shared_ptr<SourceMap>& source_map,
293                     const std::shared_ptr<StringSet>& filter, int stacklevel)
294       : captured_(std::move(captured)),
295         source_map_(source_map),
296         filter_(filter),
297         stacklevel_(stacklevel) {}
298 
ToStringHelper(absl::Span<const StackFrame> stack_frames,const TracePrintingOptions & opts,int shared_prefix_size)299   static std::string ToStringHelper(absl::Span<const StackFrame> stack_frames,
300                                     const TracePrintingOptions& opts,
301                                     int shared_prefix_size) {
302     return absl::StrJoin(
303         stack_frames, "\n", [&](std::string* out, const StackFrame& frame) {
304           absl::StrAppend(out,
305                           StackFrameToString(frame, opts, shared_prefix_size));
306         });
307   }
308 
StackTraceFiltering(const char * file_name) const309   bool StackTraceFiltering(const char* file_name) const {
310     return filter_->contains(file_name);
311   }
312 
313   // Note: Make sure to update move constructor while adding new member
314   // variables.
315   StackTrace captured_;
316   std::shared_ptr<SourceMap> source_map_;
317   std::shared_ptr<StringSet> filter_;
318   int stacklevel_;
319 
320   // Using optional to force destruction while we hold a GIL.
321   mutable absl::optional<std::vector<StackFrame>> stack_frames_cache_
322       TF_GUARDED_BY(mu_);
323   mutable absl::optional<StackFrame> last_stack_frame_cache_ TF_GUARDED_BY(mu_);
324   mutable mutex mu_;
325 };
326 
327 }  // namespace
328 
PYBIND11_MODULE(_tf_stack,m)329 PYBIND11_MODULE(_tf_stack, m) {
330   py::class_<PyBindSourceMap>(m, "PyBindSourceMap")
331       .def(py::init())
332       .def("update_to",
333            [](const PyBindSourceMap& self, const py::tuple& source_map) {
334              self.source_map_->clear();
335              for (const auto& item : source_map) {
336                const auto& tuple_item = py::cast<py::tuple>(item);
337 
338                const auto& key = py::cast<py::tuple>(tuple_item[0]);
339                std::string&& k_filename = py::cast<std::string>(key[0]);
340                int k_lineno = py::cast<int>(key[1]);
341 
342                const auto& value = py::cast<py::tuple>(tuple_item[1]);
343                std::string&& v_filename = py::cast<std::string>(value[0]);
344                int v_lineno = py::cast<int>(value[1]);
345                const auto& function_name_val = value[2];
346                std::string&& v_function_name =
347                    function_name_val.is_none()
348                        ? ""
349                        : py::cast<std::string>(function_name_val);
350 
351                self.source_map_->emplace(
352                    SourceLoc{k_filename, k_lineno},
353                    StackFrame({v_filename, v_lineno, v_function_name}));
354              }
355            });
356 
357   py::class_<PyBindFileSet>(m, "PyBindFileSet")
358       .def(py::init())
359       .def("update_to", [](const PyBindFileSet& self, const py::set& file_set) {
360         self.file_set_->clear();
361         for (const auto& item : file_set) {
362           self.file_set_->insert(py::cast<std::string>(item));
363         }
364       });
365 
366   py::class_<StackFrame>(m, "StackFrame")
367       .def_property_readonly(
368           "filename",
369           [](const StackFrame& self) { return py::str(self.file_name); })
370       .def_property_readonly(
371           "lineno",
372           [](const StackFrame& self) { return py::int_(self.line_number); })
373       .def_property_readonly(
374           "name",
375           [](const StackFrame& self) { return py::str(self.function_name); })
376       .def_property_readonly(
377           "line", [](const StackFrame& self) { return LineContents(self); })
378 
379       // For compatibility with the traceback module.
380       .def("__eq__", &StackFrame::operator==)
381       .def("__ne__", &StackFrame::operator!=)
382       .def("__hash__",
383            [](const StackFrame& self) {
384              return absl::Hash<std::tuple<std::string, int, std::string>>()(
385                  std::make_tuple(self.file_name, self.line_number,
386                                  self.function_name));
387            })
388       .def("__getitem__",
389            [](const StackFrame& self, const py::object& index) -> py::object {
390              return py::make_tuple(
391                  py::str(self.file_name), py::int_(self.line_number),
392                  py::str(self.function_name), LineContents(self))[index];
393            })
394       .def("__iter__",
395            [](const StackFrame& self) {
396              return py::iter(py::make_tuple(
397                  py::str(self.file_name), py::int_(self.line_number),
398                  py::str(self.function_name), LineContents(self))
399 
400              );
401            })
402       .def("__repr__",
403            [](const StackFrame& self) { return StackFrameToString(self, {}); })
404       .def("__len__", [](const StackFrame&) { return 4; });
405 
406   py::class_<StackTraceWrapper>(m, "StackTraceWrapper")
407       // TODO(slebedev): upstream negative indexing support into pybind11.
408       .def(
409           "__getitem__",
410           [](const StackTraceWrapper& self, py::ssize_t index) {
411             absl::Span<const StackFrame> frames = self.ToFrames();
412             const size_t eff_index =
413                 index < 0 ? frames.size() + index : static_cast<size_t>(index);
414             if (eff_index >= frames.size()) {
415               throw py::index_error();
416             }
417             return frames[eff_index];
418           },
419           py::return_value_policy::reference_internal)
420       .def(
421           "__getitem__",
422           [](const StackTraceWrapper& self, py::slice slice) {
423             absl::Span<const StackFrame> frames = self.ToFrames();
424             py::ssize_t start, stop, step, slicelength;
425             if (!slice.compute(frames.size(), &start, &stop, &step,
426                                &slicelength)) {
427               throw py::error_already_set();
428             }
429             if (step == 1) {
430               return StackTraceWrapper{frames.subspan(start, slicelength)};
431             }
432             // TODO(cheshire): Cleanup, use Python slicing logic directly
433             // instead.
434             std::vector<StackFrame> out;
435             out.reserve(slicelength);
436             // Python slices allow negative indexing.
437             for (int i = start; i != stop; i += step) {
438               out.push_back(frames[i]);
439             }
440             return StackTraceWrapper{out};
441           },
442           py::return_value_policy::reference_internal)
443       .def("__delitem__",
444            [](StackTraceWrapper& self, py::ssize_t index) {
445              absl::Span<const StackFrame> frames = self.ToFrames();
446              const size_t eff_index =
447                  index < 0 ? frames.size() + index : static_cast<size_t>(index);
448              if (eff_index >= frames.size()) {
449                throw py::index_error();
450              }
451              self.Erase(eff_index, eff_index + 1);
452            })
453       .def("__delitem__",
454            [](StackTraceWrapper& self, py::slice slice) {
455              absl::Span<const StackFrame> frames = self.ToFrames();
456              py::ssize_t start, stop, step, slicelength;
457              if (!slice.compute(frames.size(), &start, &stop, &step,
458                                 &slicelength)) {
459                throw py::error_already_set();
460              }
461              if (step != 1) {
462                throw py::index_error();
463              }
464              if (stop > start) {
465                self.Erase(start, stop);
466              }
467            })
468       .def("__len__",
469            [](const StackTraceWrapper& self) { return self.ToFrames().size(); })
470       .def("__eq__",
471            [](const StackTraceWrapper& self, const StackTraceWrapper& other) {
472              return self.ToFrames() == other.ToFrames();
473            })
474       .def("__hash__",
475            [](const StackTraceWrapper& self) {
476              return py::hash(py::str(self.ToString({})));
477            })
478       // NOTE(feyu): consider remove this and use traceback.format_list(tb)
479       // to format the trace.
480       .def("__repr__",
481            [](const StackTraceWrapper& self) {
482              return py::str(self.ToString({}));
483            })
484       .def_property(
485           "_stacklevel", &StackTraceWrapper::get_stacklevel,
486           &StackTraceWrapper::set_stacklevel,
487           "Adjusts stacklevel; no effects after ToFrames() is called.")
488       .def(
489           "get_user_frames",
490           [](const StackTraceWrapper& self) {
491             return StackTraceWrapper{self.GetUserFrames()};
492           },
493           "Returns the non-framework frames as a new trace object.")
494       .def(
495           "last_user_frame",
496           [](const StackTraceWrapper& self) { return self.LastUserFrame(); },
497           "Returns the last non-framework frame.");
498 
499   m.def("extract_stack_for_op", [](const PyBindSourceMap& source_map,
500                                    const PyBindFileSet& file_set,
501                                    TF_Operation* op, int stacklevel) {
502     DCHECK(!op->node.GetStackTrace()) << "Should not reset the stack trace";
503     op->node.SetStackTrace(
504         std::make_shared<StackTraceWrapper>(StackTraceWrapper::ExtractStack(
505             source_map.source_map_, file_set.file_set_, stacklevel)));
506   });
507 
508   m.def(
509       "extract_stack",
510       [](const PyBindSourceMap& source_map, const PyBindFileSet& file_set) {
511         return StackTraceWrapper::ExtractStack(source_map.source_map_,
512                                                file_set.file_set_, 1);
513       },
514       py::return_value_policy::move);
515 }
516 
517 }  // namespace tensorflow
518