xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
17 
18 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
19 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
20 #include "tensorflow/core/platform/errors.h"
21 #include "tensorflow/core/util/managed_stack_trace.h"
22 
23 namespace mlir {
24 
StatusScopedDiagnosticHandler(MLIRContext * context,bool propagate,bool filter_stack)25 StatusScopedDiagnosticHandler::StatusScopedDiagnosticHandler(
26     MLIRContext* context, bool propagate, bool filter_stack)
27     : SourceMgrDiagnosticHandler(source_mgr_, context, diag_stream_),
28       diag_stream_(diag_str_),
29       propagate_(propagate) {
30   if (filter_stack) {
31     this->shouldShowLocFn = [](Location loc) -> bool {
32       // For a Location to be surfaced in the stack, it must evaluate to true.
33       // For any Location that is a FileLineColLoc:
34       if (FileLineColLoc fileLoc = loc.dyn_cast<FileLineColLoc>()) {
35         return !tensorflow::IsInternalFrameForFilename(
36             fileLoc.getFilename().str());
37       } else {
38         // If this is a non-FileLineColLoc, go ahead and include it.
39         return true;
40       }
41     };
42   }
43 
44   setHandler([this](Diagnostic& diag) { return this->handler(&diag); });
45 }
46 
~StatusScopedDiagnosticHandler()47 StatusScopedDiagnosticHandler::~StatusScopedDiagnosticHandler() {
48   // Verify errors were consumed and re-register old handler.
49   bool all_errors_produced_were_consumed = ok();
50   DCHECK(all_errors_produced_were_consumed) << "Error status not consumed:\n"
51                                             << diag_str_;
52 }
53 
ok() const54 bool StatusScopedDiagnosticHandler::ok() const { return diag_str_.empty(); }
55 
ConsumeStatus()56 Status StatusScopedDiagnosticHandler::ConsumeStatus() {
57   if (ok()) return ::tensorflow::OkStatus();
58 
59   // TODO(jpienaar) This should be combining status with one previously built
60   // up.
61   Status s = tensorflow::errors::Unknown(diag_str_);
62   diag_str_.clear();
63   return s;
64 }
65 
Combine(Status status)66 Status StatusScopedDiagnosticHandler::Combine(Status status) {
67   if (status.ok()) return ConsumeStatus();
68 
69   // status is not-OK here, so if there was no diagnostics reported
70   // additionally then return this error.
71   if (ok()) return status;
72 
73   // Append the diagnostics reported to the status. This repeats the behavior of
74   // TensorFlow's AppendToMessage without the additional formatting inserted
75   // there.
76   status = ::tensorflow::Status(
77       status.code(), absl::StrCat(status.error_message(), diag_str_));
78   diag_str_.clear();
79   return status;
80 }
81 
handler(Diagnostic * diag)82 LogicalResult StatusScopedDiagnosticHandler::handler(Diagnostic* diag) {
83   size_t current_diag_str_size_ = diag_str_.size();
84 
85   // Emit the diagnostic and flush the stream.
86   emitDiagnostic(*diag);
87   diag_stream_.flush();
88 
89   // Emit non-errors to VLOG instead of the internal status.
90   if (diag->getSeverity() != DiagnosticSeverity::Error) {
91     VLOG(1) << diag_str_.substr(current_diag_str_size_);
92     diag_str_.resize(current_diag_str_size_);
93   }
94 
95   // Return failure to signal propagation if necessary.
96   return failure(propagate_);
97 }
98 
99 }  // namespace mlir
100