1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef XLA_RUNTIME_DIAGNOSTICS_H_ 17 #define XLA_RUNTIME_DIAGNOSTICS_H_ 18 19 #include <functional> 20 #include <string> 21 #include <utility> 22 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/ADT/SmallVector.h" 25 #include "llvm/Support/raw_ostream.h" 26 #include "tensorflow/compiler/xla/runtime/logical_result.h" 27 28 namespace xla { 29 namespace runtime { 30 31 // Forward declare. 32 class DiagnosticEngine; 33 34 // XLA runtime diagnostics borrows a lot of ideas from the MLIR compile time 35 // diagnostics (which is largely based on the Swift compiler diagnostics), 36 // however in contrast to MLIR compilation pipelines we need to emit diagnostics 37 // for the run time events (vs compile time) and correlate them back to the 38 // location in the input module. 39 // 40 // See MLIR Diagnostics documentation: https://mlir.llvm.org/docs/Diagnostics. 41 // 42 // TODO(ezhulenev): Add location tracking, so that we can correlate emitted 43 // diagnostics to the location in the input module, and from there rely on the 44 // MLIR location to correlate events back to the user program (e.g. original 45 // JAX program written in Python). 46 // 47 // TODO(ezhulenev): In contrast to MLIR we don't have notes. Add them if needed. 48 49 enum class DiagnosticSeverity { kWarning, kError, kRemark }; 50 51 //===----------------------------------------------------------------------===// 52 // Diagnostic 53 //===----------------------------------------------------------------------===// 54 55 class Diagnostic { 56 public: Diagnostic(DiagnosticSeverity severity)57 explicit Diagnostic(DiagnosticSeverity severity) : severity_(severity) {} 58 59 Diagnostic(Diagnostic &&) = default; 60 Diagnostic &operator=(Diagnostic &&) = default; 61 62 // TODO(ezhulenev): Instead of relying on `<<` implementation pass diagnostic 63 // arguments explicitly, similar to MLIR? 64 65 template <typename Arg> 66 Diagnostic &operator<<(Arg &&arg) { 67 llvm::raw_string_ostream(message_) << std::forward<Arg>(arg); 68 return *this; 69 } 70 71 template <typename Arg> append(Arg && arg)72 Diagnostic &append(Arg &&arg) { 73 *this << std::forward<Arg>(arg); 74 return *this; 75 } 76 77 template <typename T> 78 Diagnostic &appendRange(const T &c, const char *delim = ", ") { 79 llvm::interleave( 80 c, [this](const auto &a) { *this << a; }, [&]() { *this << delim; }); 81 return *this; 82 } 83 severity()84 DiagnosticSeverity severity() const { return severity_; } 85 str()86 std::string str() const { return message_; } 87 88 private: 89 Diagnostic(const Diagnostic &rhs) = delete; 90 Diagnostic &operator=(const Diagnostic &rhs) = delete; 91 92 DiagnosticSeverity severity_; 93 std::string message_; 94 }; 95 96 //===----------------------------------------------------------------------===// 97 // InFlightDiagnostic 98 //===----------------------------------------------------------------------===// 99 100 // In flight diagnostic gives an opportunity to build a diagnostic before 101 // reporting it to the engine, similar to the builder pattern. 102 class InFlightDiagnostic { 103 public: InFlightDiagnostic(InFlightDiagnostic && other)104 InFlightDiagnostic(InFlightDiagnostic &&other) 105 : engine_(other.engine_), diagnostic_(std::move(other.diagnostic_)) { 106 other.diagnostic_.reset(); 107 other.Abandon(); 108 } 109 ~InFlightDiagnostic()110 ~InFlightDiagnostic() { 111 if (IsInFlight()) Report(); 112 } 113 114 template <typename Arg> 115 InFlightDiagnostic &operator<<(Arg &&arg) & { 116 return append(std::forward<Arg>(arg)); 117 } 118 template <typename Arg> 119 InFlightDiagnostic &&operator<<(Arg &&arg) && { 120 return std::move(append(std::forward<Arg>(arg))); 121 } 122 123 template <typename Arg> append(Arg && arg)124 InFlightDiagnostic &append(Arg &&arg) & { 125 assert(IsActive() && "diagnostic not active"); 126 if (IsInFlight()) diagnostic_->append(std::forward<Arg>(arg)); 127 return *this; 128 } 129 130 template <typename Arg> append(Arg && arg)131 InFlightDiagnostic &&append(Arg &&arg) && { 132 return std::move(append(std::forward<Arg>(arg))); 133 } 134 135 void Report(); 136 void Abandon(); 137 138 // Allow a diagnostic to be converted to 'failure'. 139 // 140 // Example: 141 // 142 // LogicalResult call(DiagnosticEngine diag, ...) { 143 // if (<check failed>) return diag.EmitError() << "Oops"; 144 // ... 145 // } 146 // LogicalResult()147 operator LogicalResult() const { return failure(); } // NOLINT 148 149 private: 150 friend class DiagnosticEngine; 151 InFlightDiagnostic(const DiagnosticEngine * engine,Diagnostic diagnostic)152 InFlightDiagnostic(const DiagnosticEngine *engine, Diagnostic diagnostic) 153 : engine_(engine), diagnostic_(std::move(diagnostic)) {} 154 155 InFlightDiagnostic &operator=(const InFlightDiagnostic &) = delete; 156 InFlightDiagnostic &operator=(InFlightDiagnostic &&) = delete; 157 IsActive()158 bool IsActive() const { return diagnostic_.has_value(); } IsInFlight()159 bool IsInFlight() const { return engine_ != nullptr; } 160 161 // Diagnostic engine that will report this diagnostic once its ready. 162 const DiagnosticEngine *engine_ = nullptr; 163 llvm::Optional<Diagnostic> diagnostic_; 164 }; 165 166 //===----------------------------------------------------------------------===// 167 // DiagnosticEngine 168 //===----------------------------------------------------------------------===// 169 170 // Diagnostic engine is responsible for passing diagnostics to the user. 171 // 172 // XLA runtime users must set up diagnostic engine to report errors back to the 173 // caller, e.g. the handler can collect all of the emitted diagnostics into the 174 // string message, and pass it to the caller as the async error. 175 // 176 // Unhandled error diagnostics will be dumped to the llvm::errs() stream. 177 class DiagnosticEngine { 178 public: 179 // Diagnostic handler must return success if it consumed the diagnostic, and 180 // failure if the engine should pass it to the next registered handler. 181 using HandlerTy = std::function<LogicalResult(Diagnostic &)>; 182 183 // Returns the default instance of the diagnostic engine. 184 static const DiagnosticEngine *DefaultDiagnosticEngine(); 185 Emit(DiagnosticSeverity severity)186 InFlightDiagnostic Emit(DiagnosticSeverity severity) const { 187 return InFlightDiagnostic(this, Diagnostic(severity)); 188 } 189 EmitError()190 InFlightDiagnostic EmitError() const { 191 return Emit(DiagnosticSeverity::kError); 192 } 193 AddHandler(HandlerTy handler)194 void AddHandler(HandlerTy handler) { 195 handlers_.push_back(std::move(handler)); 196 } 197 Emit(Diagnostic diagnostic)198 void Emit(Diagnostic diagnostic) const { 199 for (auto &handler : llvm::reverse(handlers_)) { 200 if (succeeded(handler(diagnostic))) return; 201 } 202 203 // Dump unhandled errors to llvm::errs() stream. 204 if (diagnostic.severity() == DiagnosticSeverity::kError) 205 llvm::errs() << "Error: " << diagnostic.str() << "\n"; 206 } 207 208 private: 209 llvm::SmallVector<HandlerTy> handlers_; 210 }; 211 212 } // namespace runtime 213 } // namespace xla 214 215 #endif // XLA_RUNTIME_DIAGNOSTICS_H_ 216