1 #include <c10/util/Exception.h>
2 #include <c10/util/Logging.h>
3 #include <c10/util/Type.h>
4
5 #include <sstream>
6 #include <string>
7 #include <utility>
8
9 namespace c10 {
10
Error(std::string msg,Backtrace backtrace,const void * caller)11 Error::Error(std::string msg, Backtrace backtrace, const void* caller)
12 : msg_(std::move(msg)), backtrace_(std::move(backtrace)), caller_(caller) {
13 refresh_what();
14 }
15
16 // PyTorch-style error message
17 // Error::Error(SourceLocation source_location, const std::string& msg)
18 // NB: This is defined in Logging.cpp for access to GetFetchStackTrace
19
20 // Caffe2-style error message
Error(const char * file,const uint32_t line,const char * condition,const std::string & msg,Backtrace backtrace,const void * caller)21 Error::Error(
22 const char* file,
23 const uint32_t line,
24 const char* condition,
25 const std::string& msg,
26 Backtrace backtrace,
27 const void* caller)
28 : Error(
29 str("[enforce fail at ",
30 detail::StripBasename(file),
31 ":",
32 line,
33 "] ",
34 condition,
35 ". ",
36 msg),
37 std::move(backtrace),
38 caller) {}
39
compute_what(bool include_backtrace) const40 std::string Error::compute_what(bool include_backtrace) const {
41 std::ostringstream oss;
42
43 oss << msg_;
44
45 if (context_.size() == 1) {
46 // Fold error and context in one line
47 oss << " (" << context_[0] << ")";
48 } else {
49 for (const auto& c : context_) {
50 oss << "\n " << c;
51 }
52 }
53
54 if (include_backtrace && backtrace_) {
55 oss << "\n" << backtrace_->get();
56 }
57
58 return oss.str();
59 }
60
backtrace() const61 const Backtrace& Error::backtrace() const {
62 return backtrace_;
63 }
64
what() const65 const char* Error::what() const noexcept {
66 return what_
67 .ensure([this] {
68 try {
69 return compute_what(/*include_backtrace*/ true);
70 } catch (...) {
71 // what() is noexcept, we need to return something here.
72 return std::string{"<Error computing Error::what()>"};
73 }
74 })
75 .c_str();
76 }
77
refresh_what()78 void Error::refresh_what() {
79 // Do not compute what_ eagerly, as it would trigger the computation of the
80 // backtrace. Instead, invalidate it, it will be computed on first access.
81 // refresh_what() is only called by non-const public methods which are not
82 // supposed to be called concurrently with any other method, so it is safe to
83 // invalidate here.
84 what_.reset();
85 what_without_backtrace_ = compute_what(/*include_backtrace*/ false);
86 }
87
add_context(std::string new_msg)88 void Error::add_context(std::string new_msg) {
89 context_.push_back(std::move(new_msg));
90 // TODO: Calling add_context O(n) times has O(n^2) cost. We can fix
91 // this perf problem by populating the fields lazily... if this ever
92 // actually is a problem.
93 // NB: If you do fix this, make sure you do it in a thread safe way!
94 // what() is almost certainly expected to be thread safe even when
95 // accessed across multiple threads
96 refresh_what();
97 }
98
99 namespace detail {
100
torchCheckFail(const char * func,const char * file,uint32_t line,const std::string & msg)101 void torchCheckFail(
102 const char* func,
103 const char* file,
104 uint32_t line,
105 const std::string& msg) {
106 throw ::c10::Error({func, file, line}, msg);
107 }
108
torchCheckFail(const char * func,const char * file,uint32_t line,const char * msg)109 void torchCheckFail(
110 const char* func,
111 const char* file,
112 uint32_t line,
113 const char* msg) {
114 throw ::c10::Error({func, file, line}, msg);
115 }
116
torchInternalAssertFail(const char * func,const char * file,uint32_t line,const char * condMsg,const char * userMsg)117 void torchInternalAssertFail(
118 const char* func,
119 const char* file,
120 uint32_t line,
121 const char* condMsg,
122 const char* userMsg) {
123 torchCheckFail(func, file, line, c10::str(condMsg, userMsg));
124 }
125
126 // This should never be called. It is provided in case of compilers
127 // that don't do any dead code stripping in debug builds.
torchInternalAssertFail(const char * func,const char * file,uint32_t line,const char * condMsg,const std::string & userMsg)128 void torchInternalAssertFail(
129 const char* func,
130 const char* file,
131 uint32_t line,
132 const char* condMsg,
133 const std::string& userMsg) {
134 torchCheckFail(func, file, line, c10::str(condMsg, userMsg));
135 }
136
137 } // namespace detail
138
139 namespace WarningUtils {
140
141 namespace {
getBaseHandler()142 WarningHandler* getBaseHandler() {
143 static WarningHandler base_warning_handler_ = WarningHandler();
144 return &base_warning_handler_;
145 }
146
147 class ThreadWarningHandler {
148 public:
149 ThreadWarningHandler() = delete;
150
get_handler()151 static WarningHandler* get_handler() {
152 if (!warning_handler_) {
153 warning_handler_ = getBaseHandler();
154 }
155 return warning_handler_;
156 }
157
set_handler(WarningHandler * handler)158 static void set_handler(WarningHandler* handler) {
159 warning_handler_ = handler;
160 }
161
162 private:
163 static thread_local WarningHandler* warning_handler_;
164 };
165
166 thread_local WarningHandler* ThreadWarningHandler::warning_handler_ = nullptr;
167
168 } // namespace
169
set_warning_handler(WarningHandler * handler)170 void set_warning_handler(WarningHandler* handler) noexcept(true) {
171 ThreadWarningHandler::set_handler(handler);
172 }
173
get_warning_handler()174 WarningHandler* get_warning_handler() noexcept(true) {
175 return ThreadWarningHandler::get_handler();
176 }
177
178 bool warn_always = false;
179
set_warnAlways(bool setting)180 void set_warnAlways(bool setting) noexcept(true) {
181 warn_always = setting;
182 }
183
get_warnAlways()184 bool get_warnAlways() noexcept(true) {
185 return warn_always;
186 }
187
WarnAlways(bool setting)188 WarnAlways::WarnAlways(bool setting /*=true*/)
189 : prev_setting(get_warnAlways()) {
190 set_warnAlways(setting);
191 }
192
~WarnAlways()193 WarnAlways::~WarnAlways() {
194 set_warnAlways(prev_setting);
195 }
196
197 } // namespace WarningUtils
198
warn(const Warning & warning)199 void warn(const Warning& warning) {
200 WarningUtils::ThreadWarningHandler::get_handler()->process(warning);
201 }
202
Warning(warning_variant_t type,const SourceLocation & source_location,std::string msg,const bool verbatim)203 Warning::Warning(
204 warning_variant_t type,
205 const SourceLocation& source_location,
206 std::string msg,
207 const bool verbatim)
208 : type_(type),
209 source_location_(source_location),
210 msg_(std::move(msg)),
211 verbatim_(verbatim) {}
212
Warning(warning_variant_t type,SourceLocation source_location,detail::CompileTimeEmptyString msg,const bool verbatim)213 Warning::Warning(
214 warning_variant_t type,
215 SourceLocation source_location,
216 detail::CompileTimeEmptyString msg,
217 const bool verbatim)
218 : Warning(type, source_location, "", verbatim) {}
219
Warning(warning_variant_t type,SourceLocation source_location,const char * msg,const bool verbatim)220 Warning::Warning(
221 warning_variant_t type,
222 SourceLocation source_location,
223 const char* msg,
224 const bool verbatim)
225 : type_(type),
226 source_location_(source_location),
227 msg_(std::string(msg)),
228 verbatim_(verbatim) {}
229
type() const230 Warning::warning_variant_t Warning::type() const {
231 return type_;
232 }
233
source_location() const234 const SourceLocation& Warning::source_location() const {
235 return source_location_;
236 }
237
msg() const238 const std::string& Warning::msg() const {
239 return msg_;
240 }
241
verbatim() const242 bool Warning::verbatim() const {
243 return verbatim_;
244 }
245
process(const Warning & warning)246 void WarningHandler::process(const Warning& warning) {
247 LOG_AT_FILE_LINE(
248 WARNING, warning.source_location().file, warning.source_location().line)
249 << "Warning: " << warning.msg() << " (function "
250 << warning.source_location().function << ")";
251 }
252
GetExceptionString(const std::exception & e)253 std::string GetExceptionString(const std::exception& e) {
254 #ifdef __GXX_RTTI
255 return demangle(typeid(e).name()) + ": " + e.what();
256 #else
257 return std::string("Exception (no RTTI available): ") + e.what();
258 #endif // __GXX_RTTI
259 }
260
261 } // namespace c10
262