1 #include <torch/csrc/jit/frontend/error_report.h>
2
3 #include <torch/csrc/jit/frontend/tree.h>
4
5 namespace torch::jit {
6
7 // Avoid storing objects with destructor in thread_local for mobile build.
8 #ifndef C10_MOBILE
9 thread_local std::vector<Call> calls;
10 #endif // C10_MOBILE
11
ErrorReport(const ErrorReport & e)12 ErrorReport::ErrorReport(const ErrorReport& e)
13 : ss(e.ss.str()),
14 context(e.context),
15 the_message(e.the_message),
16 error_stack(e.error_stack.begin(), e.error_stack.end()) {}
17
18 #ifndef C10_MOBILE
ErrorReport(const SourceRange & r)19 ErrorReport::ErrorReport(const SourceRange& r)
20 : context(r), error_stack(calls.begin(), calls.end()) {}
21
update_pending_range(const SourceRange & range)22 void ErrorReport::CallStack::update_pending_range(const SourceRange& range) {
23 calls.back().caller_range = range;
24 }
25
CallStack(const std::string & name,const SourceRange & range)26 ErrorReport::CallStack::CallStack(
27 const std::string& name,
28 const SourceRange& range) {
29 calls.push_back({name, range});
30 }
31
~CallStack()32 ErrorReport::CallStack::~CallStack() {
33 calls.pop_back();
34 }
35 #else // defined C10_MOBILE
ErrorReport(const SourceRange & r)36 ErrorReport::ErrorReport(const SourceRange& r) : context(r) {}
37
update_pending_range(const SourceRange & range)38 void ErrorReport::CallStack::update_pending_range(const SourceRange& range) {}
39
CallStack(const std::string & name,const SourceRange & range)40 ErrorReport::CallStack::CallStack(
41 const std::string& name,
42 const SourceRange& range) {}
43
~CallStack()44 ErrorReport::CallStack::~CallStack() {}
45 #endif // C10_MOBILE
46
get_stacked_errors(const std::vector<Call> & error_stack)47 static std::string get_stacked_errors(const std::vector<Call>& error_stack) {
48 std::stringstream msg;
49 if (!error_stack.empty()) {
50 for (auto it = error_stack.rbegin(); it != error_stack.rend() - 1; ++it) {
51 auto callee = it + 1;
52
53 msg << "'" << it->fn_name
54 << "' is being compiled since it was called from '" << callee->fn_name
55 << "'\n";
56 callee->caller_range.highlight(msg);
57 }
58 }
59 return msg.str();
60 }
61
current_call_stack()62 std::string ErrorReport::current_call_stack() {
63 #ifndef C10_MOBILE
64 return get_stacked_errors(calls);
65 #else
66 AT_ERROR("Call stack not supported on mobile");
67 #endif // C10_MOBILE
68 }
69
what() const70 const char* ErrorReport::what() const noexcept {
71 std::stringstream msg;
72 msg << "\n" << ss.str();
73 msg << ":\n";
74 context.highlight(msg);
75
76 msg << get_stacked_errors(error_stack);
77
78 the_message = msg.str();
79 return the_message.c_str();
80 }
81
82 } // namespace torch::jit
83