xref: /aosp_15_r20/external/pytorch/c10/util/Exception.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/Exception.h>
2 #include <c10/util/Logging.h>
3 #include <c10/util/Type.h>
4 
5 #include <sstream>
6 #include <string>
7 #include <utility>
8 
9 namespace c10 {
10 
Error(std::string msg,Backtrace backtrace,const void * caller)11 Error::Error(std::string msg, Backtrace backtrace, const void* caller)
12     : msg_(std::move(msg)), backtrace_(std::move(backtrace)), caller_(caller) {
13   refresh_what();
14 }
15 
16 // PyTorch-style error message
17 // Error::Error(SourceLocation source_location, const std::string& msg)
18 // NB: This is defined in Logging.cpp for access to GetFetchStackTrace
19 
20 // Caffe2-style error message
Error(const char * file,const uint32_t line,const char * condition,const std::string & msg,Backtrace backtrace,const void * caller)21 Error::Error(
22     const char* file,
23     const uint32_t line,
24     const char* condition,
25     const std::string& msg,
26     Backtrace backtrace,
27     const void* caller)
28     : Error(
29           str("[enforce fail at ",
30               detail::StripBasename(file),
31               ":",
32               line,
33               "] ",
34               condition,
35               ". ",
36               msg),
37           std::move(backtrace),
38           caller) {}
39 
compute_what(bool include_backtrace) const40 std::string Error::compute_what(bool include_backtrace) const {
41   std::ostringstream oss;
42 
43   oss << msg_;
44 
45   if (context_.size() == 1) {
46     // Fold error and context in one line
47     oss << " (" << context_[0] << ")";
48   } else {
49     for (const auto& c : context_) {
50       oss << "\n  " << c;
51     }
52   }
53 
54   if (include_backtrace && backtrace_) {
55     oss << "\n" << backtrace_->get();
56   }
57 
58   return oss.str();
59 }
60 
backtrace() const61 const Backtrace& Error::backtrace() const {
62   return backtrace_;
63 }
64 
what() const65 const char* Error::what() const noexcept {
66   return what_
67       .ensure([this] {
68         try {
69           return compute_what(/*include_backtrace*/ true);
70         } catch (...) {
71           // what() is noexcept, we need to return something here.
72           return std::string{"<Error computing Error::what()>"};
73         }
74       })
75       .c_str();
76 }
77 
refresh_what()78 void Error::refresh_what() {
79   // Do not compute what_ eagerly, as it would trigger the computation of the
80   // backtrace. Instead, invalidate it, it will be computed on first access.
81   // refresh_what() is only called by non-const public methods which are not
82   // supposed to be called concurrently with any other method, so it is safe to
83   // invalidate here.
84   what_.reset();
85   what_without_backtrace_ = compute_what(/*include_backtrace*/ false);
86 }
87 
add_context(std::string new_msg)88 void Error::add_context(std::string new_msg) {
89   context_.push_back(std::move(new_msg));
90   // TODO: Calling add_context O(n) times has O(n^2) cost.  We can fix
91   // this perf problem by populating the fields lazily... if this ever
92   // actually is a problem.
93   // NB: If you do fix this, make sure you do it in a thread safe way!
94   // what() is almost certainly expected to be thread safe even when
95   // accessed across multiple threads
96   refresh_what();
97 }
98 
99 namespace detail {
100 
torchCheckFail(const char * func,const char * file,uint32_t line,const std::string & msg)101 void torchCheckFail(
102     const char* func,
103     const char* file,
104     uint32_t line,
105     const std::string& msg) {
106   throw ::c10::Error({func, file, line}, msg);
107 }
108 
torchCheckFail(const char * func,const char * file,uint32_t line,const char * msg)109 void torchCheckFail(
110     const char* func,
111     const char* file,
112     uint32_t line,
113     const char* msg) {
114   throw ::c10::Error({func, file, line}, msg);
115 }
116 
torchInternalAssertFail(const char * func,const char * file,uint32_t line,const char * condMsg,const char * userMsg)117 void torchInternalAssertFail(
118     const char* func,
119     const char* file,
120     uint32_t line,
121     const char* condMsg,
122     const char* userMsg) {
123   torchCheckFail(func, file, line, c10::str(condMsg, userMsg));
124 }
125 
126 // This should never be called. It is provided in case of compilers
127 // that don't do any dead code stripping in debug builds.
torchInternalAssertFail(const char * func,const char * file,uint32_t line,const char * condMsg,const std::string & userMsg)128 void torchInternalAssertFail(
129     const char* func,
130     const char* file,
131     uint32_t line,
132     const char* condMsg,
133     const std::string& userMsg) {
134   torchCheckFail(func, file, line, c10::str(condMsg, userMsg));
135 }
136 
137 } // namespace detail
138 
139 namespace WarningUtils {
140 
141 namespace {
getBaseHandler()142 WarningHandler* getBaseHandler() {
143   static WarningHandler base_warning_handler_ = WarningHandler();
144   return &base_warning_handler_;
145 }
146 
147 class ThreadWarningHandler {
148  public:
149   ThreadWarningHandler() = delete;
150 
get_handler()151   static WarningHandler* get_handler() {
152     if (!warning_handler_) {
153       warning_handler_ = getBaseHandler();
154     }
155     return warning_handler_;
156   }
157 
set_handler(WarningHandler * handler)158   static void set_handler(WarningHandler* handler) {
159     warning_handler_ = handler;
160   }
161 
162  private:
163   static thread_local WarningHandler* warning_handler_;
164 };
165 
166 thread_local WarningHandler* ThreadWarningHandler::warning_handler_ = nullptr;
167 
168 } // namespace
169 
set_warning_handler(WarningHandler * handler)170 void set_warning_handler(WarningHandler* handler) noexcept(true) {
171   ThreadWarningHandler::set_handler(handler);
172 }
173 
get_warning_handler()174 WarningHandler* get_warning_handler() noexcept(true) {
175   return ThreadWarningHandler::get_handler();
176 }
177 
178 bool warn_always = false;
179 
set_warnAlways(bool setting)180 void set_warnAlways(bool setting) noexcept(true) {
181   warn_always = setting;
182 }
183 
get_warnAlways()184 bool get_warnAlways() noexcept(true) {
185   return warn_always;
186 }
187 
WarnAlways(bool setting)188 WarnAlways::WarnAlways(bool setting /*=true*/)
189     : prev_setting(get_warnAlways()) {
190   set_warnAlways(setting);
191 }
192 
~WarnAlways()193 WarnAlways::~WarnAlways() {
194   set_warnAlways(prev_setting);
195 }
196 
197 } // namespace WarningUtils
198 
warn(const Warning & warning)199 void warn(const Warning& warning) {
200   WarningUtils::ThreadWarningHandler::get_handler()->process(warning);
201 }
202 
Warning(warning_variant_t type,const SourceLocation & source_location,std::string msg,const bool verbatim)203 Warning::Warning(
204     warning_variant_t type,
205     const SourceLocation& source_location,
206     std::string msg,
207     const bool verbatim)
208     : type_(type),
209       source_location_(source_location),
210       msg_(std::move(msg)),
211       verbatim_(verbatim) {}
212 
Warning(warning_variant_t type,SourceLocation source_location,detail::CompileTimeEmptyString msg,const bool verbatim)213 Warning::Warning(
214     warning_variant_t type,
215     SourceLocation source_location,
216     detail::CompileTimeEmptyString msg,
217     const bool verbatim)
218     : Warning(type, source_location, "", verbatim) {}
219 
Warning(warning_variant_t type,SourceLocation source_location,const char * msg,const bool verbatim)220 Warning::Warning(
221     warning_variant_t type,
222     SourceLocation source_location,
223     const char* msg,
224     const bool verbatim)
225     : type_(type),
226       source_location_(source_location),
227       msg_(std::string(msg)),
228       verbatim_(verbatim) {}
229 
type() const230 Warning::warning_variant_t Warning::type() const {
231   return type_;
232 }
233 
source_location() const234 const SourceLocation& Warning::source_location() const {
235   return source_location_;
236 }
237 
msg() const238 const std::string& Warning::msg() const {
239   return msg_;
240 }
241 
verbatim() const242 bool Warning::verbatim() const {
243   return verbatim_;
244 }
245 
process(const Warning & warning)246 void WarningHandler::process(const Warning& warning) {
247   LOG_AT_FILE_LINE(
248       WARNING, warning.source_location().file, warning.source_location().line)
249       << "Warning: " << warning.msg() << " (function "
250       << warning.source_location().function << ")";
251 }
252 
GetExceptionString(const std::exception & e)253 std::string GetExceptionString(const std::exception& e) {
254 #ifdef __GXX_RTTI
255   return demangle(typeid(e).name()) + ": " + e.what();
256 #else
257   return std::string("Exception (no RTTI available): ") + e.what();
258 #endif // __GXX_RTTI
259 }
260 
261 } // namespace c10
262