xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/runtime/diagnostics.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef XLA_RUNTIME_DIAGNOSTICS_H_
17 #define XLA_RUNTIME_DIAGNOSTICS_H_
18 
19 #include <functional>
20 #include <string>
21 #include <utility>
22 
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "tensorflow/compiler/xla/runtime/logical_result.h"
27 
28 namespace xla {
29 namespace runtime {
30 
31 // Forward declare.
32 class DiagnosticEngine;
33 
34 // XLA runtime diagnostics borrows a lot of ideas from the MLIR compile time
35 // diagnostics (which is largely based on the Swift compiler diagnostics),
36 // however in contrast to MLIR compilation pipelines we need to emit diagnostics
37 // for the run time events (vs compile time) and correlate them back to the
38 // location in the input module.
39 //
40 // See MLIR Diagnostics documentation: https://mlir.llvm.org/docs/Diagnostics.
41 //
42 // TODO(ezhulenev): Add location tracking, so that we can correlate emitted
43 // diagnostics to the location in the input module, and from there rely on the
44 // MLIR location to correlate events back to the user program (e.g. original
45 // JAX program written in Python).
46 //
47 // TODO(ezhulenev): In contrast to MLIR we don't have notes. Add them if needed.
48 
49 enum class DiagnosticSeverity { kWarning, kError, kRemark };
50 
51 //===----------------------------------------------------------------------===//
52 // Diagnostic
53 //===----------------------------------------------------------------------===//
54 
55 class Diagnostic {
56  public:
Diagnostic(DiagnosticSeverity severity)57   explicit Diagnostic(DiagnosticSeverity severity) : severity_(severity) {}
58 
59   Diagnostic(Diagnostic &&) = default;
60   Diagnostic &operator=(Diagnostic &&) = default;
61 
62   // TODO(ezhulenev): Instead of relying on `<<` implementation pass diagnostic
63   // arguments explicitly, similar to MLIR?
64 
65   template <typename Arg>
66   Diagnostic &operator<<(Arg &&arg) {
67     llvm::raw_string_ostream(message_) << std::forward<Arg>(arg);
68     return *this;
69   }
70 
71   template <typename Arg>
append(Arg && arg)72   Diagnostic &append(Arg &&arg) {
73     *this << std::forward<Arg>(arg);
74     return *this;
75   }
76 
77   template <typename T>
78   Diagnostic &appendRange(const T &c, const char *delim = ", ") {
79     llvm::interleave(
80         c, [this](const auto &a) { *this << a; }, [&]() { *this << delim; });
81     return *this;
82   }
83 
severity()84   DiagnosticSeverity severity() const { return severity_; }
85 
str()86   std::string str() const { return message_; }
87 
88  private:
89   Diagnostic(const Diagnostic &rhs) = delete;
90   Diagnostic &operator=(const Diagnostic &rhs) = delete;
91 
92   DiagnosticSeverity severity_;
93   std::string message_;
94 };
95 
96 //===----------------------------------------------------------------------===//
97 // InFlightDiagnostic
98 //===----------------------------------------------------------------------===//
99 
100 // In flight diagnostic gives an opportunity to build a diagnostic before
101 // reporting it to the engine, similar to the builder pattern.
102 class InFlightDiagnostic {
103  public:
InFlightDiagnostic(InFlightDiagnostic && other)104   InFlightDiagnostic(InFlightDiagnostic &&other)
105       : engine_(other.engine_), diagnostic_(std::move(other.diagnostic_)) {
106     other.diagnostic_.reset();
107     other.Abandon();
108   }
109 
~InFlightDiagnostic()110   ~InFlightDiagnostic() {
111     if (IsInFlight()) Report();
112   }
113 
114   template <typename Arg>
115   InFlightDiagnostic &operator<<(Arg &&arg) & {
116     return append(std::forward<Arg>(arg));
117   }
118   template <typename Arg>
119   InFlightDiagnostic &&operator<<(Arg &&arg) && {
120     return std::move(append(std::forward<Arg>(arg)));
121   }
122 
123   template <typename Arg>
append(Arg && arg)124   InFlightDiagnostic &append(Arg &&arg) & {
125     assert(IsActive() && "diagnostic not active");
126     if (IsInFlight()) diagnostic_->append(std::forward<Arg>(arg));
127     return *this;
128   }
129 
130   template <typename Arg>
append(Arg && arg)131   InFlightDiagnostic &&append(Arg &&arg) && {
132     return std::move(append(std::forward<Arg>(arg)));
133   }
134 
135   void Report();
136   void Abandon();
137 
138   // Allow a diagnostic to be converted to 'failure'.
139   //
140   // Example:
141   //
142   //   LogicalResult call(DiagnosticEngine diag, ...) {
143   //     if (<check failed>) return diag.EmitError() << "Oops";
144   //     ...
145   //   }
146   //
LogicalResult()147   operator LogicalResult() const { return failure(); }  // NOLINT
148 
149  private:
150   friend class DiagnosticEngine;
151 
InFlightDiagnostic(const DiagnosticEngine * engine,Diagnostic diagnostic)152   InFlightDiagnostic(const DiagnosticEngine *engine, Diagnostic diagnostic)
153       : engine_(engine), diagnostic_(std::move(diagnostic)) {}
154 
155   InFlightDiagnostic &operator=(const InFlightDiagnostic &) = delete;
156   InFlightDiagnostic &operator=(InFlightDiagnostic &&) = delete;
157 
IsActive()158   bool IsActive() const { return diagnostic_.has_value(); }
IsInFlight()159   bool IsInFlight() const { return engine_ != nullptr; }
160 
161   // Diagnostic engine that will report this diagnostic once its ready.
162   const DiagnosticEngine *engine_ = nullptr;
163   llvm::Optional<Diagnostic> diagnostic_;
164 };
165 
166 //===----------------------------------------------------------------------===//
167 // DiagnosticEngine
168 //===----------------------------------------------------------------------===//
169 
170 // Diagnostic engine is responsible for passing diagnostics to the user.
171 //
172 // XLA runtime users must set up diagnostic engine to report errors back to the
173 // caller, e.g. the handler can collect all of the emitted diagnostics into the
174 // string message, and pass it to the caller as the async error.
175 //
176 // Unhandled error diagnostics will be dumped to the llvm::errs() stream.
177 class DiagnosticEngine {
178  public:
179   // Diagnostic handler must return success if it consumed the diagnostic, and
180   // failure if the engine should pass it to the next registered handler.
181   using HandlerTy = std::function<LogicalResult(Diagnostic &)>;
182 
183   // Returns the default instance of the diagnostic engine.
184   static const DiagnosticEngine *DefaultDiagnosticEngine();
185 
Emit(DiagnosticSeverity severity)186   InFlightDiagnostic Emit(DiagnosticSeverity severity) const {
187     return InFlightDiagnostic(this, Diagnostic(severity));
188   }
189 
EmitError()190   InFlightDiagnostic EmitError() const {
191     return Emit(DiagnosticSeverity::kError);
192   }
193 
AddHandler(HandlerTy handler)194   void AddHandler(HandlerTy handler) {
195     handlers_.push_back(std::move(handler));
196   }
197 
Emit(Diagnostic diagnostic)198   void Emit(Diagnostic diagnostic) const {
199     for (auto &handler : llvm::reverse(handlers_)) {
200       if (succeeded(handler(diagnostic))) return;
201     }
202 
203     // Dump unhandled errors to llvm::errs() stream.
204     if (diagnostic.severity() == DiagnosticSeverity::kError)
205       llvm::errs() << "Error: " << diagnostic.str() << "\n";
206   }
207 
208  private:
209   llvm::SmallVector<HandlerTy> handlers_;
210 };
211 
212 }  // namespace runtime
213 }  // namespace xla
214 
215 #endif  // XLA_RUNTIME_DIAGNOSTICS_H_
216