xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/error_report.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/frontend/error_report.h>
2 
3 #include <torch/csrc/jit/frontend/tree.h>
4 
5 namespace torch::jit {
6 
7 // Avoid storing objects with destructor in thread_local for mobile build.
8 #ifndef C10_MOBILE
9 thread_local std::vector<Call> calls;
10 #endif // C10_MOBILE
11 
ErrorReport(const ErrorReport & e)12 ErrorReport::ErrorReport(const ErrorReport& e)
13     : ss(e.ss.str()),
14       context(e.context),
15       the_message(e.the_message),
16       error_stack(e.error_stack.begin(), e.error_stack.end()) {}
17 
18 #ifndef C10_MOBILE
ErrorReport(const SourceRange & r)19 ErrorReport::ErrorReport(const SourceRange& r)
20     : context(r), error_stack(calls.begin(), calls.end()) {}
21 
update_pending_range(const SourceRange & range)22 void ErrorReport::CallStack::update_pending_range(const SourceRange& range) {
23   calls.back().caller_range = range;
24 }
25 
CallStack(const std::string & name,const SourceRange & range)26 ErrorReport::CallStack::CallStack(
27     const std::string& name,
28     const SourceRange& range) {
29   calls.push_back({name, range});
30 }
31 
~CallStack()32 ErrorReport::CallStack::~CallStack() {
33   calls.pop_back();
34 }
35 #else // defined C10_MOBILE
ErrorReport(const SourceRange & r)36 ErrorReport::ErrorReport(const SourceRange& r) : context(r) {}
37 
update_pending_range(const SourceRange & range)38 void ErrorReport::CallStack::update_pending_range(const SourceRange& range) {}
39 
CallStack(const std::string & name,const SourceRange & range)40 ErrorReport::CallStack::CallStack(
41     const std::string& name,
42     const SourceRange& range) {}
43 
~CallStack()44 ErrorReport::CallStack::~CallStack() {}
45 #endif // C10_MOBILE
46 
get_stacked_errors(const std::vector<Call> & error_stack)47 static std::string get_stacked_errors(const std::vector<Call>& error_stack) {
48   std::stringstream msg;
49   if (!error_stack.empty()) {
50     for (auto it = error_stack.rbegin(); it != error_stack.rend() - 1; ++it) {
51       auto callee = it + 1;
52 
53       msg << "'" << it->fn_name
54           << "' is being compiled since it was called from '" << callee->fn_name
55           << "'\n";
56       callee->caller_range.highlight(msg);
57     }
58   }
59   return msg.str();
60 }
61 
current_call_stack()62 std::string ErrorReport::current_call_stack() {
63 #ifndef C10_MOBILE
64   return get_stacked_errors(calls);
65 #else
66   AT_ERROR("Call stack not supported on mobile");
67 #endif // C10_MOBILE
68 }
69 
what() const70 const char* ErrorReport::what() const noexcept {
71   std::stringstream msg;
72   msg << "\n" << ss.str();
73   msg << ":\n";
74   context.highlight(msg);
75 
76   msg << get_stacked_errors(error_stack);
77 
78   the_message = msg.str();
79   return the_message.c_str();
80 }
81 
82 } // namespace torch::jit
83